tol-sdk 1.6.37__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. tol/api_base/blueprint.py +29 -6
  2. tol/api_base/controller.py +14 -5
  3. tol/api_client/api_datasource.py +15 -7
  4. tol/api_client/client.py +12 -6
  5. tol/api_client/converter.py +22 -8
  6. tol/api_client/factory.py +5 -3
  7. tol/api_client/view.py +75 -205
  8. tol/cli/cli.py +1 -1
  9. tol/core/__init__.py +1 -0
  10. tol/core/http_client.py +4 -2
  11. tol/core/operator/cursor.py +5 -3
  12. tol/core/operator/detail_getter.py +7 -15
  13. tol/core/operator/list_getter.py +3 -1
  14. tol/core/operator/page_getter.py +3 -1
  15. tol/core/operator/relational.py +9 -4
  16. tol/core/requested_fields.py +189 -0
  17. tol/elastic/elastic_datasource.py +2 -1
  18. tol/flows/converters/benchling_extraction_to_elastic_extraction_converter.py +25 -6
  19. tol/flows/converters/benchling_extraction_to_elastic_sequencing_request_converter.py +28 -7
  20. tol/flows/converters/benchling_sequencing_request_to_elastic_sequencing_request_converter.py +30 -9
  21. tol/flows/converters/benchling_tissue_prep_to_elastic_tissue_prep_converter.py +14 -3
  22. tol/flows/converters/elastic_sample_to_benchling_tissue_update_converter.py +1 -1
  23. tol/flows/converters/elastic_sample_to_elastic_sequencing_request_update_converter.py +4 -1
  24. tol/flows/converters/elastic_tolid_to_elastic_genome_note_update_converter.py +4 -1
  25. tol/flows/converters/elastic_tolid_to_elastic_sample_update_converter.py +4 -1
  26. tol/sources/sts.py +6 -2
  27. tol/sql/database.py +80 -44
  28. tol/sql/factory.py +2 -2
  29. tol/sql/filter.py +22 -20
  30. tol/sql/model.py +43 -38
  31. tol/sql/relationship.py +1 -1
  32. tol/sql/sql_converter.py +49 -142
  33. tol/sql/sql_datasource.py +85 -180
  34. tol/sql/{board → standard}/__init__.py +1 -1
  35. tol/sql/standard/factory.py +549 -0
  36. {tol_sdk-1.6.37.dist-info → tol_sdk-1.7.0.dist-info}/METADATA +1 -1
  37. {tol_sdk-1.6.37.dist-info → tol_sdk-1.7.0.dist-info}/RECORD +41 -42
  38. tol/sql/board/factory.py +0 -341
  39. tol/sql/loader/__init__.py +0 -6
  40. tol/sql/loader/factory.py +0 -246
  41. {tol_sdk-1.6.37.dist-info → tol_sdk-1.7.0.dist-info}/WHEEL +0 -0
  42. {tol_sdk-1.6.37.dist-info → tol_sdk-1.7.0.dist-info}/entry_points.txt +0 -0
  43. {tol_sdk-1.6.37.dist-info → tol_sdk-1.7.0.dist-info}/licenses/LICENSE +0 -0
  44. {tol_sdk-1.6.37.dist-info → tol_sdk-1.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,189 @@
1
+ # SPDX-FileCopyrightText: 2025 Genome Research Ltd.
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Iterable, Iterator
8
+
9
+ from . import DataSourceError, OperableDataSource
10
+ from .operator import Relational
11
+ from .relationship import RelationshipConfig
12
+
13
+
14
+ class ReqFieldsTree:
15
+ """
16
+ Acts as a template for which related objects and attributes to fetch from
17
+ the DataSource, and for serialzing them in the response.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ object_type: str,
23
+ data_source: OperableDataSource,
24
+ requested_fields: list[str] | None = None,
25
+ include_all_to_ones: bool = True,
26
+ ) -> None:
27
+ self.__object_type: str = object_type
28
+ self.__attributes: dict[str, bool] = {}
29
+ self.__sub_trees: dict[str, ReqFieldsTree] = {}
30
+ self.__rel_conf: RelationshipConfig | None = (
31
+ data_source.relationship_config.get(object_type)
32
+ if isinstance(data_source, Relational)
33
+ else None
34
+ )
35
+ if requested_fields:
36
+ self.add_requested_tree(data_source, requested_fields)
37
+ elif include_all_to_ones:
38
+ self.add_all_to_ones(data_source)
39
+
40
+ def __eq__(self, other):
41
+ if isinstance(other, self.__class__):
42
+ return self.__dict__ == other.__dict__
43
+ return False
44
+
45
+ @property
46
+ def object_type(self) -> str:
47
+ return self.__object_type
48
+
49
+ @property
50
+ def attribute_names(self) -> list[str]:
51
+ return list(self.__attributes)
52
+
53
+ def add_attribute(self, name: str) -> None:
54
+ """
55
+ Add attribute `name`. Attributes are implemented as a dict rather than
56
+ a set to preserve order.
57
+ """
58
+ self.__attributes[name] = True
59
+
60
+ def has_attribute(self, name: str) -> bool:
61
+ return name in self.__attributes
62
+
63
+ def add_sub_tree(self, name: str, sub: ReqFieldsTree) -> None:
64
+ self.__sub_trees[name] = sub
65
+
66
+ def get_sub_tree(self, name: str) -> ReqFieldsTree | None:
67
+ return self.__sub_trees.get(name)
68
+
69
+ def sub_trees(self) -> Iterator[str, ReqFieldsTree]:
70
+ """
71
+ Iterator over `name, sub_tree`.
72
+ """
73
+ yield from self.__sub_trees.items()
74
+
75
+ @property
76
+ def is_leaf(self) -> bool:
77
+ """
78
+ This tree is a leaf if it has no sub-trees.
79
+ """
80
+ return not self.__sub_trees
81
+
82
+ @property
83
+ def is_stub(self) -> bool:
84
+ """
85
+ This tree is a stub if it has no sub-trees and has one attribute which
86
+ is "id".
87
+ """
88
+ if self.__sub_trees:
89
+ return False
90
+ return len(self.__attributes) == 1 and 'id' in self.__attributes
91
+
92
+ @property
93
+ def has_relationships(self) -> bool:
94
+ return self.__rel_conf is not None
95
+
96
+ def to_one_names(self) -> Iterable[str]:
97
+ return x.keys() if (x := self.__rel_conf.to_one) else ()
98
+
99
+ def to_many_names(self) -> Iterable[str]:
100
+ return x.keys() if (x := self.__rel_conf.to_many) else ()
101
+
102
+ def get_relationship(self, name: str) -> str:
103
+ """
104
+ Fetches the related object type from the `to_one` or `to_many` fields
105
+ of the attached `RelationshipConfig`.
106
+ """
107
+ if rel_conf := self.__rel_conf:
108
+ for attr in 'to_one', 'to_many':
109
+ if (cfg := getattr(rel_conf, attr)) and (type_name := cfg.get(name)):
110
+ return type_name
111
+
112
+ def get_attribute_type(self, data_source, name) -> type | None:
113
+ return data_source.attribute_types[self.object_type].get(name)
114
+
115
+ def add_all_to_ones(self, data_source: OperableDataSource) -> None:
116
+ if not self.__rel_conf:
117
+ return
118
+ for name in self.to_one_names():
119
+ sub_type = self.get_relationship(name)
120
+ self.add_sub_tree(
121
+ name,
122
+ self.__class__(
123
+ sub_type,
124
+ data_source,
125
+ include_all_to_ones=False,
126
+ ),
127
+ )
128
+
129
+ def add_requested_tree(
130
+ self,
131
+ data_source: OperableDataSource,
132
+ requested_fields: list[str],
133
+ ) -> None:
134
+ err_title = 'Bad Requested Fields Path Element'
135
+ for path_str in requested_fields:
136
+ tree = self
137
+ for name in path_str.split('.'):
138
+ if name == '':
139
+ msg = f"Empty element in path path '{path_str}'"
140
+ raise DataSourceError(title=err_title, detail=msg)
141
+ elif not tree:
142
+ msg = f"Element '{name}' appears after an attribute name in path '{path_str}'"
143
+ raise DataSourceError(title=err_title, detail=msg)
144
+ elif sub_tree := tree.get_sub_tree(name):
145
+ # `name` is a relationship we already have so move pointer
146
+ # to the sub tree.
147
+ tree = sub_tree
148
+ elif tree.has_attribute(name):
149
+ # Already have attribute `name`. Unset `tree` to trap path
150
+ # elements follwing an attribute.
151
+ tree = None
152
+ elif sub_type := tree.get_relationship(name):
153
+ # New sub-tree from relationship name.
154
+ sub_tree = self.__class__(
155
+ sub_type,
156
+ data_source,
157
+ include_all_to_ones=False,
158
+ )
159
+ tree.add_sub_tree(name, sub_tree)
160
+ tree = sub_tree
161
+ elif tree.get_attribute_type(data_source, name):
162
+ tree.add_attribute(name)
163
+ # Unset `tree` to trap path elements follwing an
164
+ # attribute.
165
+ tree = None
166
+ else:
167
+ msg = (
168
+ f'{name!r} in path {path_str!r} is not a known relationship'
169
+ f' or attribute of {tree.object_type!r} objects'
170
+ )
171
+ raise DataSourceError(title=err_title, detail=msg)
172
+
173
+ def __str__(self):
174
+ return ','.join(self.__to_strings_iter())
175
+
176
+ def to_paths(self) -> list[str]:
177
+ return list(self.__to_strings_iter())
178
+
179
+ def __to_strings_iter(self, *path: list[str]) -> Iterator[str]:
180
+ for name in self.__attributes:
181
+ yield '.'.join([*path, name])
182
+ if self.is_leaf:
183
+ # We only need to return the path to self if it hasn't already
184
+ # appeared as a root path under an attribtue.
185
+ if path and not self.__attributes:
186
+ yield '.'.join(path)
187
+ else:
188
+ for name, tree in self.sub_trees():
189
+ yield from tree.__to_strings_iter(*path, name)
@@ -215,7 +215,8 @@ class ElasticDataSource(
215
215
  page_size: Optional[int] = None,
216
216
  object_filters: Optional[DataSourceFilter] = None,
217
217
  search_after: list[str] | None = None,
218
- session: Optional[OperableSession] = None
218
+ session: Optional[OperableSession] = None,
219
+ **kwargs,
219
220
  ) -> tuple[Iterable[DataObject], list[str] | None]:
220
221
 
221
222
  resp = self.__get_page_response(
@@ -18,13 +18,32 @@ class BenchlingExtractionToElasticExtractionConverter(
18
18
  'extraction',
19
19
  data_object.id,
20
20
  attributes={
21
- 'sample': {'id': data_object.sts_id},
22
- 'species': {'id': data_object.taxon_id},
23
- 'specimen': {'id': data_object.specimen_id},
24
- 'tolid': {'id': data_object.programme_id},
25
- 'tissue_prep': {'id': data_object.eln_tissue_prep_id},
26
21
  **{k: v
27
22
  for k, v in data_object.attributes.items()
28
23
  if k not in ['sts_id', 'specimen_id', 'taxon_id',
29
- 'programme_id', 'eln_tissue_prep_id']}})
24
+ 'programme_id', 'eln_tissue_prep_id']}
25
+ },
26
+ to_one={
27
+ 'sample': self._data_object_factory(
28
+ 'sample',
29
+ data_object.sts_id
30
+ ) if data_object.sts_id is not None else None,
31
+ 'species': self._data_object_factory(
32
+ 'species',
33
+ data_object.taxon_id
34
+ ) if data_object.taxon_id is not None else None,
35
+ 'specimen': self._data_object_factory(
36
+ 'specimen',
37
+ data_object.specimen_id
38
+ ) if data_object.specimen_id is not None else None,
39
+ 'tolid': self._data_object_factory(
40
+ 'tolid',
41
+ data_object.programme_id
42
+ ) if data_object.programme_id is not None else None,
43
+ 'tissue_prep': self._data_object_factory(
44
+ 'tissue_prep',
45
+ data_object.eln_tissue_prep_id
46
+ ) if data_object.eln_tissue_prep_id is not None else None,
47
+ }
48
+ )
30
49
  yield ret
@@ -18,12 +18,33 @@ class BenchlingExtractionToElasticSequencingRequestConverter(
18
18
  'sequencing_request',
19
19
  data_object.id,
20
20
  attributes={
21
- 'extraction': {'id': data_object.id},
22
- 'sample': {'id': data_object.sts_id},
23
- 'species': {'id': data_object.taxon_id},
24
- 'specimen': {'id': data_object.specimen_id},
25
- 'tolid': {'id': data_object.programme_id},
26
- 'tissue_prep': {'id': data_object.eln_tissue_prep_id},
27
21
  'sequencing_platform': 'pacbio'
28
- })
22
+ },
23
+ to_one={
24
+ 'extraction': self._data_object_factory(
25
+ 'extraction',
26
+ data_object.id
27
+ ) if data_object.id is not None else None,
28
+ 'sample': self._data_object_factory(
29
+ 'sample',
30
+ data_object.sts_id
31
+ ) if data_object.sts_id is not None else None,
32
+ 'species': self._data_object_factory(
33
+ 'species',
34
+ data_object.taxon_id
35
+ ) if data_object.taxon_id is not None else None,
36
+ 'specimen': self._data_object_factory(
37
+ 'specimen',
38
+ data_object.specimen_id
39
+ ) if data_object.specimen_id is not None else None,
40
+ 'tolid': self._data_object_factory(
41
+ 'tolid',
42
+ data_object.programme_id
43
+ ) if data_object.programme_id is not None else None,
44
+ 'tissue_prep': self._data_object_factory(
45
+ 'tissue_prep',
46
+ data_object.eln_tissue_prep_id
47
+ ) if data_object.eln_tissue_prep_id is not None else None,
48
+ }
49
+ )
29
50
  yield ret
@@ -17,22 +17,43 @@ class BenchlingSequencingRequestToElasticSequencingRequestConverter(
17
17
  extraction = None
18
18
  tissue_prep = None
19
19
  if 'extraction_id' in data_object.attributes:
20
- extraction = {'id': data_object.extraction_id}
20
+ extraction = self._data_object_factory(
21
+ 'extraction',
22
+ data_object.extraction_id
23
+ )
21
24
  if 'tissue_prep_id' in data_object.attributes:
22
- tissue_prep = {'id': data_object.tissue_prep_id}
25
+ tissue_prep = self._data_object_factory(
26
+ 'tissue_prep',
27
+ data_object.tissue_prep_id
28
+ )
23
29
  ret = self._data_object_factory(
24
30
  'sequencing_request',
25
31
  data_object.sanger_sample_id,
26
32
  attributes={
27
- 'sample': {'id': str(data_object.sts_id)},
28
- 'specimen': {'id': str(data_object.specimen_id)},
29
- 'species': {'id': str(data_object.taxon_id)},
30
- 'tolid': {'id': data_object.programme_id},
31
- 'extraction': extraction,
32
- 'tissue_prep': tissue_prep,
33
33
  **{k: v
34
34
  for k, v in data_object.attributes.items()
35
35
  if k not in ['sanger_sample_id', 'sts_id',
36
36
  'specimen_id', 'taxon_id', 'extraction_id',
37
- 'programme_id', 'tissue_prep_id']}})
37
+ 'programme_id', 'tissue_prep_id']}
38
+ },
39
+ to_one={
40
+ 'sample': self._data_object_factory(
41
+ 'sample',
42
+ data_object.sts_id
43
+ ) if data_object.sts_id is not None else None,
44
+ 'specimen': self._data_object_factory(
45
+ 'specimen',
46
+ data_object.specimen_id
47
+ ) if data_object.specimen_id is not None else None,
48
+ 'species': self._data_object_factory(
49
+ 'species',
50
+ data_object.taxon_id
51
+ ) if data_object.taxon_id is not None else None,
52
+ 'tolid': self._data_object_factory(
53
+ 'tolid',
54
+ data_object.programme_id
55
+ ) if data_object.programme_id is not None else None,
56
+ 'extraction': extraction,
57
+ 'tissue_prep': tissue_prep,
58
+ })
38
59
  yield ret
@@ -17,15 +17,26 @@ class BenchlingTissuePrepToElasticTissuePrepConverter(
17
17
  'tissue_prep',
18
18
  data_object.eln_tissue_prep_id,
19
19
  attributes={
20
- 'sample': {'id': str(data_object.sts_id)},
21
- 'species': {'id': str(data_object.taxon_id)},
22
- 'tolid': {'id': data_object.programme_id},
23
20
  **{k: v
24
21
  for k, v in data_object.attributes.items()
25
22
  if k not in ['eln_tissue_prep_id',
26
23
  'sts_id',
27
24
  'taxon_id',
28
25
  'programme_id']}
26
+ },
27
+ to_one={
28
+ 'sample': self._data_object_factory(
29
+ 'sample',
30
+ data_object.sts_id
31
+ ) if data_object.sts_id is not None else None,
32
+ 'species': self._data_object_factory(
33
+ 'species',
34
+ data_object.taxon_id
35
+ ) if data_object.taxon_id is not None else None,
36
+ 'tolid': self._data_object_factory(
37
+ 'tolid',
38
+ data_object.programme_id
39
+ ) if data_object.programme_id is not None else None,
29
40
  }
30
41
  )
31
42
  return iter([ret])
@@ -34,7 +34,7 @@ class ElasticSampleToBenchlingTissueUpdateConverter(
34
34
  if species.sts_taxon_group else 'NA',
35
35
  'genome_size': str(species.sts_genome_size),
36
36
  # 'freezer': None,
37
- 'location': data_object.sts_labwhere_parentage, # Previously shelf
37
+ 'location': data_object.sts_labwhere_parentage,
38
38
  'tray': data_object.sts_labwhere_name,
39
39
  'specimen_id': specimen.id,
40
40
  'programme_id': data_object.sts_tolid.id,
@@ -19,7 +19,10 @@ class ElasticSampleToElasticSequencingRequestUpdateConverter(
19
19
  yield (
20
20
  None,
21
21
  {
22
- 'mlwh_sample': {'id': data_object.id},
22
+ 'mlwh_sample': self._data_object_factory(
23
+ 'sample',
24
+ data_object.id
25
+ ),
23
26
  'mlwh_specimen.id': specimen.id
24
27
  }
25
28
  )
@@ -19,7 +19,10 @@ class ElasticTolidToElasticGenomeNoteUpdateConverter(
19
19
  yield (
20
20
  None,
21
21
  {
22
- 'gn_species': {'id': species.id},
22
+ 'gn_species': self._data_object_factory(
23
+ 'species',
24
+ species.id
25
+ ),
23
26
  'gn_tolid.id': data_object.id
24
27
  }
25
28
  )
@@ -21,7 +21,10 @@ class ElasticTolidToElasticSampleUpdateConverter(
21
21
  yield (
22
22
  None,
23
23
  {
24
- 'tolid_tolid': {'id': data_object.id},
24
+ 'tolid_tolid': self._data_object_factory(
25
+ 'tolid',
26
+ data_object.id
27
+ ),
25
28
  'sts_species.id':
26
29
  data_object.requested_taxonomy_id
27
30
  if data_object.requested_taxonomy_id is not None
tol/sources/sts.py CHANGED
@@ -14,13 +14,17 @@ from ..core import (
14
14
  )
15
15
 
16
16
 
17
- def sts(retries: int = 5, **kwargs) -> ApiDataSource:
17
+ def sts(
18
+ retries: int = 5,
19
+ status_forcelist: list[int] | None = [429, 500, 502, 503, 504]
20
+ ) -> ApiDataSource:
18
21
  sts = create_api_datasource(
19
22
  api_url=os.getenv('STS_URL', Defaults.STS_URL)
20
23
  + os.getenv('STS_API_PATH', Defaults.STS_API_PATH),
21
24
  token=os.getenv('STS_API_KEY'),
22
25
  data_prefix=os.getenv('STS_API_DATA_PATH', Defaults.STS_API_DATA_PATH),
23
- retries=retries
26
+ retries=retries,
27
+ status_forcelist=status_forcelist,
24
28
  )
25
29
  core_data_object(sts)
26
30
  return sts
tol/sql/database.py CHANGED
@@ -5,18 +5,22 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  from abc import ABC, abstractmethod
8
- from typing import Any, Dict, Iterable, List, Optional, Type
8
+ from collections.abc import Iterable
9
+ from typing import Any, Dict, List, Optional, Type
9
10
 
10
11
  from sqlalchemy import distinct, func
11
12
  from sqlalchemy.exc import IntegrityError
12
- from sqlalchemy.orm import MappedColumn, Query, Session, joinedload
13
+ from sqlalchemy.orm import Load, MappedColumn, Query, Session, joinedload, load_only, raiseload
13
14
  from sqlalchemy.orm.attributes import flag_modified
14
15
 
15
16
  from .filter import DatabaseFilter
16
17
  from .model import Model
17
18
  from .session import SessionFactory
18
19
  from .sort import DatabaseSorter
19
- from ..core import DataSourceError
20
+ from ..core import DataSourceError, ReqFieldsTree
21
+
22
+
23
+ SubPath = tuple[str | None, ReqFieldsTree]
20
24
 
21
25
 
22
26
  class Database(ABC):
@@ -28,7 +32,7 @@ class Database(ABC):
28
32
  tablename: str,
29
33
  instance_id: Any,
30
34
  in_session: Session,
31
- requested_relationships: dict[str, str] | None = None,
35
+ requested_tree: ReqFieldsTree | None = None,
32
36
  ) -> Optional[Model]:
33
37
  """
34
38
  Gets a single instance by its instance-ID, or None if not found.
@@ -46,7 +50,7 @@ class Database(ABC):
46
50
  sort_by: Optional[DatabaseSorter] = None,
47
51
  offset: Optional[int] = None,
48
52
  limit: Optional[int] = None,
49
- requested_relationships: dict[str, str] | None = None
53
+ requested_tree: ReqFieldsTree | None = None,
50
54
  ) -> Iterable[Model]:
