Coverage for middle_layer/common/application_layer/orm_repositories/__init__.py: 96.30%
27 statements
« prev ^ index » next coverage.py v7.10.5, created at 2026-04-13 06:13 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2026-04-13 06:13 +0000
1# Copyright 2024 Associated Universities, Inc.
2#
3# This file is part of Telescope Time Allocation Tools (TTAT).
4#
5# TTAT is free software: you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation, either version 3 of the License, or
8# any later version.
9#
10# TTAT is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with TTAT. If not, see <https://www.gnu.org/licenses/>.
17#
18from typing import Type, TypeVar
20from sqlalchemy.exc import IntegrityError, NoResultFound
21from sqlalchemy.orm import Session
24def add_entity(session: Session, entity: object, integrity_error_msgs: dict[type, str] | None = None) -> None:
25 """
26 Method used by every ORM SubRepository to safely try adding an object to the database, rolling back
27 if there is an exception during the transaction
29 :param session: The ORM Session
30 :param entity: Entity to add to the database
31 :param integrity_error_msgs: Dictionary that maps
32 IntegrityError subclasses to the error messages to display if they are raised
33 :raises ValueError: When an IntegrityError is raised by SQLAlchemy
34 """
35 try:
36 session.add(entity)
37 session.flush()
38 except IntegrityError as e:
39 session.rollback()
40 # Relies on psycopg2 being the underlying database library
41 integrity_error_msg = (
42 f"Failed to add {entity.__class__.__name__} to the database: {e.orig.diag.message_primary}"
43 )
44 if integrity_error_msgs is not None and integrity_error_msgs.get(e.orig.__class__) is not None:
45 integrity_error_msg = integrity_error_msgs[e.orig.__class__]
46 raise ValueError(integrity_error_msg)
49T = TypeVar("T")
52def list_entities(session: Session, entity_type: Type[T], *order_by_attributes) -> list[T]:
53 """
54 Method used by every ORM SubRepository to get a list of stored entities given a domainmodel class
56 :param session: the orm session
57 :param entity_type: Type of entity to list
58 :param order_by_attributes: Optional list of attributes to order by
59 :return: list of domainmodel objects
60 """
61 q = session.query(entity_type)
62 if order_by_attributes:
63 q = q.order_by(*order_by_attributes)
64 return q.all()
67U = TypeVar("U")
70def get_object_by_id(session: Session, entity_id: U, entity_type: Type[T], id_attr: U) -> T:
71 """
72 Method used by every ORM SubRepository to retrieve a stored entity of a given type with a given ID
74 :param session: An ORM session
75 :param entity_id: The ID of the entity to retrieve
76 :param entity_type: Type of entity to retrieve
77 :param id_attr: Attribute of entity_type representing its ID;
78 there should be at most 1 entity of entity_type with any given value of id_attr
79 :raises ValueError: When no entity of the given type and ID is found in the database
80 """
81 result = None
82 try:
83 result = session.query(entity_type).filter(id_attr == entity_id).one()
84 except NoResultFound:
85 raise ValueError(f"{entity_type.__name__} id {entity_id} not found.")
86 return result