Coverage for middle_layer/common/application_layer/orm_repositories/__init__.py: 96.30%

27 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2026-04-13 06:13 +0000

1# Copyright 2024 Associated Universities, Inc. 

2# 

3# This file is part of Telescope Time Allocation Tools (TTAT). 

4# 

5# TTAT is free software: you can redistribute it and/or modify 

6# it under the terms of the GNU General Public License as published by 

7# the Free Software Foundation, either version 3 of the License, or 

8# any later version. 

9# 

10# TTAT is distributed in the hope that it will be useful, 

11# but WITHOUT ANY WARRANTY; without even the implied warranty of 

12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

13# GNU General Public License for more details. 

14# 

15# You should have received a copy of the GNU General Public License 

16# along with TTAT. If not, see <https://www.gnu.org/licenses/>. 

17# 

18from typing import Type, TypeVar 

19 

20from sqlalchemy.exc import IntegrityError, NoResultFound 

21from sqlalchemy.orm import Session 

22 

23 

24def add_entity(session: Session, entity: object, integrity_error_msgs: dict[type, str] | None = None) -> None: 

25 """ 

26 Method used by every ORM SubRepository to safely try adding an object to the database, rolling back 

27 if there is an exception during the transaction 

28 

29 :param session: The ORM Session 

30 :param entity: Entity to add to the database 

31 :param integrity_error_msgs: Dictionary that maps 

32 IntegrityError subclasses to the error messages to display if they are raised 

33 :raises ValueError: When an IntegrityError is raised by SQLAlchemy 

34 """ 

35 try: 

36 session.add(entity) 

37 session.flush() 

38 except IntegrityError as e: 

39 session.rollback() 

40 # Relies on psycopg2 being the underlying database library 

41 integrity_error_msg = ( 

42 f"Failed to add {entity.__class__.__name__} to the database: {e.orig.diag.message_primary}" 

43 ) 

44 if integrity_error_msgs is not None and integrity_error_msgs.get(e.orig.__class__) is not None: 

45 integrity_error_msg = integrity_error_msgs[e.orig.__class__] 

46 raise ValueError(integrity_error_msg) 

47 

48 

49T = TypeVar("T") 

50 

51 

52def list_entities(session: Session, entity_type: Type[T], *order_by_attributes) -> list[T]: 

53 """ 

54 Method used by every ORM SubRepository to get a list of stored entities given a domainmodel class 

55 

56 :param session: the orm session 

57 :param entity_type: Type of entity to list 

58 :param order_by_attributes: Optional list of attributes to order by 

59 :return: list of domainmodel objects 

60 """ 

61 q = session.query(entity_type) 

62 if order_by_attributes: 

63 q = q.order_by(*order_by_attributes) 

64 return q.all() 

65 

66 

67U = TypeVar("U") 

68 

69 

70def get_object_by_id(session: Session, entity_id: U, entity_type: Type[T], id_attr: U) -> T: 

71 """ 

72 Method used by every ORM SubRepository to retrieve a stored entity of a given type with a given ID 

73 

74 :param session: An ORM session 

75 :param entity_id: The ID of the entity to retrieve 

76 :param entity_type: Type of entity to retrieve 

77 :param id_attr: Attribute of entity_type representing its ID; 

78 there should be at most 1 entity of entity_type with any given value of id_attr 

79 :raises ValueError: When no entity of the given type and ID is found in the database 

80 """ 

81 result = None 

82 try: 

83 result = session.query(entity_type).filter(id_attr == entity_id).one() 

84 except NoResultFound: 

85 raise ValueError(f"{entity_type.__name__} id {entity_id} not found.") 

86 return result