51
55
  """
52
56
  Returns an Iterable of `Model` instances according
@@ -191,14 +195,13 @@ class DefaultDatabase(Database):
191
195
  tablename: str,
192
196
  instance_id: Any,
193
197
  in_session: Session,
194
- requested_relationships: dict[str, str] | None = None,
198
+ requested_tree: ReqFieldsTree | None = None,
195
199
  ) -> Optional[Model]:
196
-
197
200
  result = self.__get_instance_by_id(
198
201
  tablename,
199
202
  instance_id,
200
203
  in_session,
201
- requested_relationships,
204
+ requested_tree,
202
205
  )
203
206
  return result
204
207
 
@@ -210,13 +213,13 @@ class DefaultDatabase(Database):
210
213
  sort_by: Optional[DatabaseSorter] = None,
211
214
  offset: Optional[int] = None,
212
215
  limit: Optional[int] = None,
213
- requested_relationships: dict[str, str] | None = None,
216
+ requested_tree: ReqFieldsTree | None = None,
214
217
  ) -> Iterable[Model]:
215
218
 
216
219
  _, query = self.__get_model_query(
217
220
  tablename,
218
221
  in_session,
219
- requested_relationships,
222
+ requested_tree=requested_tree,
220
223
  filters=filters,
221
224
  )
222
225
  if filters is not None:
@@ -225,7 +228,8 @@ class DefaultDatabase(Database):
225
228
  query = filters.filter(query, tablename, self.__tablename_model_dict)
226
229
  if sort_by is not None:
227
230
  query = sort_by.sort(query, tablename, self.__tablename_model_dict, filters)
228
- query = query.limit(limit).offset(offset)
231
+ if limit is not None and offset is not None:
232
+ query = query.limit(limit).offset(offset)
229
233
  results = query.all()
230
234
  return results
231
235
 
@@ -236,7 +240,7 @@ class DefaultDatabase(Database):
236
240
  filters: Optional[DatabaseFilter] = None
237
241
  ) -> int:
238
242
 
239
- _, query = self.__get_model_query(tablename, in_session, None, filters=filters)
243
+ _, query = self.__get_model_query(tablename, in_session, filters=filters)
240
244
  if filters is not None:
241
245
  query = filters.filter(query, tablename, self.__tablename_model_dict)
242
246
  count = query.count()
@@ -346,7 +350,6 @@ class DefaultDatabase(Database):
346
350
  model, query = self.__get_model_query(
347
351
  tablename,
348
352
  in_session,
349
- None,
350
353
  filters=filters,
351
354
  )
352
355
 
@@ -605,43 +608,18 @@ class DefaultDatabase(Database):
605
608
  self,
606
609
  tablename: str,
607
610
  in_session: Session,
608
- requested_relationships: dict | None,
611
+ requested_tree: ReqFieldsTree | None = None,
609
612
  filters: DatabaseFilter | None = None,
610
613
  ) -> tuple[Type[Model], Query]:
611
614
 
612
615
  model = self.__tablename_model_dict[tablename]
613
- query = in_session.query(model) if not filters else filters.get_query(in_session, model)
616
+ query = filters.get_query(in_session, model) if filters else in_session.query(model)
614
617
 
615
- if requested_relationships:
616
- query = self.__apply_requested_relationships(
617
- query,
618
- requested_relationships
619
- )
618
+ if requested_tree:
619
+ query = self.add_options_to_query(query, tablename, requested_tree)
620
620
 
621
621
  return model, query
622
622
 
623
- def __apply_requested_relationships(
624
- self,
625
- query: Query,
626
- requested_relationships: dict[str, str]
627
- ) -> Query:
628
-
629
- tablename = requested_relationships.pop('__tablename__')
630
- if not requested_relationships:
631
- return query
632
-
633
- model = self.__tablename_model_dict[tablename]
634
-
635
- for r_name, r_dict in requested_relationships.items():
636
- relationship = getattr(model, r_name)
637
- query.options(
638
- joinedload(relationship)
639
- )
640
-
641
- query = self.__apply_requested_relationships(query, r_dict)
642
-
643
- return query
644
-
645
623
  def __commit_session(
646
624
  self,
647
625
  in_session: Session,
@@ -677,7 +655,7 @@ class DefaultDatabase(Database):
677
655
  tablename: str,
678
656
  instance_id: str,
679
657
  in_session: Session,
680
- requested_relationships: dict[str, str] | None = None,
658
+ requested_tree: ReqFieldsTree | None = None,
681
659
  ) -> Optional[Model]:
682
660
  """
