deriva-ml 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/.DS_Store +0 -0
- deriva_ml/__init__.py +0 -10
- deriva_ml/core/base.py +18 -6
- deriva_ml/dataset/__init__.py +2 -7
- deriva_ml/dataset/aux_classes.py +21 -11
- deriva_ml/dataset/dataset.py +5 -4
- deriva_ml/dataset/dataset_bag.py +144 -151
- deriva_ml/dataset/upload.py +6 -4
- deriva_ml/demo_catalog.py +16 -2
- deriva_ml/execution/__init__.py +2 -1
- deriva_ml/execution/execution.py +4 -2
- deriva_ml/execution/execution_configuration.py +28 -9
- deriva_ml/execution/workflow.py +8 -0
- deriva_ml/model/catalog.py +55 -50
- deriva_ml/model/database.py +455 -81
- deriva_ml/test.py +94 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/METADATA +9 -7
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/RECORD +22 -21
- deriva_ml/model/sql_mapper.py +0 -44
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/WHEEL +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/entry_points.txt +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/top_level.txt +0 -0
deriva_ml/dataset/dataset_bag.py
CHANGED
|
@@ -4,8 +4,6 @@ The module implements the sqllite interface to a set of directories representing
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
import sqlite3
|
|
8
|
-
|
|
9
7
|
# Standard library imports
|
|
10
8
|
from collections import defaultdict
|
|
11
9
|
from copy import copy
|
|
@@ -16,15 +14,18 @@ import deriva.core.datapath as datapath
|
|
|
16
14
|
# Third-party imports
|
|
17
15
|
import pandas as pd
|
|
18
16
|
|
|
17
|
+
# Local imports
|
|
18
|
+
from deriva.core.ermrest_model import Table
|
|
19
|
+
|
|
19
20
|
# Deriva imports
|
|
20
|
-
from deriva.core.ermrest_model import Column, Table
|
|
21
21
|
from pydantic import ConfigDict, validate_call
|
|
22
|
+
from sqlalchemy import CompoundSelect, Engine, RowMapping, Select, and_, inspect, select, union
|
|
23
|
+
from sqlalchemy.orm import RelationshipProperty, Session
|
|
24
|
+
from sqlalchemy.orm.util import AliasedClass
|
|
22
25
|
|
|
23
|
-
# Local imports
|
|
24
26
|
from deriva_ml.core.definitions import RID, VocabularyTerm
|
|
25
27
|
from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
|
|
26
28
|
from deriva_ml.feature import Feature
|
|
27
|
-
from deriva_ml.model.sql_mapper import SQLMapper
|
|
28
29
|
|
|
29
30
|
if TYPE_CHECKING:
|
|
30
31
|
from deriva_ml.model.database import DatabaseModel
|
|
@@ -64,7 +65,8 @@ class DatasetBag:
|
|
|
64
65
|
dataset_rid: Optional RID for the dataset.
|
|
65
66
|
"""
|
|
66
67
|
self.model = database_model
|
|
67
|
-
self.
|
|
68
|
+
self.engine = cast(Engine, self.model.engine)
|
|
69
|
+
self.metadata = self.model.metadata
|
|
68
70
|
|
|
69
71
|
self.dataset_rid = dataset_rid or self.model.dataset_rid
|
|
70
72
|
if not self.dataset_rid:
|
|
@@ -86,54 +88,48 @@ class DatasetBag:
|
|
|
86
88
|
"""
|
|
87
89
|
return self.model.list_tables()
|
|
88
90
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
[f'"{table_name}"."{c[1]}"' for c in dbase.execute(f'PRAGMA table_info("{table_name}")').fetchall()]
|
|
99
|
-
)
|
|
91
|
+
@staticmethod
|
|
92
|
+
def _find_relationship_attr(source, target):
|
|
93
|
+
"""
|
|
94
|
+
Return the relationship attribute (InstrumentedAttribute) on `source`
|
|
95
|
+
that points to `target`. Works with classes or AliasedClass.
|
|
96
|
+
Raises LookupError if not found.
|
|
97
|
+
"""
|
|
98
|
+
src_mapper = inspect(source).mapper
|
|
99
|
+
tgt_mapper = inspect(target).mapper
|
|
100
100
|
|
|
101
|
-
#
|
|
102
|
-
|
|
103
|
-
[f'"{self.dataset_rid}"'] + [f'"{ds.dataset_rid}"' for ds in self.list_dataset_children(recurse=True)]
|
|
104
|
-
)
|
|
101
|
+
# collect relationships on the *class* mapper (not on alias)
|
|
102
|
+
candidates: list[RelationshipProperty] = [rel for rel in src_mapper.relationships if rel.mapper is tgt_mapper]
|
|
105
103
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
paths = [
|
|
109
|
-
(
|
|
110
|
-
[f'"{self.model.normalize_table_name(t.name)}"' for t in p],
|
|
111
|
-
[self.model._table_relationship(t1, t2) for t1, t2 in zip(p, p[1:])],
|
|
112
|
-
)
|
|
113
|
-
for p in self.model._schema_to_paths()
|
|
114
|
-
if p[-1].name == table
|
|
115
|
-
]
|
|
104
|
+
if not candidates:
|
|
105
|
+
raise LookupError(f"No relationship from {src_mapper.class_.__name__} → {tgt_mapper.class_.__name__}")
|
|
116
106
|
|
|
117
|
-
|
|
118
|
-
|
|
107
|
+
# Prefer MANYTOONE when multiple paths exist (often best for joins)
|
|
108
|
+
candidates.sort(key=lambda r: r.direction.name != "MANYTOONE")
|
|
109
|
+
rel = candidates[0]
|
|
119
110
|
|
|
120
|
-
|
|
121
|
-
|
|
111
|
+
# Bind to the actual source (alias or class)
|
|
112
|
+
return getattr(source, rel.key) if isinstance(source, AliasedClass) else rel.class_attribute
|
|
122
113
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
114
|
+
def _dataset_table_view(self, table: str) -> CompoundSelect[Any]:
|
|
115
|
+
"""Return a SQL command that will return all of the elements in the specified table that are associated with
|
|
116
|
+
dataset_rid"""
|
|
117
|
+
table_class = self.model.get_orm_class_by_name(table)
|
|
118
|
+
dataset_table_class = self.model.get_orm_class_by_name(self._dataset_table.name)
|
|
119
|
+
dataset_rids = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
|
|
120
|
+
|
|
121
|
+
paths = [[t.name for t in p] for p in self.model._schema_to_paths() if p[-1].name == table]
|
|
122
|
+
sql_cmds = []
|
|
123
|
+
for path in paths:
|
|
124
|
+
path_sql = select(table_class)
|
|
125
|
+
last_class = self.model.get_orm_class_by_name(path[0])
|
|
126
|
+
for t in path[1:]:
|
|
127
|
+
t_class = self.model.get_orm_class_by_name(t)
|
|
128
|
+
path_sql = path_sql.join(self._find_relationship_attr(last_class, t_class))
|
|
129
|
+
last_class = t_class
|
|
130
|
+
path_sql = path_sql.where(dataset_table_class.RID.in_(dataset_rids))
|
|
131
|
+
sql_cmds.append(path_sql)
|
|
132
|
+
return union(*sql_cmds)
|
|
137
133
|
|
|
138
134
|
def get_table(self, table: str) -> Generator[tuple, None, None]:
|
|
139
135
|
"""Retrieve the contents of the specified table. If schema is not provided as part of the table name,
|
|
@@ -146,9 +142,10 @@ class DatasetBag:
|
|
|
146
142
|
A generator that yields tuples of column values.
|
|
147
143
|
|
|
148
144
|
"""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
145
|
+
with Session(self.engine) as session:
|
|
146
|
+
result = session.execute(self._dataset_table_view(table))
|
|
147
|
+
for row in result:
|
|
148
|
+
yield row
|
|
152
149
|
|
|
153
150
|
def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
|
|
154
151
|
"""Retrieve the contents of the specified table as a dataframe.
|
|
@@ -163,7 +160,7 @@ class DatasetBag:
|
|
|
163
160
|
Returns:
|
|
164
161
|
A dataframe containing the contents of the specified table.
|
|
165
162
|
"""
|
|
166
|
-
return pd.read_sql(self._dataset_table_view(table), self.
|
|
163
|
+
return pd.read_sql(self._dataset_table_view(table), self.engine)
|
|
167
164
|
|
|
168
165
|
def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
|
|
169
166
|
"""Retrieve the contents of the specified table as a dictionary.
|
|
@@ -176,15 +173,12 @@ class DatasetBag:
|
|
|
176
173
|
A generator producing dictionaries containing the contents of the specified table as name/value pairs.
|
|
177
174
|
"""
|
|
178
175
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
result = self.database.execute(self._dataset_table_view(table))
|
|
184
|
-
while row := result.fetchone():
|
|
185
|
-
yield mapper.transform_tuple(row)
|
|
176
|
+
with Session(self.engine) as session:
|
|
177
|
+
result = session.execute(self._dataset_table_view(table))
|
|
178
|
+
for row in result.mappings():
|
|
179
|
+
yield row
|
|
186
180
|
|
|
187
|
-
@validate_call
|
|
181
|
+
# @validate_call
|
|
188
182
|
def list_dataset_members(self, recurse: bool = False) -> dict[str, list[dict[str, Any]]]:
|
|
189
183
|
"""Return a list of entities associated with a specific dataset.
|
|
190
184
|
|
|
@@ -198,39 +192,31 @@ class DatasetBag:
|
|
|
198
192
|
# Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
|
|
199
193
|
# the appropriate association table.
|
|
200
194
|
members = defaultdict(list)
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
if target_table.schema.name != self.model.domain_schema and not (
|
|
211
|
-
target_table == self._dataset_table or target_table.name == "File"
|
|
212
|
-
):
|
|
195
|
+
|
|
196
|
+
dataset_class = self.model.get_orm_class_for_table(self._dataset_table)
|
|
197
|
+
for element_table in self.model.list_dataset_element_types():
|
|
198
|
+
element_class = self.model.get_orm_class_for_table(element_table)
|
|
199
|
+
|
|
200
|
+
assoc_class, dataset_rel, element_rel = self.model.get_orm_association_class(dataset_class, element_class)
|
|
201
|
+
|
|
202
|
+
element_table = inspect(element_class).mapped_table
|
|
203
|
+
if element_table.schema != self.model.domain_schema and element_table.name not in ["Dataset", "File"]:
|
|
213
204
|
# Look at domain tables and nested datasets.
|
|
214
205
|
continue
|
|
215
|
-
sql_target = self.model.normalize_table_name(target_table.name)
|
|
216
|
-
sql_member = self.model.normalize_table_name(member_table.name)
|
|
217
|
-
|
|
218
206
|
# Get the names of the columns that we are going to need for linking
|
|
219
|
-
|
|
220
|
-
with self.database as db:
|
|
221
|
-
col_names = [c[1] for c in db.execute(f'PRAGMA table_info("{sql_target}")').fetchall()]
|
|
222
|
-
select_cols = ",".join([f'"{sql_target}".{c}' for c in col_names])
|
|
207
|
+
with Session(self.engine) as session:
|
|
223
208
|
sql_cmd = (
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
209
|
+
select(element_class)
|
|
210
|
+
.join(element_rel)
|
|
211
|
+
.where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
|
|
227
212
|
)
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
213
|
+
# Get back the list of ORM entities and convert them to dictionaries.
|
|
214
|
+
element_entities = session.scalars(sql_cmd).all()
|
|
215
|
+
element_rows = [{c.key: getattr(obj, c.key) for c in obj.__table__.columns} for obj in element_entities]
|
|
216
|
+
members[element_table.name].extend(element_rows)
|
|
217
|
+
if recurse and (element_table.name == self._dataset_table.name):
|
|
232
218
|
# Get the members for all the nested datasets and add to the member list.
|
|
233
|
-
nested_datasets = [d["RID"] for d in
|
|
219
|
+
nested_datasets = [d["RID"] for d in element_rows]
|
|
234
220
|
for ds in nested_datasets:
|
|
235
221
|
nested_dataset = self.model.get_dataset(ds)
|
|
236
222
|
for k, v in nested_dataset.list_dataset_members(recurse=recurse).items():
|
|
@@ -259,12 +245,10 @@ class DatasetBag:
|
|
|
259
245
|
Feature values.
|
|
260
246
|
"""
|
|
261
247
|
feature = self.model.lookup_feature(table, feature_name)
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
sql_cmd = f'SELECT * FROM "{feature_table}"'
|
|
267
|
-
return cast(datapath._ResultSet, [dict(zip(col_names, r)) for r in db.execute(sql_cmd).fetchall()])
|
|
248
|
+
feature_class = self.model.get_orm_class_for_table(feature.feature_table)
|
|
249
|
+
with Session(self.engine) as session:
|
|
250
|
+
sql_cmd = select(feature_class)
|
|
251
|
+
return cast(datapath._ResultSet, [row for row in session.execute(sql_cmd).mappings()])
|
|
268
252
|
|
|
269
253
|
def list_dataset_element_types(self) -> list[Table]:
|
|
270
254
|
"""
|
|
@@ -291,18 +275,18 @@ class DatasetBag:
|
|
|
291
275
|
Returns:
|
|
292
276
|
List of child dataset bags.
|
|
293
277
|
"""
|
|
294
|
-
ds_table = self.model.
|
|
295
|
-
nds_table = self.model.
|
|
296
|
-
dv_table = self.model.
|
|
297
|
-
|
|
278
|
+
ds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset")
|
|
279
|
+
nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
|
|
280
|
+
dv_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Version")
|
|
281
|
+
|
|
282
|
+
with Session(self.engine) as session:
|
|
298
283
|
sql_cmd = (
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
f'where "{nds_table}".Dataset == "{self.dataset_rid}"'
|
|
284
|
+
select(nds_table.Nested_Dataset, dv_table.Version)
|
|
285
|
+
.join_from(ds_table, nds_table, onclause=ds_table.RID == nds_table.Nested_Dataset)
|
|
286
|
+
.join_from(ds_table, dv_table, onclause=ds_table.Version == dv_table.RID)
|
|
287
|
+
.where(nds_table.Dataset == self.dataset_rid)
|
|
304
288
|
)
|
|
305
|
-
nested = [DatasetBag(self.model, r[0]) for r in
|
|
289
|
+
nested = [DatasetBag(self.model, r[0]) for r in session.execute(sql_cmd).all()]
|
|
306
290
|
|
|
307
291
|
result = copy(nested)
|
|
308
292
|
if recurse:
|
|
@@ -336,20 +320,19 @@ class DatasetBag:
|
|
|
336
320
|
>>> term = ml.lookup_term("tissue_types", "epithelium")
|
|
337
321
|
"""
|
|
338
322
|
# Get and validate vocabulary table reference
|
|
339
|
-
vocab_table = self.model.normalize_table_name(table)
|
|
340
323
|
if not self.model.is_vocabulary(table):
|
|
341
324
|
raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
|
|
342
325
|
|
|
343
326
|
# Search for term by name or synonym
|
|
344
|
-
for term in self.get_table_as_dict(
|
|
327
|
+
for term in self.get_table_as_dict(table):
|
|
345
328
|
if term_name == term["Name"] or (term["Synonyms"] and term_name in term["Synonyms"]):
|
|
346
329
|
term["Synonyms"] = list(term["Synonyms"])
|
|
347
330
|
return VocabularyTerm.model_validate(term)
|
|
348
331
|
|
|
349
332
|
# Term not found
|
|
350
|
-
raise DerivaMLInvalidTerm(
|
|
333
|
+
raise DerivaMLInvalidTerm(table, term_name)
|
|
351
334
|
|
|
352
|
-
def _denormalize(self, include_tables: list[str]
|
|
335
|
+
def _denormalize(self, include_tables: list[str]) -> Select:
|
|
353
336
|
"""
|
|
354
337
|
Generates an SQL statement for denormalizing the dataset based on the tables to include. Processes cycles in
|
|
355
338
|
graph relationships, ensures proper join order, and generates selected columns for denormalization.
|
|
@@ -361,48 +344,57 @@ class DatasetBag:
|
|
|
361
344
|
Returns:
|
|
362
345
|
str: SQL query string that represents the process of denormalization.
|
|
363
346
|
"""
|
|
364
|
-
|
|
365
|
-
def column_name(col: Column) -> str:
|
|
366
|
-
return f'"{self.model.normalize_table_name(col.table.name)}"."{col.name}"'
|
|
367
|
-
|
|
368
347
|
# Skip over tables that we don't want to include in the denormalized dataset.
|
|
369
348
|
# Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
|
|
370
349
|
# table.
|
|
371
350
|
|
|
372
|
-
|
|
351
|
+
def find_relationship(table, join_condition):
|
|
352
|
+
side1 = (join_condition[0].table.name, join_condition[0].name)
|
|
353
|
+
side2 = (join_condition[1].table.name, join_condition[1].name)
|
|
354
|
+
|
|
355
|
+
for relationship in inspect(table).relationships:
|
|
356
|
+
local_columns = list(relationship.local_columns)[0].table.name, list(relationship.local_columns)[0].name
|
|
357
|
+
remote_side = list(relationship.remote_side)[0].table.name, list(relationship.remote_side)[0].name
|
|
358
|
+
if local_columns == side1 and remote_side == side2 or local_columns == side2 and remote_side == side1:
|
|
359
|
+
return relationship
|
|
360
|
+
return None
|
|
361
|
+
|
|
362
|
+
join_tables, denormalized_columns = (
|
|
373
363
|
self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
|
|
374
364
|
)
|
|
375
365
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
366
|
+
denormalized_columns = [
|
|
367
|
+
self.model.get_orm_class_by_name(table_name)
|
|
368
|
+
.__table__.columns[column_name]
|
|
369
|
+
.label(f"{table_name}.{column_name}")
|
|
380
370
|
for table_name, column_name in denormalized_columns
|
|
381
371
|
]
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
372
|
+
sql_statements = []
|
|
373
|
+
for key, (path, join_conditions) in join_tables.items():
|
|
374
|
+
sql_statement = select(*denormalized_columns).select_from(
|
|
375
|
+
self.model.get_orm_class_for_table(self._dataset_table)
|
|
376
|
+
)
|
|
377
|
+
for table_name in path[1:]: # Skip over dataset table
|
|
378
|
+
table_class = self.model.get_orm_class_by_name(table_name)
|
|
379
|
+
on_clause = [
|
|
380
|
+
getattr(table_class, r.key)
|
|
381
|
+
for on_condition in join_conditions[table_name]
|
|
382
|
+
if (r := find_relationship(table_class, on_condition))
|
|
383
|
+
]
|
|
384
|
+
sql_statement = sql_statement.join(table_class, onclause=and_(*on_clause))
|
|
385
|
+
dataset_rid_list = [self.dataset_rid] + self.list_dataset_children(recurse=True)
|
|
386
|
+
dataset_class = self.model.get_orm_class_by_name(self._dataset_table.name)
|
|
387
|
+
sql_statement = sql_statement.where(dataset_class.RID.in_(dataset_rid_list))
|
|
388
|
+
sql_statements.append(sql_statement)
|
|
389
|
+
return union(*sql_statements)
|
|
390
|
+
|
|
391
|
+
def denormalize_as_dataframe(self, include_tables: list[str]) -> pd.DataFrame:
|
|
401
392
|
"""
|
|
402
393
|
Denormalize the dataset and return the result as a dataframe.
|
|
403
394
|
|
|
404
|
-
|
|
405
|
-
the dataset values into a single wide table. The result is returned as a
|
|
395
|
+
This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
|
|
396
|
+
the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
|
|
397
|
+
for each row in the denormalized wide table.
|
|
406
398
|
|
|
407
399
|
The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
|
|
408
400
|
view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
|
|
@@ -412,28 +404,27 @@ class DatasetBag:
|
|
|
412
404
|
The resulting wide table will include a column for every table needed to complete the denormalization process.
|
|
413
405
|
|
|
414
406
|
Args:
|
|
415
|
-
include_tables: List of table names to include in the denormalized dataset.
|
|
416
|
-
is used.
|
|
407
|
+
include_tables: List of table names to include in the denormalized dataset.
|
|
417
408
|
|
|
418
409
|
Returns:
|
|
419
410
|
Dataframe containing the denormalized dataset.
|
|
420
411
|
"""
|
|
421
|
-
return pd.read_sql(self._denormalize(include_tables=include_tables), self.
|
|
412
|
+
return pd.read_sql(self._denormalize(include_tables=include_tables), self.engine)
|
|
422
413
|
|
|
423
|
-
def denormalize_as_dict(self, include_tables: list[str]
|
|
414
|
+
def denormalize_as_dict(self, include_tables: list[str]) -> Generator[RowMapping, None, None]:
|
|
424
415
|
"""
|
|
425
|
-
Denormalize the dataset and return the result as a set of
|
|
416
|
+
Denormalize the dataset and return the result as a set of dictionary's.
|
|
426
417
|
|
|
427
418
|
This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
|
|
428
|
-
the dataset values into a single wide table. The result is returned as a
|
|
429
|
-
for each row in the
|
|
419
|
+
the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
|
|
420
|
+
for each row in the denormalized wide table.
|
|
430
421
|
|
|
431
422
|
The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
|
|
432
423
|
view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
|
|
433
424
|
additional tables are required to complete the denormalization process. If include_tables is not specified,
|
|
434
425
|
all of the tables in the schema will be included.
|
|
435
426
|
|
|
436
|
-
The resulting wide table will include a column for
|
|
427
|
+
The resulting wide table will include a only those column for the tables listed in include_columns.
|
|
437
428
|
|
|
438
429
|
Args:
|
|
439
430
|
include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
|
|
@@ -442,11 +433,13 @@ class DatasetBag:
|
|
|
442
433
|
Returns:
|
|
443
434
|
A generator that returns a dictionary representation of each row in the denormalized dataset.
|
|
444
435
|
"""
|
|
445
|
-
with self.
|
|
446
|
-
cursor =
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
436
|
+
with Session(self.engine) as session:
|
|
437
|
+
cursor = session.execute(
|
|
438
|
+
self._denormalize(include_tables=include_tables)
|
|
439
|
+
)
|
|
440
|
+
yield from cursor.mappings()
|
|
441
|
+
for row in cursor.mappings():
|
|
442
|
+
yield row
|
|
450
443
|
|
|
451
444
|
|
|
452
445
|
# Add annotations after definition to deal with forward reference issues in pydantic
|
deriva_ml/dataset/upload.py
CHANGED
|
@@ -77,11 +77,11 @@ feature_value_regex = feature_table_dir_regex + f"{SEP}(?P=feature_name)[.](?P<e
|
|
|
77
77
|
feature_asset_dir_regex = feature_table_dir_regex + f"{SEP}asset{SEP}(?P<asset_table>[-\\w]+)"
|
|
78
78
|
feature_asset_regex = feature_asset_dir_regex + f"{SEP}(?P<file>[A-Za-z0-9_-]+)[.](?P<ext>[a-z0-9]*)$"
|
|
79
79
|
|
|
80
|
-
asset_path_regex = exec_dir_regex +
|
|
80
|
+
asset_path_regex = exec_dir_regex + rf"{SEP}asset{SEP}(?P<schema>[-\w]+){SEP}(?P<asset_table>[-\w]*)"
|
|
81
81
|
|
|
82
82
|
asset_file_regex = r"(?P<file>[-\w]+)[.](?P<ext>[a-z0-9]*)$"
|
|
83
83
|
|
|
84
|
-
table_regex = exec_dir_regex +
|
|
84
|
+
table_regex = exec_dir_regex + rf"{SEP}table{SEP}(?P<schema>[-\w]+){SEP}(?P<table>[-\w]+){SEP}(?P=table)[.](csv|json)$"
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
def is_feature_dir(path: Path) -> Optional[re.Match]:
|
|
@@ -190,7 +190,9 @@ def asset_table_upload_spec(model: DerivaModel, asset_table: str | Table):
|
|
|
190
190
|
metadata_columns = model.asset_metadata(asset_table)
|
|
191
191
|
asset_table = model.name_to_table(asset_table)
|
|
192
192
|
schema = model.name_to_table(asset_table).schema.name
|
|
193
|
-
|
|
193
|
+
|
|
194
|
+
# Be careful here as a metadata value might be a string with can contain special characters.
|
|
195
|
+
metadata_path = "/".join([rf"(?P<{c}>[-:._ \w]+)" for c in metadata_columns])
|
|
194
196
|
asset_path = f"{exec_dir_regex}/asset/{schema}/{asset_table.name}/{metadata_path}/{asset_file_regex}"
|
|
195
197
|
asset_table = model.name_to_table(asset_table)
|
|
196
198
|
schema = model.name_to_table(asset_table).schema.name
|
|
@@ -417,7 +419,7 @@ def asset_file_path(
|
|
|
417
419
|
raise DerivaMLException(f"Metadata {metadata} does not match asset metadata {asset_metadata}")
|
|
418
420
|
|
|
419
421
|
for m in asset_metadata:
|
|
420
|
-
path = path / metadata.get(m, "None")
|
|
422
|
+
path = path / str(metadata.get(m, "None"))
|
|
421
423
|
path.mkdir(parents=True, exist_ok=True)
|
|
422
424
|
return path / file_name
|
|
423
425
|
|
deriva_ml/demo_catalog.py
CHANGED
|
@@ -5,6 +5,7 @@ import itertools
|
|
|
5
5
|
import logging
|
|
6
6
|
import string
|
|
7
7
|
from collections.abc import Iterator, Sequence
|
|
8
|
+
from datetime import datetime
|
|
8
9
|
from numbers import Integral
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
from random import choice, randint, random
|
|
@@ -54,7 +55,13 @@ def populate_demo_catalog(ml_instance: DerivaML) -> None:
|
|
|
54
55
|
)
|
|
55
56
|
with execution.execute() as e:
|
|
56
57
|
for s in ss:
|
|
57
|
-
image_file = e.asset_file_path(
|
|
58
|
+
image_file = e.asset_file_path(
|
|
59
|
+
"Image",
|
|
60
|
+
f"test_{s['RID']}.txt",
|
|
61
|
+
Subject=s["RID"],
|
|
62
|
+
Acquisition_Time=datetime.now(),
|
|
63
|
+
Acquisition_Date=datetime.now().date(),
|
|
64
|
+
)
|
|
58
65
|
with image_file.open("w") as f:
|
|
59
66
|
f.write(f"Hello there {random()}\n")
|
|
60
67
|
execution.upload_execution_outputs()
|
|
@@ -343,7 +350,14 @@ def create_domain_schema(catalog: ErmrestCatalog, sname: str) -> None:
|
|
|
343
350
|
)
|
|
344
351
|
with TemporaryDirectory() as tmpdir:
|
|
345
352
|
ml_instance = DerivaML(hostname=catalog.deriva_server.server, catalog_id=catalog.catalog_id, working_dir=tmpdir)
|
|
346
|
-
ml_instance.create_asset(
|
|
353
|
+
ml_instance.create_asset(
|
|
354
|
+
"Image",
|
|
355
|
+
column_defs=[
|
|
356
|
+
Column.define("Acquisition_Time", builtin_types.timestamp),
|
|
357
|
+
Column.define("Acquisition_Date", builtin_types.date),
|
|
358
|
+
],
|
|
359
|
+
referenced_tables=[subject_table],
|
|
360
|
+
)
|
|
347
361
|
catalog_annotation(ml_instance.model)
|
|
348
362
|
|
|
349
363
|
|
deriva_ml/execution/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING
|
|
2
2
|
|
|
3
3
|
# Safe imports - no circular dependencies
|
|
4
|
-
from deriva_ml.execution.execution_configuration import ExecutionConfiguration
|
|
4
|
+
from deriva_ml.execution.execution_configuration import ExecutionConfiguration, AssetRIDConfig
|
|
5
5
|
from deriva_ml.execution.workflow import Workflow
|
|
6
6
|
|
|
7
7
|
if TYPE_CHECKING:
|
|
@@ -22,4 +22,5 @@ __all__ = [
|
|
|
22
22
|
"Execution", # Lazy-loaded
|
|
23
23
|
"ExecutionConfiguration",
|
|
24
24
|
"Workflow",
|
|
25
|
+
"AssetRIDConfig"
|
|
25
26
|
]
|
deriva_ml/execution/execution.py
CHANGED
|
@@ -583,7 +583,6 @@ class Execution:
|
|
|
583
583
|
asset_rid=status.result["RID"],
|
|
584
584
|
)
|
|
585
585
|
)
|
|
586
|
-
|
|
587
586
|
self._update_asset_execution_table(asset_map)
|
|
588
587
|
self.update_status(Status.running, "Updating features...")
|
|
589
588
|
|
|
@@ -805,7 +804,7 @@ class Execution:
|
|
|
805
804
|
self,
|
|
806
805
|
uploaded_assets: dict[str, list[AssetFilePath]],
|
|
807
806
|
asset_role: str = "Output",
|
|
808
|
-
):
|
|
807
|
+
) -> None:
|
|
809
808
|
"""Add entry to the association table connecting an asset to an execution RID
|
|
810
809
|
|
|
811
810
|
Args:
|
|
@@ -814,6 +813,9 @@ class Execution:
|
|
|
814
813
|
asset_role: A term or list of terms from the Asset_Role vocabulary.
|
|
815
814
|
"""
|
|
816
815
|
# Make sure the asset role is in the controlled vocabulary table.
|
|
816
|
+
if self._dry_run:
|
|
817
|
+
# Don't do any updates of we are doing a dry run.
|
|
818
|
+
return
|
|
817
819
|
self._ml_object.lookup_term(MLVocab.asset_role, asset_role)
|
|
818
820
|
|
|
819
821
|
pb = self._ml_object.pathBuilder
|
|
@@ -22,15 +22,17 @@ Typical usage example:
|
|
|
22
22
|
|
|
23
23
|
from __future__ import annotations
|
|
24
24
|
|
|
25
|
+
from dataclasses import dataclass
|
|
25
26
|
import json
|
|
26
27
|
import sys
|
|
27
28
|
from pathlib import Path
|
|
28
29
|
from typing import Any
|
|
29
30
|
|
|
31
|
+
from hydra_zen import builds
|
|
30
32
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
31
33
|
|
|
32
34
|
from deriva_ml.core.definitions import RID
|
|
33
|
-
from deriva_ml.dataset.aux_classes import
|
|
35
|
+
from deriva_ml.dataset.aux_classes import DatasetSpec
|
|
34
36
|
from deriva_ml.execution.workflow import Workflow
|
|
35
37
|
|
|
36
38
|
|
|
@@ -64,7 +66,7 @@ class ExecutionConfiguration(BaseModel):
|
|
|
64
66
|
... )
|
|
65
67
|
"""
|
|
66
68
|
|
|
67
|
-
datasets: list[DatasetSpec]
|
|
69
|
+
datasets: list[DatasetSpec] = []
|
|
68
70
|
assets: list[RID] = []
|
|
69
71
|
workflow: RID | Workflow
|
|
70
72
|
description: str = ""
|
|
@@ -72,13 +74,13 @@ class ExecutionConfiguration(BaseModel):
|
|
|
72
74
|
|
|
73
75
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
74
76
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
77
|
+
# @field_validator("datasets", mode="before")
|
|
78
|
+
# @classmethod
|
|
79
|
+
# def validate_datasets(cls, value: Any) -> Any:
|
|
80
|
+
# if isinstance(value, DatasetList):
|
|
81
|
+
# config_list: DatasetList = value
|
|
82
|
+
# value = config_list.datasets
|
|
83
|
+
# return value
|
|
82
84
|
|
|
83
85
|
@field_validator("workflow", mode="before")
|
|
84
86
|
@classmethod
|
|
@@ -137,3 +139,20 @@ class ExecutionConfiguration(BaseModel):
|
|
|
137
139
|
# hs = HatracStore("https", self.host_name, self.credential)
|
|
138
140
|
# hs.get_obj(path=configuration["URL"], destfilename=dest_file.name)
|
|
139
141
|
# return ExecutionConfiguration.load_configuration(Path(dest_file.name))
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class AssetRID(str):
|
|
146
|
+
rid: str
|
|
147
|
+
description: str = ""
|
|
148
|
+
|
|
149
|
+
def __new__(cls, rid: str, description: str = ""):
|
|
150
|
+
obj = super().__new__(cls, rid)
|
|
151
|
+
obj.description = description
|
|
152
|
+
return obj
|
|
153
|
+
|
|
154
|
+
AssetRIDConfig = builds(AssetRID, populate_full_signature=True)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
|