deriva-ml 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/.DS_Store +0 -0
- deriva_ml/__init__.py +0 -10
- deriva_ml/core/base.py +18 -6
- deriva_ml/dataset/__init__.py +2 -7
- deriva_ml/dataset/aux_classes.py +21 -11
- deriva_ml/dataset/dataset.py +5 -4
- deriva_ml/dataset/dataset_bag.py +144 -151
- deriva_ml/dataset/upload.py +6 -4
- deriva_ml/demo_catalog.py +16 -2
- deriva_ml/execution/__init__.py +2 -1
- deriva_ml/execution/execution.py +4 -2
- deriva_ml/execution/execution_configuration.py +28 -9
- deriva_ml/execution/workflow.py +8 -0
- deriva_ml/model/catalog.py +55 -50
- deriva_ml/model/database.py +455 -81
- deriva_ml/test.py +94 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/METADATA +9 -7
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/RECORD +22 -21
- deriva_ml/model/sql_mapper.py +0 -44
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/WHEEL +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/entry_points.txt +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/top_level.txt +0 -0
deriva_ml/execution/workflow.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any
|
|
|
9
9
|
import requests
|
|
10
10
|
from pydantic import BaseModel, PrivateAttr, model_validator
|
|
11
11
|
from requests import RequestException
|
|
12
|
+
from setuptools_scm import get_version
|
|
12
13
|
|
|
13
14
|
from deriva_ml.core.definitions import RID
|
|
14
15
|
from deriva_ml.core.exceptions import DerivaMLException
|
|
@@ -129,6 +130,13 @@ class Workflow(BaseModel):
|
|
|
129
130
|
self.url, self.checksum = Workflow.get_url_and_checksum(path)
|
|
130
131
|
self.git_root = Workflow._get_git_root(path)
|
|
131
132
|
|
|
133
|
+
self.version = get_version(
|
|
134
|
+
root=str(self.git_root or Path.cwd()),
|
|
135
|
+
search_parent_directories=True,
|
|
136
|
+
# Optional but recommended: provide a safe fallback when tags are absent
|
|
137
|
+
fallback_version="0.0",
|
|
138
|
+
)
|
|
139
|
+
|
|
132
140
|
self._logger = logging.getLogger("deriva_ml")
|
|
133
141
|
return self
|
|
134
142
|
|
deriva_ml/model/catalog.py
CHANGED
|
@@ -8,7 +8,7 @@ ML-specific functionality. It handles schema management, feature definitions, an
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
10
|
# Standard library imports
|
|
11
|
-
from collections import Counter
|
|
11
|
+
from collections import Counter, defaultdict
|
|
12
12
|
from graphlib import CycleError, TopologicalSorter
|
|
13
13
|
from typing import Any, Callable, Final, Iterable, NewType, TypeAlias
|
|
14
14
|
|
|
@@ -312,7 +312,10 @@ class DerivaModel:
|
|
|
312
312
|
|
|
313
313
|
return [t for a in dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
|
|
314
314
|
|
|
315
|
-
def _prepare_wide_table(self,
|
|
315
|
+
def _prepare_wide_table(self,
|
|
316
|
+
dataset,
|
|
317
|
+
dataset_rid: RID,
|
|
318
|
+
include_tables: list[str]) -> tuple[dict[str, Any], list[tuple]]:
|
|
316
319
|
"""
|
|
317
320
|
Generates details of a wide table from the model
|
|
318
321
|
|
|
@@ -327,7 +330,7 @@ class DerivaModel:
|
|
|
327
330
|
# Skip over tables that we don't want to include in the denormalized dataset.
|
|
328
331
|
# Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
|
|
329
332
|
# table.
|
|
330
|
-
include_tables = set(include_tables)
|
|
333
|
+
include_tables = set(include_tables)
|
|
331
334
|
for t in include_tables:
|
|
332
335
|
# Check to make sure the table is in the catalog.
|
|
333
336
|
_ = self.name_to_table(t)
|
|
@@ -335,8 +338,11 @@ class DerivaModel:
|
|
|
335
338
|
table_paths = [
|
|
336
339
|
path
|
|
337
340
|
for path in self._schema_to_paths()
|
|
338
|
-
if
|
|
341
|
+
if path[-1].name in include_tables and include_tables.intersection({p.name for p in path})
|
|
339
342
|
]
|
|
343
|
+
paths_by_element = defaultdict(list)
|
|
344
|
+
for p in table_paths:
|
|
345
|
+
paths_by_element[p[2].name].append(p)
|
|
340
346
|
|
|
341
347
|
# Get the names of all of the tables that can be dataset elements.
|
|
342
348
|
dataset_element_tables = {
|
|
@@ -344,58 +350,57 @@ class DerivaModel:
|
|
|
344
350
|
}
|
|
345
351
|
|
|
346
352
|
skip_columns = {"RCT", "RMT", "RCB", "RMB"}
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
for
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
(
|
|
385
|
-
|
|
386
|
-
|
|
353
|
+
element_tables = {}
|
|
354
|
+
for element_table, paths in paths_by_element.items():
|
|
355
|
+
graph = {}
|
|
356
|
+
for path in paths:
|
|
357
|
+
for left, right in zip(path[0:], path[1:]):
|
|
358
|
+
graph.setdefault(left.name, set()).add(right.name)
|
|
359
|
+
|
|
360
|
+
# New lets remove any cycles that we may have in the graph.
|
|
361
|
+
# We will use a topological sort to find the order in which we need to join the tables.
|
|
362
|
+
# If we find a cycle, we will remove the table from the graph and splice in an additional ON clause.
|
|
363
|
+
# We will then repeat the process until there are no cycles.
|
|
364
|
+
graph_has_cycles = True
|
|
365
|
+
element_join_tables = []
|
|
366
|
+
element_join_conditions = {}
|
|
367
|
+
while graph_has_cycles:
|
|
368
|
+
try:
|
|
369
|
+
ts = TopologicalSorter(graph)
|
|
370
|
+
element_join_tables = list(reversed(list(ts.static_order())))
|
|
371
|
+
graph_has_cycles = False
|
|
372
|
+
except CycleError as e:
|
|
373
|
+
cycle_nodes = e.args[1]
|
|
374
|
+
if len(cycle_nodes) > 3:
|
|
375
|
+
raise DerivaMLException(f"Unexpected cycle found when normalizing dataset {cycle_nodes}")
|
|
376
|
+
# Remove cycle from graph and splice in additional ON constraint.
|
|
377
|
+
graph[cycle_nodes[1]].remove(cycle_nodes[0])
|
|
378
|
+
|
|
379
|
+
# The Dataset_Version table is a special case as it points to dataset and dataset to version.
|
|
380
|
+
if "Dataset_Version" in element_join_tables:
|
|
381
|
+
element_join_tables.remove("Dataset_Version")
|
|
382
|
+
|
|
383
|
+
for path in paths:
|
|
384
|
+
for left, right in zip(path[0:], path[1:]):
|
|
385
|
+
if right.name == "Dataset_Version":
|
|
386
|
+
# The Dataset_Version table is a special case as it points to dataset and dataset to version.
|
|
387
|
+
continue
|
|
388
|
+
if element_join_tables.index(right.name) < element_join_tables.index(left.name):
|
|
389
|
+
continue
|
|
390
|
+
table_relationship = self._table_relationship(left, right)
|
|
391
|
+
element_join_conditions.setdefault(right.name, set()).add(
|
|
392
|
+
(table_relationship[0], table_relationship[1])
|
|
393
|
+
)
|
|
394
|
+
element_tables[element_table] = (element_join_tables, element_join_conditions)
|
|
387
395
|
# Get the list of columns that will appear in the final denormalized dataset.
|
|
388
396
|
denormalized_columns = [
|
|
389
397
|
(table_name, c.name)
|
|
390
|
-
for table_name in
|
|
398
|
+
for table_name in include_tables
|
|
391
399
|
if not self.is_association(table_name) # Don't include association columns in the denormalized view.'
|
|
392
400
|
for c in self.name_to_table(table_name).columns
|
|
393
|
-
if c.name not in skip_columns
|
|
401
|
+
if (not include_tables or table_name in include_tables) and (c.name not in skip_columns)
|
|
394
402
|
]
|
|
395
|
-
|
|
396
|
-
# List of dataset ids to include in the denormalized view.
|
|
397
|
-
dataset_rids = dataset.list_dataset_children(recurse=True)
|
|
398
|
-
return join_tables, tables, denormalized_columns, dataset_rids, dataset_element_tables
|
|
403
|
+
return element_tables, denormalized_columns
|
|
399
404
|
|
|
400
405
|
def _table_relationship(
|
|
401
406
|
self,
|