683
661
  Gets an instance by its tablename and id.
@@ -686,7 +664,7 @@ class DefaultDatabase(Database):
686
664
  model, query = self.__get_model_query(
687
665
  tablename,
688
666
  in_session,
689
- requested_relationships
667
+ requested_tree=requested_tree,
690
668
  )
691
669
  id_column = getattr(model, model.get_id_column_name())
692
670
  result = query.filter(id_column == instance_id).one_or_none()
@@ -854,3 +832,61 @@ class DefaultDatabase(Database):
854
832
  f'Hint - check the following tables: "{relationship_names}".'
855
833
  )
856
834
  )
835
+
836
+ def add_options_to_query(
837
+ self,
838
+ query: Query,
839
+ tablename: str,
840
+ requested_tree: ReqFieldsTree,
841
+ ):
842
+ options = self.joinedload_options(requested_tree)
843
+ # `raiseload(*)` acts as a trap, raising an exception if any methods
844
+ # on the returned objects are called which would trigger loading data
845
+ # from the database via another SELECT.
846
+ return query.options(options, raiseload('*')) if options else query
847
+
848
+ def joinedload_options(self, tree: ReqFieldsTree) -> list[Load]:
849
+ """
850
+ Returns a list of SQLAlchemy `Load` objects based on the supplied
851
+ `ReqFieldsTree` which specify which related tables to join into and
852
+ which attribute columns to select. This list of `Load` objects can
853
+ then be added to a SQLAlchemy `Query` via a call to `options()`.
854
+ """
855
+ sub: SubPath = None, tree
856
+ return list(self.__joinedload_iter(sub))
857
+
858
+ def __joinedload_iter(self, sub: SubPath, *path: list[SubPath]):
859
+ path = [*path, sub]
860
+ tree = sub[1]
861
+ if tree.is_leaf:
862
+ if options := self.__joinedload_options_from_path(path):
863
+ yield options
864
+ else:
865
+ for sub in tree.sub_trees():
866
+ yield from self.__joinedload_iter(sub, *path)
867
+
868
+ def __joinedload_options_from_path(self, path: list[SubPath]):
869
+ load = None
870
+ prev_model = None
871
+ for rel_name, tree in path:
872
+ model = self.__tablename_model_dict[tree.object_type]
873
+ if prev_model:
874
+ # Add a joinedload for this model if `tree` isn't the root
875
+ relation = getattr(prev_model, rel_name)
876
+ load = load.joinedload(relation) if load else joinedload(relation)
877
+ if names := tree.attribute_names:
878
+ for col_name in model.get_all_foreign_key_names():
879
+ # Always load any to-one ID columns where the relation
880
+ # isn't being fetched so that we can create stub objects
881
+ # for them.
882
+ if not tree.has_attribute(col_name):
883
+ names.append(col_name)
884
+ cols = [getattr(model, x) for x in names if x != 'id']
885
+ load = (
886
+ load.load_only(*cols, raiseload=True)
887
+ if load
888
+ else load_only(*cols, raiseload=True)
889
+ )
890
+ prev_model = model
891
+
892
+ return load
tol/sql/factory.py CHANGED
@@ -31,10 +31,10 @@ def __model_converter_factory(
31
31
  type_function: TypeFunction
32
32
  ) -> ConverterFactory:
33
33
 
34
- return lambda do_factory, requested_fields: DefaultModelConverter(
34
+ return lambda do_factory, req_fields_tree: DefaultModelConverter(
35
35
  type_function,
36
36
  do_factory,
37
- requested_fields=requested_fields
37
+ requested_tree=req_fields_tree
38
38
  )
39
39
 
40
40