deriva-ml 1.16.0__py3-none-any.whl → 1.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/.DS_Store +0 -0
- deriva_ml/__init__.py +0 -10
- deriva_ml/core/base.py +18 -6
- deriva_ml/dataset/__init__.py +2 -7
- deriva_ml/dataset/aux_classes.py +2 -10
- deriva_ml/dataset/dataset.py +5 -4
- deriva_ml/dataset/dataset_bag.py +144 -151
- deriva_ml/dataset/upload.py +6 -4
- deriva_ml/demo_catalog.py +16 -2
- deriva_ml/execution/__init__.py +2 -1
- deriva_ml/execution/execution.py +5 -3
- deriva_ml/execution/execution_configuration.py +28 -9
- deriva_ml/execution/workflow.py +8 -0
- deriva_ml/model/catalog.py +55 -50
- deriva_ml/model/database.py +455 -81
- deriva_ml/test.py +94 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.1.dist-info}/METADATA +9 -7
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.1.dist-info}/RECORD +22 -21
- deriva_ml/model/sql_mapper.py +0 -44
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.1.dist-info}/WHEEL +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.1.dist-info}/entry_points.txt +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.1.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.1.dist-info}/top_level.txt +0 -0
deriva_ml/.DS_Store
ADDED
|
Binary file
|
deriva_ml/__init__.py
CHANGED
|
@@ -25,9 +25,6 @@ from deriva_ml.core.exceptions import (
|
|
|
25
25
|
DerivaMLInvalidTerm,
|
|
26
26
|
DerivaMLTableTypeError,
|
|
27
27
|
)
|
|
28
|
-
from deriva_ml.dataset.aux_classes import DatasetConfig, DatasetConfigList, DatasetSpec, DatasetVersion
|
|
29
|
-
|
|
30
|
-
from .execution import Execution, ExecutionConfiguration, Workflow
|
|
31
28
|
|
|
32
29
|
# Type-checking only - avoid circular import at runtime
|
|
33
30
|
if TYPE_CHECKING:
|
|
@@ -51,13 +48,6 @@ def __getattr__(name):
|
|
|
51
48
|
__all__ = [
|
|
52
49
|
"DerivaML", # Lazy-loaded
|
|
53
50
|
"DerivaMLConfig",
|
|
54
|
-
"DatasetConfig",
|
|
55
|
-
"DatasetConfigList",
|
|
56
|
-
"DatasetSpec",
|
|
57
|
-
"DatasetVersion",
|
|
58
|
-
"Execution",
|
|
59
|
-
"ExecutionConfiguration",
|
|
60
|
-
"Workflow",
|
|
61
51
|
# Exceptions
|
|
62
52
|
"DerivaMLException",
|
|
63
53
|
"DerivaMLInvalidTerm",
|
deriva_ml/core/base.py
CHANGED
|
@@ -19,7 +19,7 @@ import logging
|
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from itertools import chain
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Dict, Iterable, List, cast, TYPE_CHECKING, Any
|
|
22
|
+
from typing import Dict, Iterable, List, cast, TYPE_CHECKING, Any, Self
|
|
23
23
|
from urllib.parse import urlsplit
|
|
24
24
|
|
|
25
25
|
|
|
@@ -28,13 +28,14 @@ import requests
|
|
|
28
28
|
from pydantic import ConfigDict, validate_call
|
|
29
29
|
|
|
30
30
|
# Deriva imports
|
|
31
|
-
from deriva.core import DEFAULT_SESSION_CONFIG, format_exception, get_credential, urlquote
|
|
31
|
+
from deriva.core import DEFAULT_SESSION_CONFIG, format_exception, get_credential, urlquote
|
|
32
32
|
|
|
33
33
|
import deriva.core.datapath as datapath
|
|
34
34
|
from deriva.core.datapath import DataPathException, _SchemaWrapper as SchemaWrapper
|
|
35
35
|
from deriva.core.deriva_server import DerivaServer
|
|
36
36
|
from deriva.core.ermrest_catalog import ResolveRidResult
|
|
37
37
|
from deriva.core.ermrest_model import Key, Table
|
|
38
|
+
from deriva.core.utils.core_utils import DEFAULT_LOGGER_OVERRIDES
|
|
38
39
|
from deriva.core.utils.globus_auth_utils import GlobusNativeLogin
|
|
39
40
|
|
|
40
41
|
from deriva_ml.core.exceptions import DerivaMLInvalidTerm
|
|
@@ -103,6 +104,10 @@ class DerivaML(Dataset):
|
|
|
103
104
|
>>> ml.add_term('vocabulary_table', 'new_term', description='Description of term')
|
|
104
105
|
"""
|
|
105
106
|
|
|
107
|
+
@classmethod
|
|
108
|
+
def instantiate(cls, config: DerivaMLConfig) -> Self:
|
|
109
|
+
return cls(**config.model_dump())
|
|
110
|
+
|
|
106
111
|
def __init__(
|
|
107
112
|
self,
|
|
108
113
|
hostname: str,
|
|
@@ -149,7 +154,6 @@ class DerivaML(Dataset):
|
|
|
149
154
|
credentials=self.credential,
|
|
150
155
|
session_config=self._get_session_config(),
|
|
151
156
|
)
|
|
152
|
-
|
|
153
157
|
try:
|
|
154
158
|
if check_auth and server.get_authn_session():
|
|
155
159
|
pass
|
|
@@ -158,7 +162,6 @@ class DerivaML(Dataset):
|
|
|
158
162
|
"You are not authorized to access this catalog. "
|
|
159
163
|
"Please check your credentials and make sure you have logged in."
|
|
160
164
|
)
|
|
161
|
-
|
|
162
165
|
self.catalog = server.connect_ermrest(catalog_id)
|
|
163
166
|
self.model = DerivaModel(self.catalog.getCatalogModel(), domain_schema=domain_schema)
|
|
164
167
|
|
|
@@ -176,9 +179,13 @@ class DerivaML(Dataset):
|
|
|
176
179
|
# Set up logging
|
|
177
180
|
self._logger = logging.getLogger("deriva_ml")
|
|
178
181
|
self._logger.setLevel(logging_level)
|
|
182
|
+
self._logging_level = logging_level
|
|
183
|
+
self._deriva_logging_level = deriva_logging_level
|
|
179
184
|
|
|
180
185
|
# Configure deriva logging level
|
|
181
|
-
|
|
186
|
+
logger_config = DEFAULT_LOGGER_OVERRIDES
|
|
187
|
+
# allow for reconfiguration of module-specific logging levels
|
|
188
|
+
[logging.getLogger(name).setLevel(level) for name, level in logger_config.items()]
|
|
182
189
|
logging.getLogger("bagit").setLevel(deriva_logging_level)
|
|
183
190
|
logging.getLogger("bdbag").setLevel(deriva_logging_level)
|
|
184
191
|
|
|
@@ -1081,7 +1088,12 @@ class DerivaML(Dataset):
|
|
|
1081
1088
|
return self._download_dataset_bag(
|
|
1082
1089
|
dataset=dataset,
|
|
1083
1090
|
execution_rid=execution_rid,
|
|
1084
|
-
snapshot_catalog=DerivaML(
|
|
1091
|
+
snapshot_catalog=DerivaML(
|
|
1092
|
+
self.host_name,
|
|
1093
|
+
self._version_snapshot(dataset),
|
|
1094
|
+
logging_level=self._logging_level,
|
|
1095
|
+
deriva_logging_level=self._deriva_logging_level,
|
|
1096
|
+
),
|
|
1085
1097
|
)
|
|
1086
1098
|
|
|
1087
1099
|
def _update_status(self, new_status: Status, status_detail: str, execution_rid: RID):
|
deriva_ml/dataset/__init__.py
CHANGED
|
@@ -1,16 +1,11 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from deriva_ml.core.definitions import RID
|
|
4
|
-
|
|
5
|
-
from .aux_classes import DatasetConfig, DatasetConfigList, DatasetSpec, DatasetVersion, VersionPart
|
|
1
|
+
from .aux_classes import DatasetSpec, DatasetSpecConfig, DatasetVersion, VersionPart
|
|
6
2
|
from .dataset import Dataset
|
|
7
3
|
from .dataset_bag import DatasetBag
|
|
8
4
|
|
|
9
5
|
__all__ = [
|
|
10
6
|
"Dataset",
|
|
11
7
|
"DatasetSpec",
|
|
12
|
-
"
|
|
13
|
-
"DatasetConfigList",
|
|
8
|
+
"DatasetSpecConfig",
|
|
14
9
|
"DatasetBag",
|
|
15
10
|
"DatasetVersion",
|
|
16
11
|
"VersionPart",
|
deriva_ml/dataset/aux_classes.py
CHANGED
|
@@ -212,18 +212,10 @@ class DatasetSpec(BaseModel):
|
|
|
212
212
|
return version.to_dict()
|
|
213
213
|
|
|
214
214
|
|
|
215
|
+
# Interface for hydra-zen
|
|
215
216
|
@hydrated_dataclass(DatasetSpec)
|
|
216
|
-
class
|
|
217
|
+
class DatasetSpecConfig:
|
|
217
218
|
rid: str
|
|
218
219
|
version: str
|
|
219
220
|
materialize: bool = True
|
|
220
221
|
description: str = ""
|
|
221
|
-
|
|
222
|
-
class DatasetList(BaseModel):
|
|
223
|
-
datasets: list[DatasetSpec]
|
|
224
|
-
description: str = ""
|
|
225
|
-
|
|
226
|
-
@hydrated_dataclass(DatasetList)
|
|
227
|
-
class DatasetConfigList:
|
|
228
|
-
datasets: list[DatasetConfig]
|
|
229
|
-
description: str = ""
|
deriva_ml/dataset/dataset.py
CHANGED
|
@@ -31,6 +31,7 @@ from graphlib import TopologicalSorter
|
|
|
31
31
|
from pathlib import Path
|
|
32
32
|
from tempfile import TemporaryDirectory
|
|
33
33
|
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator
|
|
34
|
+
from urllib.parse import urlparse
|
|
34
35
|
|
|
35
36
|
import deriva.core.utils.hash_utils as hash_utils
|
|
36
37
|
import requests
|
|
@@ -1040,7 +1041,6 @@ class Dataset:
|
|
|
1040
1041
|
envars={"RID": dataset.rid},
|
|
1041
1042
|
)
|
|
1042
1043
|
minid_page_url = exporter.export()[0] # Get the MINID launch page
|
|
1043
|
-
|
|
1044
1044
|
except (
|
|
1045
1045
|
DerivaDownloadError,
|
|
1046
1046
|
DerivaDownloadConfigurationError,
|
|
@@ -1096,7 +1096,8 @@ class Dataset:
|
|
|
1096
1096
|
|
|
1097
1097
|
# Check or create MINID
|
|
1098
1098
|
minid_url = version_record.minid
|
|
1099
|
-
|
|
1099
|
+
# If we either don't have a MINID, or we have a MINID, but we don't want to use it, generate a new one.
|
|
1100
|
+
if (not minid_url) or (not self._use_minid):
|
|
1100
1101
|
if not create:
|
|
1101
1102
|
raise DerivaMLException(f"Minid for dataset {rid} doesn't exist")
|
|
1102
1103
|
if self._use_minid:
|
|
@@ -1106,7 +1107,6 @@ class Dataset:
|
|
|
1106
1107
|
# Return based on MINID usage
|
|
1107
1108
|
if self._use_minid:
|
|
1108
1109
|
return self._fetch_minid_metadata(minid_url, dataset.version)
|
|
1109
|
-
|
|
1110
1110
|
return DatasetMinid(
|
|
1111
1111
|
dataset_version=dataset.version,
|
|
1112
1112
|
RID=f"{rid}@{version_record.snapshot}",
|
|
@@ -1139,7 +1139,8 @@ class Dataset:
|
|
|
1139
1139
|
with TemporaryDirectory() as tmp_dir:
|
|
1140
1140
|
if self._use_minid:
|
|
1141
1141
|
# Get bag from S3
|
|
1142
|
-
|
|
1142
|
+
bag_path = Path(tmp_dir) / Path(urlparse(minid.bag_url).path).name
|
|
1143
|
+
archive_path = fetch_single_file(minid.bag_url, output_path=bag_path)
|
|
1143
1144
|
else:
|
|
1144
1145
|
exporter = DerivaExport(host=self._model.catalog.deriva_server.server, output_dir=tmp_dir)
|
|
1145
1146
|
archive_path = exporter.retrieve_file(minid.bag_url)
|
deriva_ml/dataset/dataset_bag.py
CHANGED
|
@@ -4,8 +4,6 @@ The module implements the sqllite interface to a set of directories representing
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
import sqlite3
|
|
8
|
-
|
|
9
7
|
# Standard library imports
|
|
10
8
|
from collections import defaultdict
|
|
11
9
|
from copy import copy
|
|
@@ -16,15 +14,18 @@ import deriva.core.datapath as datapath
|
|
|
16
14
|
# Third-party imports
|
|
17
15
|
import pandas as pd
|
|
18
16
|
|
|
17
|
+
# Local imports
|
|
18
|
+
from deriva.core.ermrest_model import Table
|
|
19
|
+
|
|
19
20
|
# Deriva imports
|
|
20
|
-
from deriva.core.ermrest_model import Column, Table
|
|
21
21
|
from pydantic import ConfigDict, validate_call
|
|
22
|
+
from sqlalchemy import CompoundSelect, Engine, RowMapping, Select, and_, inspect, select, union
|
|
23
|
+
from sqlalchemy.orm import RelationshipProperty, Session
|
|
24
|
+
from sqlalchemy.orm.util import AliasedClass
|
|
22
25
|
|
|
23
|
-
# Local imports
|
|
24
26
|
from deriva_ml.core.definitions import RID, VocabularyTerm
|
|
25
27
|
from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
|
|
26
28
|
from deriva_ml.feature import Feature
|
|
27
|
-
from deriva_ml.model.sql_mapper import SQLMapper
|
|
28
29
|
|
|
29
30
|
if TYPE_CHECKING:
|
|
30
31
|
from deriva_ml.model.database import DatabaseModel
|
|
@@ -64,7 +65,8 @@ class DatasetBag:
|
|
|
64
65
|
dataset_rid: Optional RID for the dataset.
|
|
65
66
|
"""
|
|
66
67
|
self.model = database_model
|
|
67
|
-
self.
|
|
68
|
+
self.engine = cast(Engine, self.model.engine)
|
|
69
|
+
self.metadata = self.model.metadata
|
|
68
70
|
|
|
69
71
|
self.dataset_rid = dataset_rid or self.model.dataset_rid
|
|
70
72
|
if not self.dataset_rid:
|
|
@@ -86,54 +88,48 @@ class DatasetBag:
|
|
|
86
88
|
"""
|
|
87
89
|
return self.model.list_tables()
|
|
88
90
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
[f'"{table_name}"."{c[1]}"' for c in dbase.execute(f'PRAGMA table_info("{table_name}")').fetchall()]
|
|
99
|
-
)
|
|
91
|
+
@staticmethod
|
|
92
|
+
def _find_relationship_attr(source, target):
|
|
93
|
+
"""
|
|
94
|
+
Return the relationship attribute (InstrumentedAttribute) on `source`
|
|
95
|
+
that points to `target`. Works with classes or AliasedClass.
|
|
96
|
+
Raises LookupError if not found.
|
|
97
|
+
"""
|
|
98
|
+
src_mapper = inspect(source).mapper
|
|
99
|
+
tgt_mapper = inspect(target).mapper
|
|
100
100
|
|
|
101
|
-
#
|
|
102
|
-
|
|
103
|
-
[f'"{self.dataset_rid}"'] + [f'"{ds.dataset_rid}"' for ds in self.list_dataset_children(recurse=True)]
|
|
104
|
-
)
|
|
101
|
+
# collect relationships on the *class* mapper (not on alias)
|
|
102
|
+
candidates: list[RelationshipProperty] = [rel for rel in src_mapper.relationships if rel.mapper is tgt_mapper]
|
|
105
103
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
paths = [
|
|
109
|
-
(
|
|
110
|
-
[f'"{self.model.normalize_table_name(t.name)}"' for t in p],
|
|
111
|
-
[self.model._table_relationship(t1, t2) for t1, t2 in zip(p, p[1:])],
|
|
112
|
-
)
|
|
113
|
-
for p in self.model._schema_to_paths()
|
|
114
|
-
if p[-1].name == table
|
|
115
|
-
]
|
|
104
|
+
if not candidates:
|
|
105
|
+
raise LookupError(f"No relationship from {src_mapper.class_.__name__} → {tgt_mapper.class_.__name__}")
|
|
116
106
|
|
|
117
|
-
|
|
118
|
-
|
|
107
|
+
# Prefer MANYTOONE when multiple paths exist (often best for joins)
|
|
108
|
+
candidates.sort(key=lambda r: r.direction.name != "MANYTOONE")
|
|
109
|
+
rel = candidates[0]
|
|
119
110
|
|
|
120
|
-
|
|
121
|
-
|
|
111
|
+
# Bind to the actual source (alias or class)
|
|
112
|
+
return getattr(source, rel.key) if isinstance(source, AliasedClass) else rel.class_attribute
|
|
122
113
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
114
|
+
def _dataset_table_view(self, table: str) -> CompoundSelect[Any]:
|
|
115
|
+
"""Return a SQL command that will return all of the elements in the specified table that are associated with
|
|
116
|
+
dataset_rid"""
|
|
117
|
+
table_class = self.model.get_orm_class_by_name(table)
|
|
118
|
+
dataset_table_class = self.model.get_orm_class_by_name(self._dataset_table.name)
|
|
119
|
+
dataset_rids = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
|
|
120
|
+
|
|
121
|
+
paths = [[t.name for t in p] for p in self.model._schema_to_paths() if p[-1].name == table]
|
|
122
|
+
sql_cmds = []
|
|
123
|
+
for path in paths:
|
|
124
|
+
path_sql = select(table_class)
|
|
125
|
+
last_class = self.model.get_orm_class_by_name(path[0])
|
|
126
|
+
for t in path[1:]:
|
|
127
|
+
t_class = self.model.get_orm_class_by_name(t)
|
|
128
|
+
path_sql = path_sql.join(self._find_relationship_attr(last_class, t_class))
|
|
129
|
+
last_class = t_class
|
|
130
|
+
path_sql = path_sql.where(dataset_table_class.RID.in_(dataset_rids))
|
|
131
|
+
sql_cmds.append(path_sql)
|
|
132
|
+
return union(*sql_cmds)
|
|
137
133
|
|
|
138
134
|
def get_table(self, table: str) -> Generator[tuple, None, None]:
|
|
139
135
|
"""Retrieve the contents of the specified table. If schema is not provided as part of the table name,
|
|
@@ -146,9 +142,10 @@ class DatasetBag:
|
|
|
146
142
|
A generator that yields tuples of column values.
|
|
147
143
|
|
|
148
144
|
"""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
145
|
+
with Session(self.engine) as session:
|
|
146
|
+
result = session.execute(self._dataset_table_view(table))
|
|
147
|
+
for row in result:
|
|
148
|
+
yield row
|
|
152
149
|
|
|
153
150
|
def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
|
|
154
151
|
"""Retrieve the contents of the specified table as a dataframe.
|
|
@@ -163,7 +160,7 @@ class DatasetBag:
|
|
|
163
160
|
Returns:
|
|
164
161
|
A dataframe containing the contents of the specified table.
|
|
165
162
|
"""
|
|
166
|
-
return pd.read_sql(self._dataset_table_view(table), self.
|
|
163
|
+
return pd.read_sql(self._dataset_table_view(table), self.engine)
|
|
167
164
|
|
|
168
165
|
def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
|
|
169
166
|
"""Retrieve the contents of the specified table as a dictionary.
|
|
@@ -176,15 +173,12 @@ class DatasetBag:
|
|
|
176
173
|
A generator producing dictionaries containing the contents of the specified table as name/value pairs.
|
|
177
174
|
"""
|
|
178
175
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
result = self.database.execute(self._dataset_table_view(table))
|
|
184
|
-
while row := result.fetchone():
|
|
185
|
-
yield mapper.transform_tuple(row)
|
|
176
|
+
with Session(self.engine) as session:
|
|
177
|
+
result = session.execute(self._dataset_table_view(table))
|
|
178
|
+
for row in result.mappings():
|
|
179
|
+
yield row
|
|
186
180
|
|
|
187
|
-
@validate_call
|
|
181
|
+
# @validate_call
|
|
188
182
|
def list_dataset_members(self, recurse: bool = False) -> dict[str, list[dict[str, Any]]]:
|
|
189
183
|
"""Return a list of entities associated with a specific dataset.
|
|
190
184
|
|
|
@@ -198,39 +192,31 @@ class DatasetBag:
|
|
|
198
192
|
# Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
|
|
199
193
|
# the appropriate association table.
|
|
200
194
|
members = defaultdict(list)
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
if target_table.schema.name != self.model.domain_schema and not (
|
|
211
|
-
target_table == self._dataset_table or target_table.name == "File"
|
|
212
|
-
):
|
|
195
|
+
|
|
196
|
+
dataset_class = self.model.get_orm_class_for_table(self._dataset_table)
|
|
197
|
+
for element_table in self.model.list_dataset_element_types():
|
|
198
|
+
element_class = self.model.get_orm_class_for_table(element_table)
|
|
199
|
+
|
|
200
|
+
assoc_class, dataset_rel, element_rel = self.model.get_orm_association_class(dataset_class, element_class)
|
|
201
|
+
|
|
202
|
+
element_table = inspect(element_class).mapped_table
|
|
203
|
+
if element_table.schema != self.model.domain_schema and element_table.name not in ["Dataset", "File"]:
|
|
213
204
|
# Look at domain tables and nested datasets.
|
|
214
205
|
continue
|
|
215
|
-
sql_target = self.model.normalize_table_name(target_table.name)
|
|
216
|
-
sql_member = self.model.normalize_table_name(member_table.name)
|
|
217
|
-
|
|
218
206
|
# Get the names of the columns that we are going to need for linking
|
|
219
|
-
|
|
220
|
-
with self.database as db:
|
|
221
|
-
col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{sql_target}")').fetchall()]
|
|
222
|
-
select_cols = ",".join([f'"{sql_target}".{c}' for c in col_names])
|
|
207
|
+
with Session(self.engine) as session:
|
|
223
208
|
sql_cmd = (
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
209
|
+
select(element_class)
|
|
210
|
+
.join(element_rel)
|
|
211
|
+
.where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
|
|
227
212
|
)
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
213
|
+
# Get back the list of ORM entities and convert them to dictionaries.
|
|
214
|
+
element_entities = session.scalars(sql_cmd).all()
|
|
215
|
+
element_rows = [{c.key: getattr(obj, c.key) for c in obj.__table__.columns} for obj in element_entities]
|
|
216
|
+
members[element_table.name].extend(element_rows)
|
|
217
|
+
if recurse and (element_table.name == self._dataset_table.name):
|
|
232
218
|
# Get the members for all the nested datasets and add to the member list.
|
|
233
|
-
nested_datasets = [d["RID"] for d in
|
|
219
|
+
nested_datasets = [d["RID"] for d in element_rows]
|
|
234
220
|
for ds in nested_datasets:
|
|
235
221
|
nested_dataset = self.model.get_dataset(ds)
|
|
236
222
|
for k, v in nested_dataset.list_dataset_members(recurse=recurse).items():
|
|
@@ -259,12 +245,10 @@ class DatasetBag:
|
|
|
259
245
|
Feature values.
|
|
260
246
|
"""
|
|
261
247
|
feature = self.model.lookup_feature(table, feature_name)
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
sql_cmd = f'SELECT * FROM "{feature_table}"'
|
|
267
|
-
return cast(datapath._ResultSet, [dict(zip(col_names, r)) for r in db.execute(sql_cmd).fetchall()])
|
|
248
|
+
feature_class = self.model.get_orm_class_for_table(feature.feature_table)
|
|
249
|
+
with Session(self.engine) as session:
|
|
250
|
+
sql_cmd = select(feature_class)
|
|
251
|
+
return cast(datapath._ResultSet, [row for row in session.execute(sql_cmd).mappings()])
|
|
268
252
|
|
|
269
253
|
def list_dataset_element_types(self) -> list[Table]:
|
|
270
254
|
"""
|
|
@@ -291,18 +275,18 @@ class DatasetBag:
|
|
|
291
275
|
Returns:
|
|
292
276
|
List of child dataset bags.
|
|
293
277
|
"""
|
|
294
|
-
ds_table = self.model.
|
|
295
|
-
nds_table = self.model.
|
|
296
|
-
dv_table = self.model.
|
|
297
|
-
|
|
278
|
+
ds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset")
|
|
279
|
+
nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
|
|
280
|
+
dv_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Version")
|
|
281
|
+
|
|
282
|
+
with Session(self.engine) as session:
|
|
298
283
|
sql_cmd = (
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
f'where "{nds_table}".Dataset == "{self.dataset_rid}"'
|
|
284
|
+
select(nds_table.Nested_Dataset, dv_table.Version)
|
|
285
|
+
.join_from(ds_table, nds_table, onclause=ds_table.RID == nds_table.Nested_Dataset)
|
|
286
|
+
.join_from(ds_table, dv_table, onclause=ds_table.Version == dv_table.RID)
|
|
287
|
+
.where(nds_table.Dataset == self.dataset_rid)
|
|
304
288
|
)
|
|
305
|
-
nested = [DatasetBag(self.model, r[0]) for r in
|
|
289
|
+
nested = [DatasetBag(self.model, r[0]) for r in session.execute(sql_cmd).all()]
|
|
306
290
|
|
|
307
291
|
result = copy(nested)
|
|
308
292
|
if recurse:
|
|
@@ -336,20 +320,19 @@ class DatasetBag:
|
|
|
336
320
|
>>> term = ml.lookup_term("tissue_types", "epithelium")
|
|
337
321
|
"""
|
|
338
322
|
# Get and validate vocabulary table reference
|
|
339
|
-
vocab_table = self.model.normalize_table_name(table)
|
|
340
323
|
if not self.model.is_vocabulary(table):
|
|
341
324
|
raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
|
|
342
325
|
|
|
343
326
|
# Search for term by name or synonym
|
|
344
|
-
for term in self.get_table_as_dict(
|
|
327
|
+
for term in self.get_table_as_dict(table):
|
|
345
328
|
if term_name == term["Name"] or (term["Synonyms"] and term_name in term["Synonyms"]):
|
|
346
329
|
term["Synonyms"] = list(term["Synonyms"])
|
|
347
330
|
return VocabularyTerm.model_validate(term)
|
|
348
331
|
|
|
349
332
|
# Term not found
|
|
350
|
-
raise DerivaMLInvalidTerm(
|
|
333
|
+
raise DerivaMLInvalidTerm(table, term_name)
|
|
351
334
|
|
|
352
|
-
def _denormalize(self, include_tables: list[str]
|
|
335
|
+
def _denormalize(self, include_tables: list[str]) -> Select:
|
|
353
336
|
"""
|
|
354
337
|
Generates an SQL statement for denormalizing the dataset based on the tables to include. Processes cycles in
|
|
355
338
|
graph relationships, ensures proper join order, and generates selected columns for denormalization.
|
|
@@ -361,48 +344,57 @@ class DatasetBag:
|
|
|
361
344
|
Returns:
|
|
362
345
|
str: SQL query string that represents the process of denormalization.
|
|
363
346
|
"""
|
|
364
|
-
|
|
365
|
-
def column_name(col: Column) -> str:
|
|
366
|
-
return f'"{self.model.normalize_table_name(col.table.name)}"."{col.name}"'
|
|
367
|
-
|
|
368
347
|
# Skip over tables that we don't want to include in the denormalized dataset.
|
|
369
348
|
# Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
|
|
370
349
|
# table.
|
|
371
350
|
|
|
372
|
-
|
|
351
|
+
def find_relationship(table, join_condition):
|
|
352
|
+
side1 = (join_condition[0].table.name, join_condition[0].name)
|
|
353
|
+
side2 = (join_condition[1].table.name, join_condition[1].name)
|
|
354
|
+
|
|
355
|
+
for relationship in inspect(table).relationships:
|
|
356
|
+
local_columns = list(relationship.local_columns)[0].table.name, list(relationship.local_columns)[0].name
|
|
357
|
+
remote_side = list(relationship.remote_side)[0].table.name, list(relationship.remote_side)[0].name
|
|
358
|
+
if local_columns == side1 and remote_side == side2 or local_columns == side2 and remote_side == side1:
|
|
359
|
+
return relationship
|
|
360
|
+
return None
|
|
361
|
+
|
|
362
|
+
join_tables, denormalized_columns = (
|
|
373
363
|
self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
|
|
374
364
|
)
|
|
375
365
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
366
|
+
denormalized_columns = [
|
|
367
|
+
self.model.get_orm_class_by_name(table_name)
|
|
368
|
+
.__table__.columns[column_name]
|
|
369
|
+
.label(f"{table_name}.{column_name}")
|
|
380
370
|
for table_name, column_name in denormalized_columns
|
|
381
371
|
]
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
372
|
+
sql_statements = []
|
|
373
|
+
for key, (path, join_conditions) in join_tables.items():
|
|
374
|
+
sql_statement = select(*denormalized_columns).select_from(
|
|
375
|
+
self.model.get_orm_class_for_table(self._dataset_table)
|
|
376
|
+
)
|
|
377
|
+
for table_name in path[1:]: # Skip over dataset table
|
|
378
|
+
table_class = self.model.get_orm_class_by_name(table_name)
|
|
379
|
+
on_clause = [
|
|
380
|
+
getattr(table_class, r.key)
|
|
381
|
+
for on_condition in join_conditions[table_name]
|
|
382
|
+
if (r := find_relationship(table_class, on_condition))
|
|
383
|
+
]
|
|
384
|
+
sql_statement = sql_statement.join(table_class, onclause=and_(*on_clause))
|
|
385
|
+
dataset_rid_list = [self.dataset_rid] + self.list_dataset_children(recurse=True)
|
|
386
|
+
dataset_class = self.model.get_orm_class_by_name(self._dataset_table.name)
|
|
387
|
+
sql_statement = sql_statement.where(dataset_class.RID.in_(dataset_rid_list))
|
|
388
|
+
sql_statements.append(sql_statement)
|
|
389
|
+
return union(*sql_statements)
|
|
390
|
+
|
|
391
|
+
def denormalize_as_dataframe(self, include_tables: list[str]) -> pd.DataFrame:
|
|
401
392
|
"""
|
|
402
393
|
Denormalize the dataset and return the result as a dataframe.
|
|
403
394
|
|
|
404
|
-
|
|
405
|
-
the dataset values into a single wide table. The result is returned as a
|
|
395
|
+
This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
|
|
396
|
+
the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
|
|
397
|
+
for each row in the denormalized wide table.
|
|
406
398
|
|
|
407
399
|
The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
|
|
408
400
|
view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
|
|
@@ -412,28 +404,27 @@ class DatasetBag:
|
|
|
412
404
|
The resulting wide table will include a column for every table needed to complete the denormalization process.
|
|
413
405
|
|
|
414
406
|
Args:
|
|
415
|
-
include_tables: List of table names to include in the denormalized dataset.
|
|
416
|
-
is used.
|
|
407
|
+
include_tables: List of table names to include in the denormalized dataset.
|
|
417
408
|
|
|
418
409
|
Returns:
|
|
419
410
|
Dataframe containing the denormalized dataset.
|
|
420
411
|
"""
|
|
421
|
-
return pd.read_sql(self._denormalize(include_tables=include_tables), self.
|
|
412
|
+
return pd.read_sql(self._denormalize(include_tables=include_tables), self.engine)
|
|
422
413
|
|
|
423
|
-
def denormalize_as_dict(self, include_tables: list[str]
|
|
414
|
+
def denormalize_as_dict(self, include_tables: list[str]) -> Generator[RowMapping, None, None]:
|
|
424
415
|
"""
|
|
425
|
-
Denormalize the dataset and return the result as a set of
|
|
416
|
+
Denormalize the dataset and return the result as a set of dictionary's.
|
|
426
417
|
|
|
427
418
|
This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
|
|
428
|
-
the dataset values into a single wide table. The result is returned as a
|
|
429
|
-
for each row in the
|
|
419
|
+
the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
|
|
420
|
+
for each row in the denormalized wide table.
|
|
430
421
|
|
|
431
422
|
The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
|
|
432
423
|
view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
|
|
433
424
|
additional tables are required to complete the denormalization process. If include_tables is not specified,
|
|
434
425
|
all of the tables in the schema will be included.
|
|
435
426
|
|
|
436
|
-
The resulting wide table will include a column for
|
|
427
|
+
The resulting wide table will include a only those column for the tables listed in include_columns.
|
|
437
428
|
|
|
438
429
|
Args:
|
|
439
430
|
include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
|
|
@@ -442,11 +433,13 @@ class DatasetBag:
|
|
|
442
433
|
Returns:
|
|
443
434
|
A generator that returns a dictionary representation of each row in the denormalized dataset.
|
|
444
435
|
"""
|
|
445
|
-
with self.
|
|
446
|
-
cursor =
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
436
|
+
with Session(self.engine) as session:
|
|
437
|
+
cursor = session.execute(
|
|
438
|
+
self._denormalize(include_tables=include_tables)
|
|
439
|
+
)
|
|
440
|
+
yield from cursor.mappings()
|
|
441
|
+
for row in cursor.mappings():
|
|
442
|
+
yield row
|
|
450
443
|
|
|
451
444
|
|
|
452
445
|
# Add annotations after definition to deal with forward reference issues in pydantic
|