deriva-ml 1.16.0__py3-none-any.whl → 1.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deriva_ml/.DS_Store ADDED
Binary file
deriva_ml/__init__.py CHANGED
@@ -25,9 +25,6 @@ from deriva_ml.core.exceptions import (
25
25
  DerivaMLInvalidTerm,
26
26
  DerivaMLTableTypeError,
27
27
  )
28
- from deriva_ml.dataset.aux_classes import DatasetConfig, DatasetConfigList, DatasetSpec, DatasetVersion
29
-
30
- from .execution import Execution, ExecutionConfiguration, Workflow
31
28
 
32
29
  # Type-checking only - avoid circular import at runtime
33
30
  if TYPE_CHECKING:
@@ -51,13 +48,6 @@ def __getattr__(name):
51
48
  __all__ = [
52
49
  "DerivaML", # Lazy-loaded
53
50
  "DerivaMLConfig",
54
- "DatasetConfig",
55
- "DatasetConfigList",
56
- "DatasetSpec",
57
- "DatasetVersion",
58
- "Execution",
59
- "ExecutionConfiguration",
60
- "Workflow",
61
51
  # Exceptions
62
52
  "DerivaMLException",
63
53
  "DerivaMLInvalidTerm",
deriva_ml/core/base.py CHANGED
@@ -19,7 +19,7 @@ import logging
19
19
  from datetime import datetime
20
20
  from itertools import chain
21
21
  from pathlib import Path
22
- from typing import Dict, Iterable, List, cast, TYPE_CHECKING, Any
22
+ from typing import Dict, Iterable, List, cast, TYPE_CHECKING, Any, Self
23
23
  from urllib.parse import urlsplit
24
24
 
25
25
 
@@ -28,13 +28,14 @@ import requests
28
28
  from pydantic import ConfigDict, validate_call
29
29
 
30
30
  # Deriva imports
31
- from deriva.core import DEFAULT_SESSION_CONFIG, format_exception, get_credential, urlquote, init_logging
31
+ from deriva.core import DEFAULT_SESSION_CONFIG, format_exception, get_credential, urlquote
32
32
 
33
33
  import deriva.core.datapath as datapath
34
34
  from deriva.core.datapath import DataPathException, _SchemaWrapper as SchemaWrapper
35
35
  from deriva.core.deriva_server import DerivaServer
36
36
  from deriva.core.ermrest_catalog import ResolveRidResult
37
37
  from deriva.core.ermrest_model import Key, Table
38
+ from deriva.core.utils.core_utils import DEFAULT_LOGGER_OVERRIDES
38
39
  from deriva.core.utils.globus_auth_utils import GlobusNativeLogin
39
40
 
40
41
  from deriva_ml.core.exceptions import DerivaMLInvalidTerm
@@ -103,6 +104,10 @@ class DerivaML(Dataset):
103
104
  >>> ml.add_term('vocabulary_table', 'new_term', description='Description of term')
104
105
  """
105
106
 
107
+ @classmethod
108
+ def instantiate(cls, config: DerivaMLConfig) -> Self:
109
+ return cls(**config.model_dump())
110
+
106
111
  def __init__(
107
112
  self,
108
113
  hostname: str,
@@ -149,7 +154,6 @@ class DerivaML(Dataset):
149
154
  credentials=self.credential,
150
155
  session_config=self._get_session_config(),
151
156
  )
152
-
153
157
  try:
154
158
  if check_auth and server.get_authn_session():
155
159
  pass
@@ -158,7 +162,6 @@ class DerivaML(Dataset):
158
162
  "You are not authorized to access this catalog. "
159
163
  "Please check your credentials and make sure you have logged in."
160
164
  )
161
-
162
165
  self.catalog = server.connect_ermrest(catalog_id)
163
166
  self.model = DerivaModel(self.catalog.getCatalogModel(), domain_schema=domain_schema)
164
167
 
@@ -176,9 +179,13 @@ class DerivaML(Dataset):
176
179
  # Set up logging
177
180
  self._logger = logging.getLogger("deriva_ml")
178
181
  self._logger.setLevel(logging_level)
182
+ self._logging_level = logging_level
183
+ self._deriva_logging_level = deriva_logging_level
179
184
 
180
185
  # Configure deriva logging level
181
- init_logging(deriva_logging_level)
186
+ logger_config = DEFAULT_LOGGER_OVERRIDES
187
+ # allow for reconfiguration of module-specific logging levels
188
+ [logging.getLogger(name).setLevel(level) for name, level in logger_config.items()]
182
189
  logging.getLogger("bagit").setLevel(deriva_logging_level)
183
190
  logging.getLogger("bdbag").setLevel(deriva_logging_level)
184
191
 
@@ -1081,7 +1088,12 @@ class DerivaML(Dataset):
1081
1088
  return self._download_dataset_bag(
1082
1089
  dataset=dataset,
1083
1090
  execution_rid=execution_rid,
1084
- snapshot_catalog=DerivaML(self.host_name, self._version_snapshot(dataset)),
1091
+ snapshot_catalog=DerivaML(
1092
+ self.host_name,
1093
+ self._version_snapshot(dataset),
1094
+ logging_level=self._logging_level,
1095
+ deriva_logging_level=self._deriva_logging_level,
1096
+ ),
1085
1097
  )
1086
1098
 
1087
1099
  def _update_status(self, new_status: Status, status_detail: str, execution_rid: RID):
@@ -1,16 +1,11 @@
1
- from typing import Protocol, runtime_checkable
2
-
3
- from deriva_ml.core.definitions import RID
4
-
5
- from .aux_classes import DatasetConfig, DatasetConfigList, DatasetSpec, DatasetVersion, VersionPart
1
+ from .aux_classes import DatasetSpec, DatasetSpecConfig, DatasetVersion, VersionPart
6
2
  from .dataset import Dataset
7
3
  from .dataset_bag import DatasetBag
8
4
 
9
5
  __all__ = [
10
6
  "Dataset",
11
7
  "DatasetSpec",
12
- "DatasetConfig",
13
- "DatasetConfigList",
8
+ "DatasetSpecConfig",
14
9
  "DatasetBag",
15
10
  "DatasetVersion",
16
11
  "VersionPart",
@@ -212,18 +212,10 @@ class DatasetSpec(BaseModel):
212
212
  return version.to_dict()
213
213
 
214
214
 
215
+ # Interface for hydra-zen
215
216
  @hydrated_dataclass(DatasetSpec)
216
- class DatasetConfig:
217
+ class DatasetSpecConfig:
217
218
  rid: str
218
219
  version: str
219
220
  materialize: bool = True
220
221
  description: str = ""
221
-
222
- class DatasetList(BaseModel):
223
- datasets: list[DatasetSpec]
224
- description: str = ""
225
-
226
- @hydrated_dataclass(DatasetList)
227
- class DatasetConfigList:
228
- datasets: list[DatasetConfig]
229
- description: str = ""
@@ -31,6 +31,7 @@ from graphlib import TopologicalSorter
31
31
  from pathlib import Path
32
32
  from tempfile import TemporaryDirectory
33
33
  from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator
34
+ from urllib.parse import urlparse
34
35
 
35
36
  import deriva.core.utils.hash_utils as hash_utils
36
37
  import requests
@@ -1040,7 +1041,6 @@ class Dataset:
1040
1041
  envars={"RID": dataset.rid},
1041
1042
  )
1042
1043
  minid_page_url = exporter.export()[0] # Get the MINID launch page
1043
-
1044
1044
  except (
1045
1045
  DerivaDownloadError,
1046
1046
  DerivaDownloadConfigurationError,
@@ -1096,7 +1096,8 @@ class Dataset:
1096
1096
 
1097
1097
  # Check or create MINID
1098
1098
  minid_url = version_record.minid
1099
- if not minid_url:
1099
+ # If we either don't have a MINID, or we have a MINID, but we don't want to use it, generate a new one.
1100
+ if (not minid_url) or (not self._use_minid):
1100
1101
  if not create:
1101
1102
  raise DerivaMLException(f"Minid for dataset {rid} doesn't exist")
1102
1103
  if self._use_minid:
@@ -1106,7 +1107,6 @@ class Dataset:
1106
1107
  # Return based on MINID usage
1107
1108
  if self._use_minid:
1108
1109
  return self._fetch_minid_metadata(minid_url, dataset.version)
1109
-
1110
1110
  return DatasetMinid(
1111
1111
  dataset_version=dataset.version,
1112
1112
  RID=f"{rid}@{version_record.snapshot}",
@@ -1139,7 +1139,8 @@ class Dataset:
1139
1139
  with TemporaryDirectory() as tmp_dir:
1140
1140
  if self._use_minid:
1141
1141
  # Get bag from S3
1142
- archive_path = fetch_single_file(minid.bag_url, output_path=tmp_dir)
1142
+ bag_path = Path(tmp_dir) / Path(urlparse(minid.bag_url).path).name
1143
+ archive_path = fetch_single_file(minid.bag_url, output_path=bag_path)
1143
1144
  else:
1144
1145
  exporter = DerivaExport(host=self._model.catalog.deriva_server.server, output_dir=tmp_dir)
1145
1146
  archive_path = exporter.retrieve_file(minid.bag_url)
@@ -4,8 +4,6 @@ The module implements the sqllite interface to a set of directories representing
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- import sqlite3
8
-
9
7
  # Standard library imports
10
8
  from collections import defaultdict
11
9
  from copy import copy
@@ -16,15 +14,18 @@ import deriva.core.datapath as datapath
16
14
  # Third-party imports
17
15
  import pandas as pd
18
16
 
17
+ # Local imports
18
+ from deriva.core.ermrest_model import Table
19
+
19
20
  # Deriva imports
20
- from deriva.core.ermrest_model import Column, Table
21
21
  from pydantic import ConfigDict, validate_call
22
+ from sqlalchemy import CompoundSelect, Engine, RowMapping, Select, and_, inspect, select, union
23
+ from sqlalchemy.orm import RelationshipProperty, Session
24
+ from sqlalchemy.orm.util import AliasedClass
22
25
 
23
- # Local imports
24
26
  from deriva_ml.core.definitions import RID, VocabularyTerm
25
27
  from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
26
28
  from deriva_ml.feature import Feature
27
- from deriva_ml.model.sql_mapper import SQLMapper
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  from deriva_ml.model.database import DatabaseModel
@@ -64,7 +65,8 @@ class DatasetBag:
64
65
  dataset_rid: Optional RID for the dataset.
65
66
  """
66
67
  self.model = database_model
67
- self.database = cast(sqlite3.Connection, self.model.dbase)
68
+ self.engine = cast(Engine, self.model.engine)
69
+ self.metadata = self.model.metadata
68
70
 
69
71
  self.dataset_rid = dataset_rid or self.model.dataset_rid
70
72
  if not self.dataset_rid:
@@ -86,54 +88,48 @@ class DatasetBag:
86
88
  """
87
89
  return self.model.list_tables()
88
90
 
89
- def _dataset_table_view(self, table: str) -> str:
90
- """Return a SQL command that will return all of the elements in the specified table that are associated with
91
- dataset_rid"""
92
-
93
- table_name = self.model.normalize_table_name(table)
94
-
95
- # Get the names of the columns in the table.
96
- with self.database as dbase:
97
- select_args = ",".join(
98
- [f'"{table_name}"."{c[1]}"' for c in dbase.execute(f'PRAGMA table_info("{table_name}")').fetchall()]
99
- )
91
+ @staticmethod
92
+ def _find_relationship_attr(source, target):
93
+ """
94
+ Return the relationship attribute (InstrumentedAttribute) on `source`
95
+ that points to `target`. Works with classes or AliasedClass.
96
+ Raises LookupError if not found.
97
+ """
98
+ src_mapper = inspect(source).mapper
99
+ tgt_mapper = inspect(target).mapper
100
100
 
101
- # Get the list of datasets in the bag including the dataset itself.
102
- datasets = ",".join(
103
- [f'"{self.dataset_rid}"'] + [f'"{ds.dataset_rid}"' for ds in self.list_dataset_children(recurse=True)]
104
- )
101
+ # collect relationships on the *class* mapper (not on alias)
102
+ candidates: list[RelationshipProperty] = [rel for rel in src_mapper.relationships if rel.mapper is tgt_mapper]
105
103
 
106
- # Find the paths that terminate in the table we are looking for
107
- # Assemble the ON clause by looking at each table pair, and looking up the FK columns that connect them.
108
- paths = [
109
- (
110
- [f'"{self.model.normalize_table_name(t.name)}"' for t in p],
111
- [self.model._table_relationship(t1, t2) for t1, t2 in zip(p, p[1:])],
112
- )
113
- for p in self.model._schema_to_paths()
114
- if p[-1].name == table
115
- ]
104
+ if not candidates:
105
+ raise LookupError(f"No relationship from {src_mapper.class_.__name__} {tgt_mapper.class_.__name__}")
116
106
 
117
- sql = []
118
- dataset_table_name = f'"{self.model.normalize_table_name(self._dataset_table.name)}"'
107
+ # Prefer MANYTOONE when multiple paths exist (often best for joins)
108
+ candidates.sort(key=lambda r: r.direction.name != "MANYTOONE")
109
+ rel = candidates[0]
119
110
 
120
- def column_name(col: Column) -> str:
121
- return f'"{self.model.normalize_table_name(col.table.name)}"."{col.name}"'
111
+ # Bind to the actual source (alias or class)
112
+ return getattr(source, rel.key) if isinstance(source, AliasedClass) else rel.class_attribute
122
113
 
123
- for ts, on in paths:
124
- tables = " JOIN ".join(ts)
125
- on_expression = " and ".join([f"{column_name(left)}={column_name(right)}" for left, right in on])
126
- sql.append(
127
- f"SELECT {select_args} FROM {tables} "
128
- f"{'ON ' + on_expression if on_expression else ''} "
129
- f"WHERE {dataset_table_name}.RID IN ({datasets})"
130
- )
131
- if table_name == self.model.normalize_table_name(self._dataset_table.name):
132
- sql.append(
133
- f"SELECT {select_args} FROM {dataset_table_name} WHERE {dataset_table_name}.RID IN ({datasets})"
134
- )
135
- sql = " UNION ".join(sql) if len(sql) > 1 else sql[0]
136
- return sql
114
+ def _dataset_table_view(self, table: str) -> CompoundSelect[Any]:
115
+ """Return a SQL command that will return all of the elements in the specified table that are associated with
116
+ dataset_rid"""
117
+ table_class = self.model.get_orm_class_by_name(table)
118
+ dataset_table_class = self.model.get_orm_class_by_name(self._dataset_table.name)
119
+ dataset_rids = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
120
+
121
+ paths = [[t.name for t in p] for p in self.model._schema_to_paths() if p[-1].name == table]
122
+ sql_cmds = []
123
+ for path in paths:
124
+ path_sql = select(table_class)
125
+ last_class = self.model.get_orm_class_by_name(path[0])
126
+ for t in path[1:]:
127
+ t_class = self.model.get_orm_class_by_name(t)
128
+ path_sql = path_sql.join(self._find_relationship_attr(last_class, t_class))
129
+ last_class = t_class
130
+ path_sql = path_sql.where(dataset_table_class.RID.in_(dataset_rids))
131
+ sql_cmds.append(path_sql)
132
+ return union(*sql_cmds)
137
133
 
138
134
  def get_table(self, table: str) -> Generator[tuple, None, None]:
139
135
  """Retrieve the contents of the specified table. If schema is not provided as part of the table name,
@@ -146,9 +142,10 @@ class DatasetBag:
146
142
  A generator that yields tuples of column values.
147
143
 
148
144
  """
149
- result = self.database.execute(self._dataset_table_view(table))
150
- while row := result.fetchone():
151
- yield row
145
+ with Session(self.engine) as session:
146
+ result = session.execute(self._dataset_table_view(table))
147
+ for row in result:
148
+ yield row
152
149
 
153
150
  def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
154
151
  """Retrieve the contents of the specified table as a dataframe.
@@ -163,7 +160,7 @@ class DatasetBag:
163
160
  Returns:
164
161
  A dataframe containing the contents of the specified table.
165
162
  """
166
- return pd.read_sql(self._dataset_table_view(table), self.database)
163
+ return pd.read_sql(self._dataset_table_view(table), self.engine)
167
164
 
168
165
  def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
169
166
  """Retrieve the contents of the specified table as a dictionary.
@@ -176,15 +173,12 @@ class DatasetBag:
176
173
  A generator producing dictionaries containing the contents of the specified table as name/value pairs.
177
174
  """
178
175
 
179
- table_name = self.model.normalize_table_name(table)
180
- schema, table = table_name.split(":")
181
- with self.database as _dbase:
182
- mapper = SQLMapper(self.model, table)
183
- result = self.database.execute(self._dataset_table_view(table))
184
- while row := result.fetchone():
185
- yield mapper.transform_tuple(row)
176
+ with Session(self.engine) as session:
177
+ result = session.execute(self._dataset_table_view(table))
178
+ for row in result.mappings():
179
+ yield row
186
180
 
187
- @validate_call
181
+ # @validate_call
188
182
  def list_dataset_members(self, recurse: bool = False) -> dict[str, list[dict[str, Any]]]:
189
183
  """Return a list of entities associated with a specific dataset.
190
184
 
@@ -198,39 +192,31 @@ class DatasetBag:
198
192
  # Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
199
193
  # the appropriate association table.
200
194
  members = defaultdict(list)
201
- for assoc_table in self._dataset_table.find_associations():
202
- member_fkey = assoc_table.other_fkeys.pop()
203
- if member_fkey.pk_table.name == "Dataset" and member_fkey.foreign_key_columns[0].name != "Nested_Dataset":
204
- # Sometimes find_assoc gets confused on Dataset_Dataset.
205
- member_fkey = assoc_table.self_fkey
206
-
207
- target_table = member_fkey.pk_table
208
- member_table = assoc_table.table
209
-
210
- if target_table.schema.name != self.model.domain_schema and not (
211
- target_table == self._dataset_table or target_table.name == "File"
212
- ):
195
+
196
+ dataset_class = self.model.get_orm_class_for_table(self._dataset_table)
197
+ for element_table in self.model.list_dataset_element_types():
198
+ element_class = self.model.get_orm_class_for_table(element_table)
199
+
200
+ assoc_class, dataset_rel, element_rel = self.model.get_orm_association_class(dataset_class, element_class)
201
+
202
+ element_table = inspect(element_class).mapped_table
203
+ if element_table.schema != self.model.domain_schema and element_table.name not in ["Dataset", "File"]:
213
204
  # Look at domain tables and nested datasets.
214
205
  continue
215
- sql_target = self.model.normalize_table_name(target_table.name)
216
- sql_member = self.model.normalize_table_name(member_table.name)
217
-
218
206
  # Get the names of the columns that we are going to need for linking
219
- member_link = tuple(c.name for c in next(iter(member_fkey.column_map.items())))
220
- with self.database as db:
221
- col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{sql_target}")').fetchall()]
222
- select_cols = ",".join([f'"{sql_target}".{c}' for c in col_names])
207
+ with Session(self.engine) as session:
223
208
  sql_cmd = (
224
- f'SELECT {select_cols} FROM "{sql_member}" '
225
- f'JOIN "{sql_target}" ON "{sql_member}".{member_link[0]} = "{sql_target}".{member_link[1]} '
226
- f'WHERE "{self.dataset_rid}" = "{sql_member}".Dataset;'
209
+ select(element_class)
210
+ .join(element_rel)
211
+ .where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
227
212
  )
228
- mapper = SQLMapper(self.model, sql_target)
229
- target_entities = [mapper.transform_tuple(e) for e in db.execute(sql_cmd).fetchall()]
230
- members[target_table.name].extend(target_entities)
231
- if recurse and (target_table.name == self._dataset_table.name):
213
+ # Get back the list of ORM entities and convert them to dictionaries.
214
+ element_entities = session.scalars(sql_cmd).all()
215
+ element_rows = [{c.key: getattr(obj, c.key) for c in obj.__table__.columns} for obj in element_entities]
216
+ members[element_table.name].extend(element_rows)
217
+ if recurse and (element_table.name == self._dataset_table.name):
232
218
  # Get the members for all the nested datasets and add to the member list.
233
- nested_datasets = [d["RID"] for d in target_entities]
219
+ nested_datasets = [d["RID"] for d in element_rows]
234
220
  for ds in nested_datasets:
235
221
  nested_dataset = self.model.get_dataset(ds)
236
222
  for k, v in nested_dataset.list_dataset_members(recurse=recurse).items():
@@ -259,12 +245,10 @@ class DatasetBag:
259
245
  Feature values.
260
246
  """
261
247
  feature = self.model.lookup_feature(table, feature_name)
262
- feature_table = self.model.normalize_table_name(feature.feature_table.name)
263
-
264
- with self.database as db:
265
- col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{feature_table}")').fetchall()]
266
- sql_cmd = f'SELECT * FROM "{feature_table}"'
267
- return cast(datapath._ResultSet, [dict(zip(col_names, r)) for r in db.execute(sql_cmd).fetchall()])
248
+ feature_class = self.model.get_orm_class_for_table(feature.feature_table)
249
+ with Session(self.engine) as session:
250
+ sql_cmd = select(feature_class)
251
+ return cast(datapath._ResultSet, [row for row in session.execute(sql_cmd).mappings()])
268
252
 
269
253
  def list_dataset_element_types(self) -> list[Table]:
270
254
  """
@@ -291,18 +275,18 @@ class DatasetBag:
291
275
  Returns:
292
276
  List of child dataset bags.
293
277
  """
294
- ds_table = self.model.normalize_table_name("Dataset")
295
- nds_table = self.model.normalize_table_name("Dataset_Dataset")
296
- dv_table = self.model.normalize_table_name("Dataset_Version")
297
- with self.database as db:
278
+ ds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset")
279
+ nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
280
+ dv_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Version")
281
+
282
+ with Session(self.engine) as session:
298
283
  sql_cmd = (
299
- f'SELECT "{nds_table}".Nested_Dataset, "{dv_table}".Version '
300
- f'FROM "{nds_table}" JOIN "{dv_table}" JOIN "{ds_table}" on '
301
- f'"{ds_table}".Version == "{dv_table}".RID AND '
302
- f'"{nds_table}".Nested_Dataset == "{ds_table}".RID '
303
- f'where "{nds_table}".Dataset == "{self.dataset_rid}"'
284
+ select(nds_table.Nested_Dataset, dv_table.Version)
285
+ .join_from(ds_table, nds_table, onclause=ds_table.RID == nds_table.Nested_Dataset)
286
+ .join_from(ds_table, dv_table, onclause=ds_table.Version == dv_table.RID)
287
+ .where(nds_table.Dataset == self.dataset_rid)
304
288
  )
305
- nested = [DatasetBag(self.model, r[0]) for r in db.execute(sql_cmd).fetchall()]
289
+ nested = [DatasetBag(self.model, r[0]) for r in session.execute(sql_cmd).all()]
306
290
 
307
291
  result = copy(nested)
308
292
  if recurse:
@@ -336,20 +320,19 @@ class DatasetBag:
336
320
  >>> term = ml.lookup_term("tissue_types", "epithelium")
337
321
  """
338
322
  # Get and validate vocabulary table reference
339
- vocab_table = self.model.normalize_table_name(table)
340
323
  if not self.model.is_vocabulary(table):
341
324
  raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
342
325
 
343
326
  # Search for term by name or synonym
344
- for term in self.get_table_as_dict(vocab_table):
327
+ for term in self.get_table_as_dict(table):
345
328
  if term_name == term["Name"] or (term["Synonyms"] and term_name in term["Synonyms"]):
346
329
  term["Synonyms"] = list(term["Synonyms"])
347
330
  return VocabularyTerm.model_validate(term)
348
331
 
349
332
  # Term not found
350
- raise DerivaMLInvalidTerm(vocab_table, term_name)
333
+ raise DerivaMLInvalidTerm(table, term_name)
351
334
 
352
- def _denormalize(self, include_tables: list[str] | None) -> str:
335
+ def _denormalize(self, include_tables: list[str]) -> Select:
353
336
  """
354
337
  Generates an SQL statement for denormalizing the dataset based on the tables to include. Processes cycles in
355
338
  graph relationships, ensures proper join order, and generates selected columns for denormalization.
@@ -361,48 +344,57 @@ class DatasetBag:
361
344
  Returns:
362
345
  str: SQL query string that represents the process of denormalization.
363
346
  """
364
-
365
- def column_name(col: Column) -> str:
366
- return f'"{self.model.normalize_table_name(col.table.name)}"."{col.name}"'
367
-
368
347
  # Skip over tables that we don't want to include in the denormalized dataset.
369
348
  # Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
370
349
  # table.
371
350
 
372
- join_tables, tables, denormalized_columns, dataset_rids, dataset_element_tables = (
351
+ def find_relationship(table, join_condition):
352
+ side1 = (join_condition[0].table.name, join_condition[0].name)
353
+ side2 = (join_condition[1].table.name, join_condition[1].name)
354
+
355
+ for relationship in inspect(table).relationships:
356
+ local_columns = list(relationship.local_columns)[0].table.name, list(relationship.local_columns)[0].name
357
+ remote_side = list(relationship.remote_side)[0].table.name, list(relationship.remote_side)[0].name
358
+ if local_columns == side1 and remote_side == side2 or local_columns == side2 and remote_side == side1:
359
+ return relationship
360
+ return None
361
+
362
+ join_tables, denormalized_columns = (
373
363
  self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
374
364
  )
375
365
 
376
- select_args = [
377
- # SQLlite will strip out the table name from the column in the select statement, so we need to add
378
- # an explicit alias to the column name.
379
- f'"{self.model.normalize_table_name(table_name)}"."{column_name}" AS "{table_name}.{column_name}"'
366
+ denormalized_columns = [
367
+ self.model.get_orm_class_by_name(table_name)
368
+ .__table__.columns[column_name]
369
+ .label(f"{table_name}.{column_name}")
380
370
  for table_name, column_name in denormalized_columns
381
371
  ]
382
-
383
- # First table in the table list is the table specified in the method call.
384
- normalized_join_tables = [self.model.normalize_table_name(t) for t in join_tables]
385
- sql_statement = f'SELECT {",".join(select_args)} FROM "{normalized_join_tables[0]}"'
386
- for t in normalized_join_tables[1:]:
387
- on = tables[t]
388
- sql_statement += f' LEFT JOIN "{t}" ON '
389
- sql_statement += "OR ".join([f"{column_name(o[0])} = {column_name(o[1])}" for o in on])
390
-
391
- # Select only rows from the datasets you wish to include.
392
- dataset_rid_list = ",".join([f'"{self.dataset_rid}"'] + [f'"{b.dataset_rid}"' for b in dataset_rids])
393
- sql_statement += f'WHERE "{self.model.normalize_table_name("Dataset")}"."RID" IN ({dataset_rid_list})'
394
-
395
- # Only include rows that have actual values in them.
396
- real_row = [f'"{self.model.normalize_table_name(t)}".RID IS NOT NULL ' for t in dataset_element_tables]
397
- sql_statement += f" AND ({' OR '.join(real_row)})"
398
- return sql_statement
399
-
400
- def denormalize_as_dataframe(self, include_tables: list[str] | None = None) -> pd.DataFrame:
372
+ sql_statements = []
373
+ for key, (path, join_conditions) in join_tables.items():
374
+ sql_statement = select(*denormalized_columns).select_from(
375
+ self.model.get_orm_class_for_table(self._dataset_table)
376
+ )
377
+ for table_name in path[1:]: # Skip over dataset table
378
+ table_class = self.model.get_orm_class_by_name(table_name)
379
+ on_clause = [
380
+ getattr(table_class, r.key)
381
+ for on_condition in join_conditions[table_name]
382
+ if (r := find_relationship(table_class, on_condition))
383
+ ]
384
+ sql_statement = sql_statement.join(table_class, onclause=and_(*on_clause))
385
+ dataset_rid_list = [self.dataset_rid] + self.list_dataset_children(recurse=True)
386
+ dataset_class = self.model.get_orm_class_by_name(self._dataset_table.name)
387
+ sql_statement = sql_statement.where(dataset_class.RID.in_(dataset_rid_list))
388
+ sql_statements.append(sql_statement)
389
+ return union(*sql_statements)
390
+
391
+ def denormalize_as_dataframe(self, include_tables: list[str]) -> pd.DataFrame:
401
392
  """
402
393
  Denormalize the dataset and return the result as a dataframe.
403
394
 
404
- This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
405
- the dataset values into a single wide table. The result is returned as a dataframe.
395
+ This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
396
+ the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
397
+ for each row in the denormalized wide table.
406
398
 
407
399
  The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
408
400
  view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
@@ -412,28 +404,27 @@ class DatasetBag:
412
404
  The resulting wide table will include a column for every table needed to complete the denormalization process.
413
405
 
414
406
  Args:
415
- include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
416
- is used.
407
+ include_tables: List of table names to include in the denormalized dataset.
417
408
 
418
409
  Returns:
419
410
  Dataframe containing the denormalized dataset.
420
411
  """
421
- return pd.read_sql(self._denormalize(include_tables=include_tables), self.database)
412
+ return pd.read_sql(self._denormalize(include_tables=include_tables), self.engine)
422
413
 
423
- def denormalize_as_dict(self, include_tables: list[str] | None = None) -> Generator[dict[str, Any], None, None]:
414
+ def denormalize_as_dict(self, include_tables: list[str]) -> Generator[RowMapping, None, None]:
424
415
  """
425
- Denormalize the dataset and return the result as a set of dictionarys.
416
+ Denormalize the dataset and return the result as a set of dictionary's.
426
417
 
427
418
  This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
428
- the dataset values into a single wide table. The result is returned as a generateor that returns a dictionary
429
- for each row in the denormlized wide table.
419
+ the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
420
+ for each row in the denormalized wide table.
430
421
 
431
422
  The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
432
423
  view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
433
424
  additional tables are required to complete the denormalization process. If include_tables is not specified,
434
425
  all of the tables in the schema will be included.
435
426
 
436
- The resulting wide table will include a column for every table needed to complete the denormalization process.
427
+ The resulting wide table will include a only those column for the tables listed in include_columns.
437
428
 
438
429
  Args:
439
430
  include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
@@ -442,11 +433,13 @@ class DatasetBag:
442
433
  Returns:
443
434
  A generator that returns a dictionary representation of each row in the denormalized dataset.
444
435
  """
445
- with self.database as dbase:
446
- cursor = dbase.execute(self._denormalize(include_tables=include_tables))
447
- columns = [desc[0] for desc in cursor.description]
448
- for row in cursor:
449
- yield dict(zip(columns, row))
436
+ with Session(self.engine) as session:
437
+ cursor = session.execute(
438
+ self._denormalize(include_tables=include_tables)
439
+ )
440
+ yield from cursor.mappings()
441
+ for row in cursor.mappings():
442
+ yield row
450
443
 
451
444
 
452
445
  # Add annotations after definition to deal with forward reference issues in pydantic