deriva-ml 1.17.10__py3-none-any.whl → 1.17.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/__init__.py +43 -1
- deriva_ml/asset/__init__.py +17 -0
- deriva_ml/asset/asset.py +357 -0
- deriva_ml/asset/aux_classes.py +100 -0
- deriva_ml/bump_version.py +254 -11
- deriva_ml/catalog/__init__.py +21 -0
- deriva_ml/catalog/clone.py +1199 -0
- deriva_ml/catalog/localize.py +426 -0
- deriva_ml/core/__init__.py +29 -0
- deriva_ml/core/base.py +817 -1067
- deriva_ml/core/config.py +169 -21
- deriva_ml/core/constants.py +120 -19
- deriva_ml/core/definitions.py +123 -13
- deriva_ml/core/enums.py +47 -73
- deriva_ml/core/ermrest.py +226 -193
- deriva_ml/core/exceptions.py +297 -14
- deriva_ml/core/filespec.py +99 -28
- deriva_ml/core/logging_config.py +225 -0
- deriva_ml/core/mixins/__init__.py +42 -0
- deriva_ml/core/mixins/annotation.py +915 -0
- deriva_ml/core/mixins/asset.py +384 -0
- deriva_ml/core/mixins/dataset.py +237 -0
- deriva_ml/core/mixins/execution.py +408 -0
- deriva_ml/core/mixins/feature.py +365 -0
- deriva_ml/core/mixins/file.py +263 -0
- deriva_ml/core/mixins/path_builder.py +145 -0
- deriva_ml/core/mixins/rid_resolution.py +204 -0
- deriva_ml/core/mixins/vocabulary.py +400 -0
- deriva_ml/core/mixins/workflow.py +322 -0
- deriva_ml/core/validation.py +389 -0
- deriva_ml/dataset/__init__.py +2 -1
- deriva_ml/dataset/aux_classes.py +20 -4
- deriva_ml/dataset/catalog_graph.py +575 -0
- deriva_ml/dataset/dataset.py +1242 -1008
- deriva_ml/dataset/dataset_bag.py +1311 -182
- deriva_ml/dataset/history.py +27 -14
- deriva_ml/dataset/upload.py +225 -38
- deriva_ml/demo_catalog.py +126 -110
- deriva_ml/execution/__init__.py +46 -2
- deriva_ml/execution/base_config.py +639 -0
- deriva_ml/execution/execution.py +543 -242
- deriva_ml/execution/execution_configuration.py +26 -11
- deriva_ml/execution/execution_record.py +592 -0
- deriva_ml/execution/find_caller.py +298 -0
- deriva_ml/execution/model_protocol.py +175 -0
- deriva_ml/execution/multirun_config.py +153 -0
- deriva_ml/execution/runner.py +595 -0
- deriva_ml/execution/workflow.py +223 -34
- deriva_ml/experiment/__init__.py +8 -0
- deriva_ml/experiment/experiment.py +411 -0
- deriva_ml/feature.py +6 -1
- deriva_ml/install_kernel.py +143 -6
- deriva_ml/interfaces.py +862 -0
- deriva_ml/model/__init__.py +99 -0
- deriva_ml/model/annotations.py +1278 -0
- deriva_ml/model/catalog.py +286 -60
- deriva_ml/model/database.py +144 -649
- deriva_ml/model/deriva_ml_database.py +308 -0
- deriva_ml/model/handles.py +14 -0
- deriva_ml/run_model.py +319 -0
- deriva_ml/run_notebook.py +507 -38
- deriva_ml/schema/__init__.py +18 -2
- deriva_ml/schema/annotations.py +62 -33
- deriva_ml/schema/create_schema.py +169 -69
- deriva_ml/schema/validation.py +601 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -4
- deriva_ml-1.17.11.dist-info/RECORD +77 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +1 -0
- deriva_ml/protocols/dataset.py +0 -19
- deriva_ml/test.py +0 -94
- deriva_ml-1.17.10.dist-info/RECORD +0 -45
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
deriva_ml/dataset/dataset_bag.py
CHANGED
|
@@ -1,13 +1,44 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
1
|
+
"""SQLite-backed dataset access for downloaded BDBags.
|
|
2
|
+
|
|
3
|
+
This module provides the DatasetBag class, which allows querying and navigating
|
|
4
|
+
downloaded dataset bags using SQLite. When a dataset is downloaded from a Deriva
|
|
5
|
+
catalog, it is stored as a BDBag (Big Data Bag) containing:
|
|
6
|
+
|
|
7
|
+
- CSV files with table data
|
|
8
|
+
- Asset files (images, documents, etc.)
|
|
9
|
+
- A schema.json describing the catalog structure
|
|
10
|
+
- A fetch.txt manifest of referenced files
|
|
11
|
+
|
|
12
|
+
The DatasetBag class provides a read-only interface to this data, mirroring
|
|
13
|
+
the Dataset class API where possible. This allows code to work uniformly
|
|
14
|
+
with both live catalog datasets and downloaded bags.
|
|
15
|
+
|
|
16
|
+
Key concepts:
|
|
17
|
+
- DatasetBag wraps a single dataset within a downloaded bag
|
|
18
|
+
- A bag may contain multiple datasets (nested/hierarchical)
|
|
19
|
+
- All operations are read-only (bags are immutable snapshots)
|
|
20
|
+
- Queries use SQLite via SQLAlchemy ORM
|
|
21
|
+
- Table-level access (get_table_as_dict, lookup_term) is on the catalog (DerivaMLDatabase)
|
|
22
|
+
|
|
23
|
+
Typical usage:
|
|
24
|
+
>>> # Download a dataset from a catalog
|
|
25
|
+
>>> bag = ml.download_dataset_bag(dataset_spec)
|
|
26
|
+
>>> # List dataset members by type
|
|
27
|
+
>>> members = bag.list_dataset_members(recurse=True)
|
|
28
|
+
>>> for image in members.get("Image", []):
|
|
29
|
+
... print(image["Filename"])
|
|
3
30
|
"""
|
|
4
31
|
|
|
5
32
|
from __future__ import annotations
|
|
6
33
|
|
|
7
34
|
# Standard library imports
|
|
35
|
+
import logging
|
|
36
|
+
import shutil
|
|
8
37
|
from collections import defaultdict
|
|
9
38
|
from copy import copy
|
|
10
|
-
from
|
|
39
|
+
from dataclasses import dataclass, field
|
|
40
|
+
from pathlib import Path
|
|
41
|
+
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Self, cast
|
|
11
42
|
|
|
12
43
|
import deriva.core.datapath as datapath
|
|
13
44
|
|
|
@@ -18,17 +49,17 @@ import pandas as pd
|
|
|
18
49
|
from deriva.core.ermrest_model import Table
|
|
19
50
|
|
|
20
51
|
# Deriva imports
|
|
21
|
-
from
|
|
22
|
-
from sqlalchemy import CompoundSelect, Engine, RowMapping, Select, and_, inspect, select, union
|
|
52
|
+
from sqlalchemy import CompoundSelect, Engine, Select, and_, inspect, select, union
|
|
23
53
|
from sqlalchemy.orm import RelationshipProperty, Session
|
|
24
54
|
from sqlalchemy.orm.util import AliasedClass
|
|
25
55
|
|
|
26
|
-
from deriva_ml.core.definitions import RID
|
|
27
|
-
from deriva_ml.core.exceptions import DerivaMLException
|
|
28
|
-
from deriva_ml.
|
|
56
|
+
from deriva_ml.core.definitions import RID
|
|
57
|
+
from deriva_ml.core.exceptions import DerivaMLException
|
|
58
|
+
from deriva_ml.dataset.aux_classes import DatasetHistory, DatasetVersion
|
|
59
|
+
from deriva_ml.feature import Feature, FeatureRecord
|
|
29
60
|
|
|
30
61
|
if TYPE_CHECKING:
|
|
31
|
-
from deriva_ml.model.
|
|
62
|
+
from deriva_ml.model.deriva_ml_database import DerivaMLDatabase
|
|
32
63
|
|
|
33
64
|
try:
|
|
34
65
|
from icecream import ic
|
|
@@ -36,69 +67,235 @@ except ImportError: # Graceful fallback if IceCream isn't installed.
|
|
|
36
67
|
ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
|
|
37
68
|
|
|
38
69
|
|
|
39
|
-
|
|
70
|
+
@dataclass
|
|
71
|
+
class FeatureValueRecord:
|
|
72
|
+
"""A feature value record with execution provenance.
|
|
73
|
+
|
|
74
|
+
This class represents a single feature value assigned to an asset,
|
|
75
|
+
including the execution that created it. Used by restructure_assets
|
|
76
|
+
when a value_selector function needs to choose between multiple
|
|
77
|
+
feature values for the same asset.
|
|
78
|
+
|
|
79
|
+
The raw_record attribute contains the complete feature table row as
|
|
80
|
+
a dictionary, which can be used to access all columns including any
|
|
81
|
+
additional metadata or columns beyond the primary value.
|
|
82
|
+
|
|
83
|
+
Attributes:
|
|
84
|
+
target_rid: RID of the asset/entity this feature value applies to.
|
|
85
|
+
feature_name: Name of the feature.
|
|
86
|
+
value: The feature value (typically a vocabulary term name).
|
|
87
|
+
execution_rid: RID of the execution that created this feature value, if any.
|
|
88
|
+
Use this to distinguish between values from different executions.
|
|
89
|
+
raw_record: The complete raw record from the feature table as a dictionary.
|
|
90
|
+
Access all columns via dict keys, e.g., record.raw_record["MyColumn"].
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
Using a value_selector to choose the most recent feature value::
|
|
94
|
+
|
|
95
|
+
def select_by_execution(records: list[FeatureValueRecord]) -> FeatureValueRecord:
|
|
96
|
+
# Select value from most recent execution (assuming RIDs are sortable)
|
|
97
|
+
return max(records, key=lambda r: r.execution_rid or "")
|
|
98
|
+
|
|
99
|
+
bag.restructure_assets(
|
|
100
|
+
output_dir="./ml_data",
|
|
101
|
+
group_by=["Diagnosis"],
|
|
102
|
+
value_selector=select_by_execution,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
Accessing raw record data::
|
|
106
|
+
|
|
107
|
+
def select_by_confidence(records: list[FeatureValueRecord]) -> FeatureValueRecord:
|
|
108
|
+
# Select value with highest confidence score from raw record
|
|
109
|
+
return max(records, key=lambda r: r.raw_record.get("Confidence", 0))
|
|
40
110
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
111
|
+
target_rid: RID
|
|
112
|
+
feature_name: str
|
|
113
|
+
value: Any
|
|
114
|
+
execution_rid: RID | None = None
|
|
115
|
+
raw_record: dict[str, Any] = field(default_factory=dict)
|
|
116
|
+
|
|
117
|
+
def __repr__(self) -> str:
|
|
118
|
+
return (f"FeatureValueRecord(target_rid='{self.target_rid}', "
|
|
119
|
+
f"feature_name='{self.feature_name}', value='{self.value}', "
|
|
120
|
+
f"execution_rid='{self.execution_rid}')")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class DatasetBag:
|
|
124
|
+
"""Read-only interface to a downloaded dataset bag.
|
|
125
|
+
|
|
126
|
+
DatasetBag manages access to a materialized BDBag (Big Data Bag) that contains
|
|
127
|
+
a snapshot of dataset data from a Deriva catalog. It provides methods for:
|
|
44
128
|
|
|
45
|
-
|
|
46
|
-
|
|
129
|
+
- Listing dataset members and their attributes
|
|
130
|
+
- Navigating dataset relationships (parents, children)
|
|
131
|
+
- Accessing feature values
|
|
132
|
+
- Denormalizing data across related tables
|
|
47
133
|
|
|
48
|
-
|
|
134
|
+
A bag may contain multiple datasets when nested datasets are involved. Each
|
|
135
|
+
DatasetBag instance represents a single dataset within the bag - use
|
|
136
|
+
list_dataset_children() to navigate to nested datasets.
|
|
137
|
+
|
|
138
|
+
For catalog-level operations like querying arbitrary tables or looking up
|
|
139
|
+
vocabulary terms, use the DerivaMLDatabase class instead.
|
|
140
|
+
|
|
141
|
+
The class implements the DatasetLike protocol, providing the same read interface
|
|
142
|
+
as the Dataset class. This allows code to work with both live catalogs and
|
|
143
|
+
downloaded bags interchangeably.
|
|
49
144
|
|
|
50
145
|
Attributes:
|
|
51
|
-
dataset_rid (RID):
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
146
|
+
dataset_rid (RID): The unique Resource Identifier for this dataset.
|
|
147
|
+
dataset_types (list[str]): List of vocabulary terms describing the dataset type.
|
|
148
|
+
description (str): Human-readable description of the dataset.
|
|
149
|
+
execution_rid (RID | None): RID of the execution associated with this dataset version, if any.
|
|
150
|
+
model (DatabaseModel): The DatabaseModel providing SQLite access to bag data.
|
|
151
|
+
engine (Engine): SQLAlchemy engine for database queries.
|
|
152
|
+
metadata (MetaData): SQLAlchemy metadata with table definitions.
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
>>> # Download a dataset
|
|
156
|
+
>>> bag = dataset.download_dataset_bag(version="1.0.0")
|
|
157
|
+
>>> # List members by type
|
|
158
|
+
>>> members = bag.list_dataset_members()
|
|
159
|
+
>>> for image in members.get("Image", []):
|
|
160
|
+
... print(f"File: {image['Filename']}")
|
|
161
|
+
>>> # Navigate to nested datasets
|
|
162
|
+
>>> for child in bag.list_dataset_children():
|
|
163
|
+
... print(f"Nested: {child.dataset_rid}")
|
|
57
164
|
"""
|
|
58
165
|
|
|
59
|
-
def __init__(
|
|
60
|
-
|
|
61
|
-
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
catalog: "DerivaMLDatabase",
|
|
169
|
+
dataset_rid: RID | None = None,
|
|
170
|
+
dataset_types: str | list[str] | None = None,
|
|
171
|
+
description: str = "",
|
|
172
|
+
execution_rid: RID | None = None,
|
|
173
|
+
):
|
|
174
|
+
"""Initialize a DatasetBag instance for a dataset within a downloaded bag.
|
|
175
|
+
|
|
176
|
+
This mirrors the Dataset class initialization pattern, where both classes
|
|
177
|
+
take a catalog-like object as their first argument for consistency.
|
|
62
178
|
|
|
63
179
|
Args:
|
|
64
|
-
|
|
65
|
-
|
|
180
|
+
catalog: The DerivaMLDatabase instance providing access to the bag's data.
|
|
181
|
+
This implements the DerivaMLCatalog protocol.
|
|
182
|
+
dataset_rid: The RID of the dataset to wrap. If None, uses the primary
|
|
183
|
+
dataset RID from the bag.
|
|
184
|
+
dataset_types: One or more dataset type terms. Can be a single string
|
|
185
|
+
or list of strings.
|
|
186
|
+
description: Human-readable description of the dataset.
|
|
187
|
+
execution_rid: RID of the execution associated with this dataset version.
|
|
188
|
+
If None, will be looked up from the Dataset_Version table.
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
DerivaMLException: If no dataset_rid is provided and none can be
|
|
192
|
+
determined from the bag, or if the RID doesn't exist in the bag.
|
|
66
193
|
"""
|
|
67
|
-
|
|
194
|
+
# Store reference to the catalog and extract the underlying model
|
|
195
|
+
self._catalog = catalog
|
|
196
|
+
self.model = catalog.model
|
|
68
197
|
self.engine = cast(Engine, self.model.engine)
|
|
69
198
|
self.metadata = self.model.metadata
|
|
70
199
|
|
|
200
|
+
# Use provided RID or fall back to the bag's primary dataset
|
|
71
201
|
self.dataset_rid = dataset_rid or self.model.dataset_rid
|
|
202
|
+
self.description = description
|
|
203
|
+
self.execution_rid = execution_rid or (
|
|
204
|
+
self.model._get_dataset_execution(self.dataset_rid) or {}
|
|
205
|
+
).get("Execution")
|
|
206
|
+
|
|
207
|
+
# Normalize dataset_types to always be a list of strings for consistency
|
|
208
|
+
# with the Dataset class interface
|
|
209
|
+
if dataset_types is None:
|
|
210
|
+
self.dataset_types: list[str] = []
|
|
211
|
+
elif isinstance(dataset_types, str):
|
|
212
|
+
self.dataset_types: list[str] = [dataset_types]
|
|
213
|
+
else:
|
|
214
|
+
self.dataset_types: list[str] = list(dataset_types)
|
|
215
|
+
|
|
72
216
|
if not self.dataset_rid:
|
|
73
217
|
raise DerivaMLException("No dataset RID provided")
|
|
74
218
|
|
|
75
|
-
|
|
219
|
+
# Validate that this dataset exists in the bag
|
|
220
|
+
self.model.rid_lookup(self.dataset_rid)
|
|
76
221
|
|
|
77
|
-
|
|
222
|
+
# Cache the version and dataset table reference
|
|
223
|
+
self._current_version = self.model.dataset_version(self.dataset_rid)
|
|
78
224
|
self._dataset_table = self.model.dataset_table
|
|
79
225
|
|
|
80
226
|
def __repr__(self) -> str:
|
|
81
|
-
|
|
227
|
+
"""Return a string representation of the DatasetBag for debugging."""
|
|
228
|
+
return (f"<deriva_ml.DatasetBag object at {hex(id(self))}: rid='{self.dataset_rid}', "
|
|
229
|
+
f"version='{self.current_version}', types={self.dataset_types}>")
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def current_version(self) -> DatasetVersion:
|
|
233
|
+
"""Get the version of the dataset at the time the bag was downloaded.
|
|
234
|
+
|
|
235
|
+
For a DatasetBag, this is the version that was current when the bag was
|
|
236
|
+
created. Unlike the live Dataset class, this value is immutable since
|
|
237
|
+
bags are read-only snapshots.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
DatasetVersion: The semantic version (major.minor.patch) of this dataset.
|
|
241
|
+
"""
|
|
242
|
+
return self._current_version
|
|
82
243
|
|
|
83
244
|
def list_tables(self) -> list[str]:
|
|
84
|
-
"""List
|
|
245
|
+
"""List all tables available in the bag's SQLite database.
|
|
246
|
+
|
|
247
|
+
Returns the fully-qualified names of all tables (e.g., "domain.Image",
|
|
248
|
+
"deriva-ml.Dataset") that were exported in this bag.
|
|
85
249
|
|
|
86
250
|
Returns:
|
|
87
|
-
|
|
251
|
+
list[str]: Table names in "schema.table" format, sorted alphabetically.
|
|
88
252
|
"""
|
|
89
253
|
return self.model.list_tables()
|
|
90
254
|
|
|
255
|
+
def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
|
|
256
|
+
"""Get table contents as dictionaries.
|
|
257
|
+
|
|
258
|
+
Convenience method that delegates to the underlying catalog. This provides
|
|
259
|
+
access to all rows in a table, not just those belonging to this dataset.
|
|
260
|
+
For dataset-filtered results, use list_dataset_members() instead.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
table: Name of the table to retrieve (e.g., "Subject", "Image").
|
|
264
|
+
|
|
265
|
+
Yields:
|
|
266
|
+
dict: Dictionary for each row in the table.
|
|
267
|
+
|
|
268
|
+
Example:
|
|
269
|
+
>>> for subject in bag.get_table_as_dict("Subject"):
|
|
270
|
+
... print(subject["Name"])
|
|
271
|
+
"""
|
|
272
|
+
return self._catalog.get_table_as_dict(table)
|
|
273
|
+
|
|
91
274
|
@staticmethod
|
|
92
275
|
def _find_relationship_attr(source, target):
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
that points to `target
|
|
96
|
-
|
|
276
|
+
"""Find the SQLAlchemy relationship attribute connecting two ORM classes.
|
|
277
|
+
|
|
278
|
+
Searches for a relationship on `source` that points to `target`, which is
|
|
279
|
+
needed to construct proper JOIN clauses in SQL queries.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
source: Source ORM class or AliasedClass.
|
|
283
|
+
target: Target ORM class or AliasedClass.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
InstrumentedAttribute: The relationship attribute on source pointing to target.
|
|
287
|
+
|
|
288
|
+
Raises:
|
|
289
|
+
LookupError: If no relationship exists between the two classes.
|
|
290
|
+
|
|
291
|
+
Note:
|
|
292
|
+
When multiple relationships exist, prefers MANYTOONE direction as this
|
|
293
|
+
is typically the more natural join direction for denormalization.
|
|
97
294
|
"""
|
|
98
295
|
src_mapper = inspect(source).mapper
|
|
99
296
|
tgt_mapper = inspect(target).mapper
|
|
100
297
|
|
|
101
|
-
#
|
|
298
|
+
# Collect all relationships on the source mapper that point to target
|
|
102
299
|
candidates: list[RelationshipProperty] = [rel for rel in src_mapper.relationships if rel.mapper is tgt_mapper]
|
|
103
300
|
|
|
104
301
|
if not candidates:
|
|
@@ -108,86 +305,117 @@ class DatasetBag:
|
|
|
108
305
|
candidates.sort(key=lambda r: r.direction.name != "MANYTOONE")
|
|
109
306
|
rel = candidates[0]
|
|
110
307
|
|
|
111
|
-
#
|
|
308
|
+
# Return the bound attribute (handles AliasedClass properly)
|
|
112
309
|
return getattr(source, rel.key) if isinstance(source, AliasedClass) else rel.class_attribute
|
|
113
310
|
|
|
114
311
|
def _dataset_table_view(self, table: str) -> CompoundSelect[Any]:
|
|
115
|
-
"""
|
|
116
|
-
|
|
312
|
+
"""Build a SQL query for all rows in a table that belong to this dataset.
|
|
313
|
+
|
|
314
|
+
Creates a UNION of queries that traverse all possible paths from the
|
|
315
|
+
Dataset table to the target table, filtering by this dataset's RID
|
|
316
|
+
(and any nested dataset RIDs).
|
|
317
|
+
|
|
318
|
+
This is necessary because table data may be linked to datasets through
|
|
319
|
+
different relationship paths (e.g., Image might be linked directly to
|
|
320
|
+
Dataset or through an intermediate Subject table).
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
table: Name of the table to query.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
CompoundSelect: A SQLAlchemy UNION query selecting all matching rows.
|
|
327
|
+
"""
|
|
117
328
|
table_class = self.model.get_orm_class_by_name(table)
|
|
118
329
|
dataset_table_class = self.model.get_orm_class_by_name(self._dataset_table.name)
|
|
330
|
+
|
|
331
|
+
# Include this dataset and all nested datasets in the query
|
|
119
332
|
dataset_rids = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
|
|
120
333
|
|
|
334
|
+
# Find all paths from Dataset to the target table
|
|
121
335
|
paths = [[t.name for t in p] for p in self.model._schema_to_paths() if p[-1].name == table]
|
|
336
|
+
|
|
337
|
+
# Build a SELECT query for each path and UNION them together
|
|
122
338
|
sql_cmds = []
|
|
123
339
|
for path in paths:
|
|
124
340
|
path_sql = select(table_class)
|
|
125
341
|
last_class = self.model.get_orm_class_by_name(path[0])
|
|
342
|
+
# Join through each table in the path
|
|
126
343
|
for t in path[1:]:
|
|
127
344
|
t_class = self.model.get_orm_class_by_name(t)
|
|
128
345
|
path_sql = path_sql.join(self._find_relationship_attr(last_class, t_class))
|
|
129
346
|
last_class = t_class
|
|
347
|
+
# Filter to only rows belonging to our dataset(s)
|
|
130
348
|
path_sql = path_sql.where(dataset_table_class.RID.in_(dataset_rids))
|
|
131
349
|
sql_cmds.append(path_sql)
|
|
132
350
|
return union(*sql_cmds)
|
|
133
351
|
|
|
134
|
-
def
|
|
135
|
-
"""
|
|
136
|
-
the method will attempt to locate the schema for the table.
|
|
352
|
+
def dataset_history(self) -> list[DatasetHistory]:
|
|
353
|
+
"""Retrieves the version history of a dataset.
|
|
137
354
|
|
|
138
|
-
|
|
139
|
-
|
|
355
|
+
Returns a chronological list of dataset versions, including their version numbers,
|
|
356
|
+
creation times, and associated metadata.
|
|
140
357
|
|
|
141
358
|
Returns:
|
|
142
|
-
|
|
359
|
+
list[DatasetHistory]: List of history entries, each containing:
|
|
360
|
+
- dataset_version: Version number (major.minor.patch)
|
|
361
|
+
- minid: Minimal Viable Identifier
|
|
362
|
+
- snapshot: Catalog snapshot time
|
|
363
|
+
- dataset_rid: Dataset Resource Identifier
|
|
364
|
+
- version_rid: Version Resource Identifier
|
|
365
|
+
- description: Version description
|
|
366
|
+
- execution_rid: Associated execution RID
|
|
143
367
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
result = session.execute(self._dataset_table_view(table))
|
|
147
|
-
for row in result:
|
|
148
|
-
yield row
|
|
149
|
-
|
|
150
|
-
def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
|
|
151
|
-
"""Retrieve the contents of the specified table as a dataframe.
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
If schema is not provided as part of the table name,
|
|
155
|
-
the method will attempt to locate the schema for the table.
|
|
156
|
-
|
|
157
|
-
Args:
|
|
158
|
-
table: Table to retrieve data from.
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
A dataframe containing the contents of the specified table.
|
|
162
|
-
"""
|
|
163
|
-
return pd.read_sql(self._dataset_table_view(table), self.engine)
|
|
164
|
-
|
|
165
|
-
def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
|
|
166
|
-
"""Retrieve the contents of the specified table as a dictionary.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
table: Table to retrieve data from. f schema is not provided as part of the table name,
|
|
170
|
-
the method will attempt to locate the schema for the table.
|
|
368
|
+
Raises:
|
|
369
|
+
DerivaMLException: If dataset_rid is not a valid dataset RID.
|
|
171
370
|
|
|
172
|
-
|
|
173
|
-
|
|
371
|
+
Example:
|
|
372
|
+
>>> history = ml.dataset_history("1-abc123")
|
|
373
|
+
>>> for entry in history:
|
|
374
|
+
... print(f"Version {entry.dataset_version}: {entry.description}")
|
|
174
375
|
"""
|
|
376
|
+
# Query Dataset_Version table directly via the model
|
|
377
|
+
return [
|
|
378
|
+
DatasetHistory(
|
|
379
|
+
dataset_version=DatasetVersion.parse(v["Version"]),
|
|
380
|
+
minid=v["Minid"],
|
|
381
|
+
snapshot=v["Snapshot"],
|
|
382
|
+
dataset_rid=self.dataset_rid,
|
|
383
|
+
version_rid=v["RID"],
|
|
384
|
+
description=v["Description"],
|
|
385
|
+
execution_rid=v["Execution"],
|
|
386
|
+
)
|
|
387
|
+
for v in self.model._get_table_contents("Dataset_Version")
|
|
388
|
+
if v["Dataset"] == self.dataset_rid
|
|
389
|
+
]
|
|
175
390
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
391
|
+
def list_dataset_members(
|
|
392
|
+
self,
|
|
393
|
+
recurse: bool = False,
|
|
394
|
+
limit: int | None = None,
|
|
395
|
+
_visited: set[RID] | None = None,
|
|
396
|
+
version: Any = None,
|
|
397
|
+
**kwargs: Any,
|
|
398
|
+
) -> dict[str, list[dict[str, Any]]]:
|
|
183
399
|
"""Return a list of entities associated with a specific dataset.
|
|
184
400
|
|
|
185
401
|
Args:
|
|
186
|
-
|
|
402
|
+
recurse: Whether to include members of nested datasets.
|
|
403
|
+
limit: Maximum number of members to return per type. None for no limit.
|
|
404
|
+
_visited: Internal parameter to track visited datasets and prevent infinite recursion.
|
|
405
|
+
version: Ignored (bags are immutable snapshots).
|
|
406
|
+
**kwargs: Additional arguments (ignored, for protocol compatibility).
|
|
187
407
|
|
|
188
408
|
Returns:
|
|
189
|
-
Dictionary
|
|
409
|
+
Dictionary mapping member types to lists of member records.
|
|
190
410
|
"""
|
|
411
|
+
# Initialize visited set for recursion guard
|
|
412
|
+
if _visited is None:
|
|
413
|
+
_visited = set()
|
|
414
|
+
|
|
415
|
+
# Prevent infinite recursion by checking if we've already visited this dataset
|
|
416
|
+
if self.dataset_rid in _visited:
|
|
417
|
+
return {}
|
|
418
|
+
_visited.add(self.dataset_rid)
|
|
191
419
|
|
|
192
420
|
# Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
|
|
193
421
|
# the appropriate association table.
|
|
@@ -200,16 +428,29 @@ class DatasetBag:
|
|
|
200
428
|
assoc_class, dataset_rel, element_rel = self.model.get_orm_association_class(dataset_class, element_class)
|
|
201
429
|
|
|
202
430
|
element_table = inspect(element_class).mapped_table
|
|
203
|
-
if
|
|
431
|
+
if not self.model.is_domain_schema(element_table.schema) and element_table.name not in ["Dataset", "File"]:
|
|
204
432
|
# Look at domain tables and nested datasets.
|
|
205
433
|
continue
|
|
434
|
+
|
|
206
435
|
# Get the names of the columns that we are going to need for linking
|
|
207
436
|
with Session(self.engine) as session:
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
437
|
+
# For Dataset_Dataset, use Nested_Dataset column to find nested datasets
|
|
438
|
+
# (similar to how the live catalog does it in Dataset.list_dataset_members)
|
|
439
|
+
if element_table.name == "Dataset":
|
|
440
|
+
sql_cmd = (
|
|
441
|
+
select(element_class)
|
|
442
|
+
.join(assoc_class, element_class.RID == assoc_class.__table__.c["Nested_Dataset"])
|
|
443
|
+
.where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
# For other tables, use the original join via element_rel
|
|
447
|
+
sql_cmd = (
|
|
448
|
+
select(element_class)
|
|
449
|
+
.join(element_rel)
|
|
450
|
+
.where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
|
|
451
|
+
)
|
|
452
|
+
if limit is not None:
|
|
453
|
+
sql_cmd = sql_cmd.limit(limit)
|
|
213
454
|
# Get back the list of ORM entities and convert them to dictionaries.
|
|
214
455
|
element_entities = session.scalars(sql_cmd).all()
|
|
215
456
|
element_rows = [{c.key: getattr(obj, c.key) for c in obj.__table__.columns} for obj in element_entities]
|
|
@@ -218,8 +459,8 @@ class DatasetBag:
|
|
|
218
459
|
# Get the members for all the nested datasets and add to the member list.
|
|
219
460
|
nested_datasets = [d["RID"] for d in element_rows]
|
|
220
461
|
for ds in nested_datasets:
|
|
221
|
-
nested_dataset = self.
|
|
222
|
-
for k, v in nested_dataset.list_dataset_members(recurse=recurse).items():
|
|
462
|
+
nested_dataset = self._catalog.lookup_dataset(ds)
|
|
463
|
+
for k, v in nested_dataset.list_dataset_members(recurse=recurse, limit=limit, _visited=_visited).items():
|
|
223
464
|
members[k].extend(v)
|
|
224
465
|
return dict(members)
|
|
225
466
|
|
|
@@ -234,25 +475,63 @@ class DatasetBag:
|
|
|
234
475
|
"""
|
|
235
476
|
return self.model.find_features(table)
|
|
236
477
|
|
|
237
|
-
def list_feature_values(
|
|
238
|
-
|
|
478
|
+
def list_feature_values(
|
|
479
|
+
self, table: Table | str, feature_name: str
|
|
480
|
+
) -> Iterable[FeatureRecord]:
|
|
481
|
+
"""Retrieves all values for a feature as typed FeatureRecord instances.
|
|
482
|
+
|
|
483
|
+
Returns an iterator of dynamically-generated FeatureRecord objects for each
|
|
484
|
+
feature value. Each record is an instance of a Pydantic model specific to
|
|
485
|
+
this feature, with typed attributes for all columns including the Execution
|
|
486
|
+
that created the feature value.
|
|
239
487
|
|
|
240
488
|
Args:
|
|
241
|
-
table: The table
|
|
242
|
-
feature_name: Name of the feature.
|
|
489
|
+
table: The table containing the feature, either as name or Table object.
|
|
490
|
+
feature_name: Name of the feature to retrieve values for.
|
|
243
491
|
|
|
244
492
|
Returns:
|
|
245
|
-
|
|
493
|
+
Iterable[FeatureRecord]: An iterator of FeatureRecord instances.
|
|
494
|
+
Each instance has:
|
|
495
|
+
- Execution: RID of the execution that created this feature value
|
|
496
|
+
- Feature_Name: Name of the feature
|
|
497
|
+
- All feature-specific columns as typed attributes
|
|
498
|
+
- model_dump() method to convert back to a dictionary
|
|
499
|
+
|
|
500
|
+
Raises:
|
|
501
|
+
DerivaMLException: If the feature doesn't exist or cannot be accessed.
|
|
502
|
+
|
|
503
|
+
Example:
|
|
504
|
+
>>> # Get typed feature records
|
|
505
|
+
>>> for record in bag.list_feature_values("Image", "Quality"):
|
|
506
|
+
... print(f"Image {record.Image}: {record.ImageQuality}")
|
|
507
|
+
... print(f"Created by execution: {record.Execution}")
|
|
508
|
+
|
|
509
|
+
>>> # Convert records to dictionaries
|
|
510
|
+
>>> records = list(bag.list_feature_values("Image", "Quality"))
|
|
511
|
+
>>> dicts = [r.model_dump() for r in records]
|
|
246
512
|
"""
|
|
513
|
+
# Get table and feature
|
|
247
514
|
feature = self.model.lookup_feature(table, feature_name)
|
|
248
|
-
|
|
515
|
+
|
|
516
|
+
# Get the dynamically-generated FeatureRecord subclass for this feature
|
|
517
|
+
record_class = feature.feature_record_class()
|
|
518
|
+
|
|
519
|
+
# Query raw values from SQLite
|
|
520
|
+
feature_table = self.model.find_table(feature.feature_table.name)
|
|
249
521
|
with Session(self.engine) as session:
|
|
250
|
-
sql_cmd = select(
|
|
251
|
-
|
|
522
|
+
sql_cmd = select(feature_table)
|
|
523
|
+
result = session.execute(sql_cmd)
|
|
524
|
+
rows = [dict(row._mapping) for row in result]
|
|
252
525
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
526
|
+
# Convert to typed records
|
|
527
|
+
for raw_value in rows:
|
|
528
|
+
# Filter to only include fields that the record class expects
|
|
529
|
+
field_names = set(record_class.model_fields.keys())
|
|
530
|
+
filtered_data = {k: v for k, v in raw_value.items() if k in field_names}
|
|
531
|
+
yield record_class(**filtered_data)
|
|
532
|
+
|
|
533
|
+
def list_dataset_element_types(self) -> Iterable[Table]:
|
|
534
|
+
"""List the types of elements that can be contained in datasets.
|
|
256
535
|
|
|
257
536
|
This method analyzes the dataset and identifies the data types for all
|
|
258
537
|
elements within it. It is useful for understanding the structure and
|
|
@@ -266,15 +545,33 @@ class DatasetBag:
|
|
|
266
545
|
"""
|
|
267
546
|
return self.model.list_dataset_element_types()
|
|
268
547
|
|
|
269
|
-
def list_dataset_children(
|
|
548
|
+
def list_dataset_children(
|
|
549
|
+
self,
|
|
550
|
+
recurse: bool = False,
|
|
551
|
+
_visited: set[RID] | None = None,
|
|
552
|
+
version: Any = None,
|
|
553
|
+
**kwargs: Any,
|
|
554
|
+
) -> list[Self]:
|
|
270
555
|
"""Get nested datasets.
|
|
271
556
|
|
|
272
557
|
Args:
|
|
273
558
|
recurse: Whether to include children of children.
|
|
559
|
+
_visited: Internal parameter to track visited datasets and prevent infinite recursion.
|
|
560
|
+
version: Ignored (bags are immutable snapshots).
|
|
561
|
+
**kwargs: Additional arguments (ignored, for protocol compatibility).
|
|
274
562
|
|
|
275
563
|
Returns:
|
|
276
564
|
List of child dataset bags.
|
|
277
565
|
"""
|
|
566
|
+
# Initialize visited set for recursion guard
|
|
567
|
+
if _visited is None:
|
|
568
|
+
_visited = set()
|
|
569
|
+
|
|
570
|
+
# Prevent infinite recursion by checking if we've already visited this dataset
|
|
571
|
+
if self.dataset_rid in _visited:
|
|
572
|
+
return []
|
|
573
|
+
_visited.add(self.dataset_rid)
|
|
574
|
+
|
|
278
575
|
ds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset")
|
|
279
576
|
nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
|
|
280
577
|
dv_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Version")
|
|
@@ -286,63 +583,102 @@ class DatasetBag:
|
|
|
286
583
|
.join_from(ds_table, dv_table, onclause=ds_table.Version == dv_table.RID)
|
|
287
584
|
.where(nds_table.Dataset == self.dataset_rid)
|
|
288
585
|
)
|
|
289
|
-
nested = [
|
|
586
|
+
nested = [self._catalog.lookup_dataset(r[0]) for r in session.execute(sql_cmd).all()]
|
|
290
587
|
|
|
291
588
|
result = copy(nested)
|
|
292
589
|
if recurse:
|
|
293
590
|
for child in nested:
|
|
294
|
-
result.extend(child.list_dataset_children(recurse))
|
|
591
|
+
result.extend(child.list_dataset_children(recurse=recurse, _visited=_visited))
|
|
295
592
|
return result
|
|
296
593
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
594
|
+
def list_dataset_parents(
|
|
595
|
+
self,
|
|
596
|
+
recurse: bool = False,
|
|
597
|
+
_visited: set[RID] | None = None,
|
|
598
|
+
version: Any = None,
|
|
599
|
+
**kwargs: Any,
|
|
600
|
+
) -> list[Self]:
|
|
601
|
+
"""Given a dataset_table RID, return a list of RIDs of the parent datasets if this is included in a
|
|
602
|
+
nested dataset.
|
|
303
603
|
|
|
304
604
|
Args:
|
|
305
|
-
|
|
306
|
-
|
|
605
|
+
recurse: If True, recursively return all ancestor datasets.
|
|
606
|
+
_visited: Internal parameter to track visited datasets and prevent infinite recursion.
|
|
607
|
+
version: Ignored (bags are immutable snapshots).
|
|
608
|
+
**kwargs: Additional arguments (ignored, for protocol compatibility).
|
|
307
609
|
|
|
308
610
|
Returns:
|
|
309
|
-
|
|
611
|
+
List of parent dataset bags.
|
|
612
|
+
"""
|
|
613
|
+
# Initialize visited set for recursion guard
|
|
614
|
+
if _visited is None:
|
|
615
|
+
_visited = set()
|
|
310
616
|
|
|
311
|
-
|
|
312
|
-
|
|
617
|
+
# Prevent infinite recursion by checking if we've already visited this dataset
|
|
618
|
+
if self.dataset_rid in _visited:
|
|
619
|
+
return []
|
|
620
|
+
_visited.add(self.dataset_rid)
|
|
313
621
|
|
|
314
|
-
|
|
315
|
-
Look up by primary name:
|
|
316
|
-
>>> term = ml.lookup_term("tissue_types", "epithelial")
|
|
317
|
-
>>> print(term.description)
|
|
622
|
+
nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
|
|
318
623
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
# Get and validate vocabulary table reference
|
|
323
|
-
if not self.model.is_vocabulary(table):
|
|
324
|
-
raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
|
|
624
|
+
with Session(self.engine) as session:
|
|
625
|
+
sql_cmd = select(nds_table.Dataset).where(nds_table.Nested_Dataset == self.dataset_rid)
|
|
626
|
+
parents = [self._catalog.lookup_dataset(r[0]) for r in session.execute(sql_cmd).all()]
|
|
325
627
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
return VocabularyTerm.model_validate(term)
|
|
628
|
+
if recurse:
|
|
629
|
+
for parent in parents.copy():
|
|
630
|
+
parents.extend(parent.list_dataset_parents(recurse=True, _visited=_visited))
|
|
631
|
+
return parents
|
|
331
632
|
|
|
332
|
-
|
|
333
|
-
|
|
633
|
+
def list_executions(self) -> list[RID]:
|
|
634
|
+
"""List all execution RIDs associated with this dataset.
|
|
334
635
|
|
|
335
|
-
|
|
636
|
+
Returns all executions that used this dataset as input. This is
|
|
637
|
+
tracked through the Dataset_Execution association table.
|
|
638
|
+
|
|
639
|
+
Note:
|
|
640
|
+
Unlike the live Dataset class which returns Execution objects,
|
|
641
|
+
DatasetBag returns a list of execution RIDs since the bag is
|
|
642
|
+
an offline snapshot and cannot look up live execution objects.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
List of execution RIDs associated with this dataset.
|
|
646
|
+
|
|
647
|
+
Example:
|
|
648
|
+
>>> bag = ml.download_dataset_bag(dataset_spec)
|
|
649
|
+
>>> execution_rids = bag.list_executions()
|
|
650
|
+
>>> for rid in execution_rids:
|
|
651
|
+
... print(f"Associated execution: {rid}")
|
|
336
652
|
"""
|
|
337
|
-
|
|
338
|
-
|
|
653
|
+
de_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Execution")
|
|
654
|
+
|
|
655
|
+
with Session(self.engine) as session:
|
|
656
|
+
sql_cmd = select(de_table.Execution).where(de_table.Dataset == self.dataset_rid)
|
|
657
|
+
return [r[0] for r in session.execute(sql_cmd).all()]
|
|
658
|
+
|
|
659
|
+
def _denormalize(self, include_tables: list[str]) -> Select:
|
|
660
|
+
"""Build a SQL query that joins multiple tables into a denormalized view.
|
|
661
|
+
|
|
662
|
+
This method creates a "wide table" by joining related tables together,
|
|
663
|
+
producing a single query that returns columns from all specified tables.
|
|
664
|
+
This is useful for machine learning pipelines that need flat data.
|
|
665
|
+
|
|
666
|
+
The method:
|
|
667
|
+
1. Analyzes the schema to find join paths between tables
|
|
668
|
+
2. Determines the correct join order based on foreign key relationships
|
|
669
|
+
3. Builds SELECT statements with properly aliased columns
|
|
670
|
+
4. Creates a UNION if multiple paths exist to the same tables
|
|
339
671
|
|
|
340
672
|
Args:
|
|
341
|
-
include_tables
|
|
342
|
-
|
|
673
|
+
include_tables: List of table names to include in the output. Additional
|
|
674
|
+
tables may be included if they're needed to join the requested tables.
|
|
343
675
|
|
|
344
676
|
Returns:
|
|
345
|
-
|
|
677
|
+
Select: A SQLAlchemy query that produces the denormalized result.
|
|
678
|
+
|
|
679
|
+
Note:
|
|
680
|
+
Column names in the result are prefixed with the table name to avoid
|
|
681
|
+
collisions (e.g., "Image.Filename", "Subject.RID").
|
|
346
682
|
"""
|
|
347
683
|
# Skip over tables that we don't want to include in the denormalized dataset.
|
|
348
684
|
# Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
|
|
@@ -359,9 +695,7 @@ class DatasetBag:
|
|
|
359
695
|
return relationship
|
|
360
696
|
return None
|
|
361
697
|
|
|
362
|
-
join_tables, denormalized_columns = (
|
|
363
|
-
self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
|
|
364
|
-
)
|
|
698
|
+
join_tables, denormalized_columns = self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
|
|
365
699
|
|
|
366
700
|
denormalized_columns = [
|
|
367
701
|
self.model.get_orm_class_by_name(table_name)
|
|
@@ -382,69 +716,864 @@ class DatasetBag:
|
|
|
382
716
|
if (r := find_relationship(table_class, on_condition))
|
|
383
717
|
]
|
|
384
718
|
sql_statement = sql_statement.join(table_class, onclause=and_(*on_clause))
|
|
385
|
-
dataset_rid_list = [self.dataset_rid] + self.list_dataset_children(recurse=True)
|
|
719
|
+
dataset_rid_list = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
|
|
386
720
|
dataset_class = self.model.get_orm_class_by_name(self._dataset_table.name)
|
|
387
721
|
sql_statement = sql_statement.where(dataset_class.RID.in_(dataset_rid_list))
|
|
388
722
|
sql_statements.append(sql_statement)
|
|
389
723
|
return union(*sql_statements)
|
|
390
724
|
|
|
391
|
-
def
|
|
725
|
+
def _denormalize_from_members(
|
|
726
|
+
self,
|
|
727
|
+
include_tables: list[str],
|
|
728
|
+
) -> Generator[dict[str, Any], None, None]:
|
|
729
|
+
"""Denormalize dataset members by joining related tables.
|
|
730
|
+
|
|
731
|
+
This method creates a "wide table" view by joining related tables together,
|
|
732
|
+
using list_dataset_members() as the data source. This ensures consistency
|
|
733
|
+
with the catalog-based denormalize implementation. The result has outer join
|
|
734
|
+
semantics - tables without FK relationships are included with NULL values.
|
|
735
|
+
|
|
736
|
+
The method:
|
|
737
|
+
1. Gets the list of dataset members for each included table via list_dataset_members
|
|
738
|
+
2. For each member in the first table, follows foreign key relationships to
|
|
739
|
+
get related records from other tables
|
|
740
|
+
3. Tables without FK connections to the first table are included with NULLs
|
|
741
|
+
4. Includes nested dataset members recursively
|
|
742
|
+
|
|
743
|
+
Args:
|
|
744
|
+
include_tables: List of table names to include in the output.
|
|
745
|
+
|
|
746
|
+
Yields:
|
|
747
|
+
dict[str, Any]: Rows with column names prefixed by table name (e.g., "Image.Filename").
|
|
748
|
+
Unrelated tables have NULL values for their columns.
|
|
749
|
+
|
|
750
|
+
Note:
|
|
751
|
+
Column names in the result are prefixed with the table name to avoid
|
|
752
|
+
collisions (e.g., "Image.Filename", "Subject.RID").
|
|
392
753
|
"""
|
|
393
|
-
|
|
754
|
+
# Skip system columns in output
|
|
755
|
+
skip_columns = {"RCT", "RMT", "RCB", "RMB"}
|
|
756
|
+
|
|
757
|
+
# Get all members for the included tables (recursively includes nested datasets)
|
|
758
|
+
members = self.list_dataset_members(recurse=True)
|
|
759
|
+
|
|
760
|
+
# Build a lookup of columns for each table
|
|
761
|
+
table_columns: dict[str, list[str]] = {}
|
|
762
|
+
for table_name in include_tables:
|
|
763
|
+
table = self.model.name_to_table(table_name)
|
|
764
|
+
table_columns[table_name] = [
|
|
765
|
+
c.name for c in table.columns if c.name not in skip_columns
|
|
766
|
+
]
|
|
767
|
+
|
|
768
|
+
# Find the primary table (first non-empty table in include_tables)
|
|
769
|
+
primary_table = None
|
|
770
|
+
for table_name in include_tables:
|
|
771
|
+
if table_name in members and members[table_name]:
|
|
772
|
+
primary_table = table_name
|
|
773
|
+
break
|
|
774
|
+
|
|
775
|
+
if primary_table is None:
|
|
776
|
+
# No data at all
|
|
777
|
+
return
|
|
778
|
+
|
|
779
|
+
primary_table_obj = self.model.name_to_table(primary_table)
|
|
780
|
+
|
|
781
|
+
for member in members[primary_table]:
|
|
782
|
+
# Build the row with all columns from all tables
|
|
783
|
+
row: dict[str, Any] = {}
|
|
784
|
+
|
|
785
|
+
# Add primary table columns
|
|
786
|
+
for col_name in table_columns[primary_table]:
|
|
787
|
+
prefixed_name = f"{primary_table}.{col_name}"
|
|
788
|
+
row[prefixed_name] = member.get(col_name)
|
|
789
|
+
|
|
790
|
+
# For each other table, try to join or add NULL values
|
|
791
|
+
for other_table_name in include_tables:
|
|
792
|
+
if other_table_name == primary_table:
|
|
793
|
+
continue
|
|
794
|
+
|
|
795
|
+
other_table = self.model.name_to_table(other_table_name)
|
|
796
|
+
other_cols = table_columns[other_table_name]
|
|
797
|
+
|
|
798
|
+
# Initialize all columns to None (outer join behavior)
|
|
799
|
+
for col_name in other_cols:
|
|
800
|
+
prefixed_name = f"{other_table_name}.{col_name}"
|
|
801
|
+
row[prefixed_name] = None
|
|
802
|
+
|
|
803
|
+
# Try to find FK relationship and join
|
|
804
|
+
if other_table_name in members:
|
|
805
|
+
try:
|
|
806
|
+
relationship = self.model._table_relationship(
|
|
807
|
+
primary_table_obj, other_table
|
|
808
|
+
)
|
|
809
|
+
fk_col, pk_col = relationship
|
|
810
|
+
|
|
811
|
+
# Look up the related record
|
|
812
|
+
fk_value = member.get(fk_col.name)
|
|
813
|
+
if fk_value:
|
|
814
|
+
for other_member in members.get(other_table_name, []):
|
|
815
|
+
if other_member.get(pk_col.name) == fk_value:
|
|
816
|
+
for col_name in other_cols:
|
|
817
|
+
prefixed_name = f"{other_table_name}.{col_name}"
|
|
818
|
+
row[prefixed_name] = other_member.get(col_name)
|
|
819
|
+
break
|
|
820
|
+
except DerivaMLException:
|
|
821
|
+
# No FK relationship - columns remain NULL (outer join)
|
|
822
|
+
pass
|
|
823
|
+
|
|
824
|
+
yield row
|
|
825
|
+
|
|
826
|
+
def denormalize_as_dataframe(
|
|
827
|
+
self,
|
|
828
|
+
include_tables: list[str],
|
|
829
|
+
version: Any = None,
|
|
830
|
+
**kwargs: Any,
|
|
831
|
+
) -> pd.DataFrame:
|
|
832
|
+
"""Denormalize the dataset bag into a single wide table (DataFrame).
|
|
833
|
+
|
|
834
|
+
Denormalization transforms normalized relational data into a single "wide table"
|
|
835
|
+
(also called a "flat table" or "denormalized table") by joining related tables
|
|
836
|
+
together. This produces a DataFrame where each row contains all related information
|
|
837
|
+
from multiple source tables, with columns from each table combined side-by-side.
|
|
838
|
+
|
|
839
|
+
Wide tables are the standard input format for most machine learning frameworks,
|
|
840
|
+
which expect all features for a single observation to be in one row. This method
|
|
841
|
+
bridges the gap between normalized database schemas and ML-ready tabular data.
|
|
842
|
+
|
|
843
|
+
**How it works:**
|
|
844
|
+
|
|
845
|
+
Tables are joined based on their foreign key relationships stored in the bag's
|
|
846
|
+
schema. For example, if Image has a foreign key to Subject, denormalizing
|
|
847
|
+
["Subject", "Image"] produces rows where each image appears with its subject's
|
|
848
|
+
metadata.
|
|
849
|
+
|
|
850
|
+
**Column naming:**
|
|
851
|
+
|
|
852
|
+
Column names are prefixed with the source table name using dots to avoid
|
|
853
|
+
collisions (e.g., "Image.Filename", "Subject.RID"). This differs from the
|
|
854
|
+
live Dataset class which uses underscores.
|
|
855
|
+
|
|
856
|
+
Args:
|
|
857
|
+
include_tables: List of table names to include in the output. Tables
|
|
858
|
+
are joined based on their foreign key relationships.
|
|
859
|
+
Order doesn't matter - the join order is determined automatically.
|
|
860
|
+
version: Ignored (bags are immutable snapshots of a specific version).
|
|
861
|
+
**kwargs: Additional arguments (ignored, for protocol compatibility).
|
|
862
|
+
|
|
863
|
+
Returns:
|
|
864
|
+
pd.DataFrame: Wide table with columns from all included tables.
|
|
865
|
+
|
|
866
|
+
Example:
|
|
867
|
+
Create a training dataset from a downloaded bag::
|
|
868
|
+
|
|
869
|
+
>>> # Download and materialize the dataset
|
|
870
|
+
>>> bag = ml.download_dataset_bag(spec, materialize=True)
|
|
871
|
+
|
|
872
|
+
>>> # Denormalize into a wide table
|
|
873
|
+
>>> df = bag.denormalize_as_dataframe(["Image", "Diagnosis"])
|
|
874
|
+
>>> print(df.columns.tolist())
|
|
875
|
+
['Image.RID', 'Image.Filename', 'Image.URL', 'Diagnosis.RID',
|
|
876
|
+
'Diagnosis.Label', 'Diagnosis.Confidence']
|
|
877
|
+
|
|
878
|
+
>>> # Access local file paths for images
|
|
879
|
+
>>> for _, row in df.iterrows():
|
|
880
|
+
... local_path = bag.get_asset_path("Image", row["Image.RID"])
|
|
881
|
+
... label = row["Diagnosis.Label"]
|
|
882
|
+
... # Train on local_path with label
|
|
883
|
+
|
|
884
|
+
See Also:
|
|
885
|
+
denormalize_as_dict: Generator version for memory-efficient processing.
|
|
886
|
+
"""
|
|
887
|
+
rows = list(self._denormalize_from_members(include_tables=include_tables))
|
|
888
|
+
return pd.DataFrame(rows)
|
|
889
|
+
|
|
890
|
+
def denormalize_as_dict(
|
|
891
|
+
self,
|
|
892
|
+
include_tables: list[str],
|
|
893
|
+
version: Any = None,
|
|
894
|
+
**kwargs: Any,
|
|
895
|
+
) -> Generator[dict[str, Any], None, None]:
|
|
896
|
+
"""Denormalize the dataset bag and yield rows as dictionaries.
|
|
897
|
+
|
|
898
|
+
This is a memory-efficient alternative to denormalize_as_dataframe() that
|
|
899
|
+
yields one row at a time as a dictionary instead of loading all data into
|
|
900
|
+
a DataFrame. Use this when processing large datasets that may not fit in
|
|
901
|
+
memory, or when you want to process rows incrementally.
|
|
394
902
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
903
|
+
Like denormalize_as_dataframe(), this produces a "wide table" representation
|
|
904
|
+
where each yielded dictionary contains all columns from the joined tables.
|
|
905
|
+
See denormalize_as_dataframe() for detailed explanation of how denormalization
|
|
906
|
+
works.
|
|
398
907
|
|
|
399
|
-
|
|
400
|
-
view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
|
|
401
|
-
additional tables are required to complete the denormalization process. If include_tables is not specified,
|
|
402
|
-
all of the tables in the schema will be included.
|
|
908
|
+
**Column naming:**
|
|
403
909
|
|
|
404
|
-
|
|
910
|
+
Column names are prefixed with the source table name using dots to avoid
|
|
911
|
+
collisions (e.g., "Image.Filename", "Subject.RID"). This differs from the
|
|
912
|
+
live Dataset class which uses underscores.
|
|
405
913
|
|
|
406
914
|
Args:
|
|
407
|
-
include_tables: List of table names to include in the
|
|
915
|
+
include_tables: List of table names to include in the output.
|
|
916
|
+
Tables are joined based on their foreign key relationships.
|
|
917
|
+
version: Ignored (bags are immutable snapshots of a specific version).
|
|
918
|
+
**kwargs: Additional arguments (ignored, for protocol compatibility).
|
|
919
|
+
|
|
920
|
+
Yields:
|
|
921
|
+
dict[str, Any]: Dictionary representing one row of the wide table.
|
|
922
|
+
Keys are column names in "Table.Column" format.
|
|
923
|
+
|
|
924
|
+
Example:
|
|
925
|
+
Stream through a large dataset for training::
|
|
926
|
+
|
|
927
|
+
>>> bag = ml.download_dataset_bag(spec, materialize=True)
|
|
928
|
+
>>> for row in bag.denormalize_as_dict(["Image", "Diagnosis"]):
|
|
929
|
+
... # Get local file path for this image
|
|
930
|
+
... local_path = bag.get_asset_path("Image", row["Image.RID"])
|
|
931
|
+
... label = row["Diagnosis.Label"]
|
|
932
|
+
... # Process image and label...
|
|
933
|
+
|
|
934
|
+
Build a PyTorch dataset efficiently::
|
|
935
|
+
|
|
936
|
+
>>> class BagDataset(torch.utils.data.IterableDataset):
|
|
937
|
+
... def __init__(self, bag, tables):
|
|
938
|
+
... self.bag = bag
|
|
939
|
+
... self.tables = tables
|
|
940
|
+
... def __iter__(self):
|
|
941
|
+
... for row in self.bag.denormalize_as_dict(self.tables):
|
|
942
|
+
... img_path = self.bag.get_asset_path("Image", row["Image.RID"])
|
|
943
|
+
... yield load_image(img_path), row["Diagnosis.Label"]
|
|
944
|
+
|
|
945
|
+
See Also:
|
|
946
|
+
denormalize_as_dataframe: Returns all data as a pandas DataFrame.
|
|
947
|
+
"""
|
|
948
|
+
yield from self._denormalize_from_members(include_tables=include_tables)
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
# =========================================================================
|
|
952
|
+
# Asset Restructuring Methods
|
|
953
|
+
# =========================================================================
|
|
954
|
+
|
|
955
|
+
def _build_dataset_type_path_map(
|
|
956
|
+
self,
|
|
957
|
+
type_selector: Callable[[list[str]], str] | None = None,
|
|
958
|
+
) -> dict[RID, list[str]]:
|
|
959
|
+
"""Build a mapping from dataset RID to its type path in the hierarchy.
|
|
960
|
+
|
|
961
|
+
Recursively traverses nested datasets to create a mapping where each
|
|
962
|
+
dataset RID maps to its hierarchical type path (e.g., ["complete", "training"]).
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
type_selector: Function to select type when dataset has multiple types.
|
|
966
|
+
Receives list of type names, returns selected type name.
|
|
967
|
+
Defaults to selecting first type or "unknown" if no types.
|
|
408
968
|
|
|
409
969
|
Returns:
|
|
410
|
-
|
|
970
|
+
Dictionary mapping dataset RID to list of type names from root to leaf.
|
|
971
|
+
e.g., {"4-ABC": ["complete", "training"], "4-DEF": ["complete", "testing"]}
|
|
411
972
|
"""
|
|
412
|
-
|
|
973
|
+
if type_selector is None:
|
|
974
|
+
type_selector = lambda types: types[0] if types else "Testing"
|
|
975
|
+
|
|
976
|
+
type_paths: dict[RID, list[str]] = {}
|
|
977
|
+
|
|
978
|
+
def traverse(dataset: DatasetBag, parent_path: list[str], visited: set[RID]) -> None:
|
|
979
|
+
if dataset.dataset_rid in visited:
|
|
980
|
+
return
|
|
981
|
+
visited.add(dataset.dataset_rid)
|
|
982
|
+
|
|
983
|
+
current_type = type_selector(dataset.dataset_types)
|
|
984
|
+
current_path = parent_path + [current_type]
|
|
985
|
+
type_paths[dataset.dataset_rid] = current_path
|
|
986
|
+
|
|
987
|
+
for child in dataset.list_dataset_children():
|
|
988
|
+
traverse(child, current_path, visited)
|
|
989
|
+
|
|
990
|
+
traverse(self, [], set())
|
|
991
|
+
return type_paths
|
|
992
|
+
|
|
993
|
+
def _get_asset_dataset_mapping(self, asset_table: str) -> dict[RID, RID]:
|
|
994
|
+
"""Map asset RIDs to their containing dataset RID.
|
|
995
|
+
|
|
996
|
+
For each asset in the specified table, determines which dataset it belongs to.
|
|
997
|
+
This uses _dataset_table_view to find assets reachable through any FK path
|
|
998
|
+
from the dataset, not just directly associated assets.
|
|
413
999
|
|
|
414
|
-
|
|
1000
|
+
Assets are mapped to their most specific (leaf) dataset in the hierarchy.
|
|
1001
|
+
For example, if a Split dataset contains Training and Testing children,
|
|
1002
|
+
and images are members of Training, the images map to Training (not Split).
|
|
1003
|
+
|
|
1004
|
+
Args:
|
|
1005
|
+
asset_table: Name of the asset table (e.g., "Image")
|
|
1006
|
+
|
|
1007
|
+
Returns:
|
|
1008
|
+
Dictionary mapping asset RID to the dataset RID that contains it.
|
|
415
1009
|
"""
|
|
416
|
-
|
|
1010
|
+
asset_to_dataset: dict[RID, RID] = {}
|
|
1011
|
+
|
|
1012
|
+
def collect_from_dataset(dataset: DatasetBag, visited: set[RID]) -> None:
|
|
1013
|
+
if dataset.dataset_rid in visited:
|
|
1014
|
+
return
|
|
1015
|
+
visited.add(dataset.dataset_rid)
|
|
1016
|
+
|
|
1017
|
+
# Process children FIRST (depth-first) so leaf datasets get priority
|
|
1018
|
+
# This ensures assets are mapped to their most specific dataset
|
|
1019
|
+
for child in dataset.list_dataset_children():
|
|
1020
|
+
collect_from_dataset(child, visited)
|
|
1021
|
+
|
|
1022
|
+
# Then process this dataset's assets
|
|
1023
|
+
# Only set if not already mapped (child/leaf dataset wins)
|
|
1024
|
+
for asset in dataset._get_reachable_assets(asset_table):
|
|
1025
|
+
if asset["RID"] not in asset_to_dataset:
|
|
1026
|
+
asset_to_dataset[asset["RID"]] = dataset.dataset_rid
|
|
417
1027
|
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
for each row in the denormalized wide table.
|
|
1028
|
+
collect_from_dataset(self, set())
|
|
1029
|
+
return asset_to_dataset
|
|
421
1030
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
additional tables are required to complete the denormalization process. If include_tables is not specified,
|
|
425
|
-
all of the tables in the schema will be included.
|
|
1031
|
+
def _get_reachable_assets(self, asset_table: str) -> list[dict[str, Any]]:
|
|
1032
|
+
"""Get all assets reachable from this dataset through any FK path.
|
|
426
1033
|
|
|
427
|
-
|
|
1034
|
+
Unlike list_dataset_members which only returns directly associated entities,
|
|
1035
|
+
this method traverses foreign key relationships to find assets that are
|
|
1036
|
+
indirectly connected to the dataset. For example, if a dataset contains
|
|
1037
|
+
Subjects, and Subject -> Encounter -> Image, this method will find those
|
|
1038
|
+
Images even though they're not directly in the Dataset_Image association table.
|
|
428
1039
|
|
|
429
1040
|
Args:
|
|
430
|
-
|
|
431
|
-
is used.
|
|
1041
|
+
asset_table: Name of the asset table (e.g., "Image")
|
|
432
1042
|
|
|
433
1043
|
Returns:
|
|
434
|
-
|
|
1044
|
+
List of asset records as dictionaries.
|
|
435
1045
|
"""
|
|
1046
|
+
# Use the _dataset_table_view query which traverses all FK paths
|
|
1047
|
+
sql_query = self._dataset_table_view(asset_table)
|
|
1048
|
+
|
|
436
1049
|
with Session(self.engine) as session:
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
1050
|
+
result = session.execute(sql_query)
|
|
1051
|
+
# Convert rows to dictionaries
|
|
1052
|
+
rows = [dict(row._mapping) for row in result]
|
|
1053
|
+
|
|
1054
|
+
return rows
|
|
1055
|
+
|
|
1056
|
+
def _load_feature_values_cache(
|
|
1057
|
+
self,
|
|
1058
|
+
asset_table: str,
|
|
1059
|
+
group_keys: list[str],
|
|
1060
|
+
enforce_vocabulary: bool = True,
|
|
1061
|
+
value_selector: Callable[[list[FeatureValueRecord]], FeatureValueRecord] | None = None,
|
|
1062
|
+
) -> dict[str, dict[RID, Any]]:
|
|
1063
|
+
"""Load feature values into a cache for efficient lookup.
|
|
1064
|
+
|
|
1065
|
+
Pre-loads feature values for any group_keys that are feature names,
|
|
1066
|
+
organizing them by target entity RID for fast lookup.
|
|
1067
|
+
|
|
1068
|
+
Args:
|
|
1069
|
+
asset_table: The asset table name to find features for.
|
|
1070
|
+
group_keys: List of potential feature names to cache. Supports two formats:
|
|
1071
|
+
- "FeatureName": Uses the first term column (default behavior)
|
|
1072
|
+
- "FeatureName.column_name": Uses the specified column from the feature table
|
|
1073
|
+
enforce_vocabulary: If True (default), only allow features with
|
|
1074
|
+
controlled vocabulary term columns and raise an error if an
|
|
1075
|
+
asset has multiple values. If False, allow any feature type
|
|
1076
|
+
and use the first value found when multiple exist.
|
|
1077
|
+
value_selector: Optional function to select which feature value to use
|
|
1078
|
+
when an asset has multiple values for the same feature. Receives a
|
|
1079
|
+
list of FeatureValueRecord objects (each with execution_rid for
|
|
1080
|
+
provenance) and returns the selected one. If not provided and
|
|
1081
|
+
multiple values exist, raises DerivaMLException when
|
|
1082
|
+
enforce_vocabulary=True or uses the first value when False.
|
|
1083
|
+
|
|
1084
|
+
Returns:
|
|
1085
|
+
Dictionary mapping group_key -> {target_rid -> feature_value}
|
|
1086
|
+
Only includes entries for keys that are actually features.
|
|
1087
|
+
|
|
1088
|
+
Raises:
|
|
1089
|
+
DerivaMLException: If enforce_vocabulary is True and:
|
|
1090
|
+
- A feature has no term columns (not vocabulary-based), or
|
|
1091
|
+
- An asset has multiple different vocabulary term values for the same feature
|
|
1092
|
+
and no value_selector is provided.
|
|
1093
|
+
"""
|
|
1094
|
+
from deriva_ml.core.exceptions import DerivaMLException
|
|
1095
|
+
|
|
1096
|
+
cache: dict[str, dict[RID, Any]] = {}
|
|
1097
|
+
# Store all feature value records for later selection when there are multiples
|
|
1098
|
+
records_cache: dict[str, dict[RID, list[FeatureValueRecord]]] = {}
|
|
1099
|
+
logger = logging.getLogger("deriva_ml")
|
|
1100
|
+
|
|
1101
|
+
# Parse group_keys to extract feature names and optional column specifications
|
|
1102
|
+
# Format: "FeatureName" or "FeatureName.column_name"
|
|
1103
|
+
feature_column_map: dict[str, str | None] = {} # group_key -> specific column or None
|
|
1104
|
+
feature_names_to_check: set[str] = set()
|
|
1105
|
+
for key in group_keys:
|
|
1106
|
+
if "." in key:
|
|
1107
|
+
parts = key.split(".", 1)
|
|
1108
|
+
feature_name = parts[0]
|
|
1109
|
+
column_name = parts[1]
|
|
1110
|
+
feature_column_map[key] = column_name
|
|
1111
|
+
feature_names_to_check.add(feature_name)
|
|
1112
|
+
else:
|
|
1113
|
+
feature_column_map[key] = None
|
|
1114
|
+
feature_names_to_check.add(key)
|
|
1115
|
+
|
|
1116
|
+
def process_feature(feat: Any, table_name: str, group_key: str, specific_column: str | None) -> None:
|
|
1117
|
+
"""Process a single feature and add its values to the cache."""
|
|
1118
|
+
term_cols = [c.name for c in feat.term_columns]
|
|
1119
|
+
value_cols = [c.name for c in feat.value_columns]
|
|
1120
|
+
all_cols = term_cols + value_cols
|
|
1121
|
+
|
|
1122
|
+
# Determine which column to use for the value
|
|
1123
|
+
if specific_column:
|
|
1124
|
+
# User specified a specific column
|
|
1125
|
+
if specific_column not in all_cols:
|
|
1126
|
+
raise DerivaMLException(
|
|
1127
|
+
f"Column '{specific_column}' not found in feature '{feat.feature_name}'. "
|
|
1128
|
+
f"Available columns: {all_cols}"
|
|
1129
|
+
)
|
|
1130
|
+
use_column = specific_column
|
|
1131
|
+
elif term_cols:
|
|
1132
|
+
# Use first term column (default behavior)
|
|
1133
|
+
use_column = term_cols[0]
|
|
1134
|
+
elif not enforce_vocabulary and value_cols:
|
|
1135
|
+
# Fall back to value columns if allowed
|
|
1136
|
+
use_column = value_cols[0]
|
|
1137
|
+
else:
|
|
1138
|
+
if enforce_vocabulary:
|
|
1139
|
+
raise DerivaMLException(
|
|
1140
|
+
f"Feature '{feat.feature_name}' on table '{table_name}' has no "
|
|
1141
|
+
f"controlled vocabulary term columns. Only vocabulary-based features "
|
|
1142
|
+
f"can be used for grouping when enforce_vocabulary=True. "
|
|
1143
|
+
f"Set enforce_vocabulary=False to allow non-vocabulary features."
|
|
1144
|
+
)
|
|
1145
|
+
return
|
|
1146
|
+
|
|
1147
|
+
records_cache[group_key] = defaultdict(list)
|
|
1148
|
+
feature_values = self.list_feature_values(table_name, feat.feature_name)
|
|
1149
|
+
|
|
1150
|
+
for fv in feature_values:
|
|
1151
|
+
# Convert FeatureRecord to dict for easier access
|
|
1152
|
+
fv_dict = fv.model_dump()
|
|
1153
|
+
target_col = table_name
|
|
1154
|
+
if target_col not in fv_dict:
|
|
1155
|
+
continue
|
|
1156
|
+
|
|
1157
|
+
target_rid = fv_dict[target_col]
|
|
1158
|
+
|
|
1159
|
+
# Get the value from the specified column
|
|
1160
|
+
value = fv_dict.get(use_column) if use_column in fv_dict else None
|
|
1161
|
+
|
|
1162
|
+
if value is None:
|
|
1163
|
+
continue
|
|
1164
|
+
|
|
1165
|
+
# Create a FeatureValueRecord with execution provenance
|
|
1166
|
+
record = FeatureValueRecord(
|
|
1167
|
+
target_rid=target_rid,
|
|
1168
|
+
feature_name=feat.feature_name,
|
|
1169
|
+
value=value,
|
|
1170
|
+
execution_rid=fv_dict.get("Execution"),
|
|
1171
|
+
raw_record=fv_dict,
|
|
1172
|
+
)
|
|
1173
|
+
records_cache[group_key][target_rid].append(record)
|
|
1174
|
+
|
|
1175
|
+
# Find all features on tables that this asset table references
|
|
1176
|
+
asset_table_obj = self.model.name_to_table(asset_table)
|
|
1177
|
+
|
|
1178
|
+
# Check features on the asset table itself
|
|
1179
|
+
for feature in self.find_features(asset_table):
|
|
1180
|
+
if feature.feature_name in feature_names_to_check:
|
|
1181
|
+
# Find all group_keys that reference this feature
|
|
1182
|
+
for group_key, specific_col in feature_column_map.items():
|
|
1183
|
+
# Check if this group_key references this feature
|
|
1184
|
+
key_feature = group_key.split(".")[0] if "." in group_key else group_key
|
|
1185
|
+
if key_feature == feature.feature_name:
|
|
1186
|
+
try:
|
|
1187
|
+
process_feature(feature, asset_table, group_key, specific_col)
|
|
1188
|
+
except DerivaMLException:
|
|
1189
|
+
raise
|
|
1190
|
+
except Exception as e:
|
|
1191
|
+
logger.warning(f"Could not load feature {feature.feature_name}: {e}")
|
|
1192
|
+
|
|
1193
|
+
# Also check features on referenced tables (via foreign keys)
|
|
1194
|
+
for fk in asset_table_obj.foreign_keys:
|
|
1195
|
+
target_table = fk.pk_table
|
|
1196
|
+
for feature in self.find_features(target_table):
|
|
1197
|
+
if feature.feature_name in feature_names_to_check:
|
|
1198
|
+
# Find all group_keys that reference this feature
|
|
1199
|
+
for group_key, specific_col in feature_column_map.items():
|
|
1200
|
+
# Check if this group_key references this feature
|
|
1201
|
+
key_feature = group_key.split(".")[0] if "." in group_key else group_key
|
|
1202
|
+
if key_feature == feature.feature_name:
|
|
1203
|
+
try:
|
|
1204
|
+
process_feature(feature, target_table.name, group_key, specific_col)
|
|
1205
|
+
except DerivaMLException:
|
|
1206
|
+
raise
|
|
1207
|
+
except Exception as e:
|
|
1208
|
+
logger.warning(f"Could not load feature {feature.feature_name}: {e}")
|
|
1209
|
+
|
|
1210
|
+
# Now resolve multiple values using value_selector or error handling
|
|
1211
|
+
for group_key, target_records in records_cache.items():
|
|
1212
|
+
cache[group_key] = {}
|
|
1213
|
+
for target_rid, records in target_records.items():
|
|
1214
|
+
if len(records) == 1:
|
|
1215
|
+
# Single value - straightforward
|
|
1216
|
+
cache[group_key][target_rid] = records[0].value
|
|
1217
|
+
elif len(records) > 1:
|
|
1218
|
+
# Multiple values - need to resolve
|
|
1219
|
+
unique_values = set(r.value for r in records)
|
|
1220
|
+
if len(unique_values) == 1:
|
|
1221
|
+
# All records have same value, use it
|
|
1222
|
+
cache[group_key][target_rid] = records[0].value
|
|
1223
|
+
elif value_selector:
|
|
1224
|
+
# Use provided selector function
|
|
1225
|
+
selected = value_selector(records)
|
|
1226
|
+
cache[group_key][target_rid] = selected.value
|
|
1227
|
+
elif enforce_vocabulary:
|
|
1228
|
+
# Multiple different values without selector - error
|
|
1229
|
+
values_str = ", ".join(f"'{r.value}' (exec: {r.execution_rid})" for r in records)
|
|
1230
|
+
raise DerivaMLException(
|
|
1231
|
+
f"Asset '{target_rid}' has multiple different values for "
|
|
1232
|
+
f"feature '{records[0].feature_name}': {values_str}. "
|
|
1233
|
+
f"Provide a value_selector function to choose between values, "
|
|
1234
|
+
f"or set enforce_vocabulary=False to use the first value."
|
|
1235
|
+
)
|
|
1236
|
+
else:
|
|
1237
|
+
# Not enforcing - use first value
|
|
1238
|
+
cache[group_key][target_rid] = records[0].value
|
|
1239
|
+
|
|
1240
|
+
return cache
|
|
1241
|
+
|
|
1242
|
+
def _resolve_grouping_value(
|
|
1243
|
+
self,
|
|
1244
|
+
asset: dict[str, Any],
|
|
1245
|
+
group_key: str,
|
|
1246
|
+
feature_cache: dict[str, dict[RID, Any]],
|
|
1247
|
+
) -> str:
|
|
1248
|
+
"""Resolve a grouping value for an asset.
|
|
1249
|
+
|
|
1250
|
+
First checks if group_key is a direct column on the asset record,
|
|
1251
|
+
then checks if it's a feature name in the feature cache.
|
|
1252
|
+
|
|
1253
|
+
Args:
|
|
1254
|
+
asset: The asset record dictionary.
|
|
1255
|
+
group_key: Column name or feature name to group by.
|
|
1256
|
+
feature_cache: Pre-loaded feature values keyed by feature name -> target RID -> value.
|
|
1257
|
+
|
|
1258
|
+
Returns:
|
|
1259
|
+
The resolved value as a string, or "Unknown" if not found or None.
|
|
1260
|
+
Uses "Unknown" (capitalized) to match vocabulary term naming conventions.
|
|
1261
|
+
"""
|
|
1262
|
+
# First check if it's a direct column on the asset table
|
|
1263
|
+
if group_key in asset:
|
|
1264
|
+
value = asset[group_key]
|
|
1265
|
+
if value is not None:
|
|
1266
|
+
return str(value)
|
|
1267
|
+
return "Unknown"
|
|
1268
|
+
|
|
1269
|
+
# Check if it's a feature name
|
|
1270
|
+
if group_key in feature_cache:
|
|
1271
|
+
feature_values = feature_cache[group_key]
|
|
1272
|
+
# Check each column in the asset that might be a FK to the feature target
|
|
1273
|
+
for column_name, column_value in asset.items():
|
|
1274
|
+
if column_value and column_value in feature_values:
|
|
1275
|
+
return str(feature_values[column_value])
|
|
1276
|
+
# Also check if the asset's own RID is in the feature values
|
|
1277
|
+
if asset.get("RID") in feature_values:
|
|
1278
|
+
return str(feature_values[asset["RID"]])
|
|
1279
|
+
|
|
1280
|
+
return "Unknown"
|
|
1281
|
+
|
|
1282
|
+
def _detect_asset_table(self) -> str | None:
|
|
1283
|
+
"""Auto-detect the asset table from dataset members.
|
|
1284
|
+
|
|
1285
|
+
Searches for asset tables in the dataset members by examining
|
|
1286
|
+
the schema. Returns the first asset table found, or None if
|
|
1287
|
+
no asset tables are in the dataset.
|
|
1288
|
+
|
|
1289
|
+
Returns:
|
|
1290
|
+
Name of the detected asset table, or None if not found.
|
|
1291
|
+
"""
|
|
1292
|
+
members = self.list_dataset_members(recurse=True)
|
|
1293
|
+
for table_name in members:
|
|
1294
|
+
if table_name == "Dataset":
|
|
1295
|
+
continue
|
|
1296
|
+
# Check if this table is an asset table
|
|
1297
|
+
try:
|
|
1298
|
+
table = self.model.name_to_table(table_name)
|
|
1299
|
+
if self.model.is_asset(table):
|
|
1300
|
+
return table_name
|
|
1301
|
+
except (KeyError, AttributeError):
|
|
1302
|
+
continue
|
|
1303
|
+
return None
|
|
1304
|
+
|
|
1305
|
+
def _validate_dataset_types(self) -> list[str] | None:
|
|
1306
|
+
"""Validate that the dataset or its children have Training/Testing types.
|
|
1307
|
+
|
|
1308
|
+
Checks if this dataset is of type Training or Testing, or if it has
|
|
1309
|
+
nested children of those types. Returns the valid types found.
|
|
1310
|
+
|
|
1311
|
+
Returns:
|
|
1312
|
+
List of Training/Testing type names found, or None if validation fails.
|
|
1313
|
+
"""
|
|
1314
|
+
valid_types = {"Training", "Testing"}
|
|
1315
|
+
found_types: set[str] = set()
|
|
1316
|
+
|
|
1317
|
+
def check_dataset(ds: DatasetBag, visited: set[RID]) -> None:
|
|
1318
|
+
if ds.dataset_rid in visited:
|
|
1319
|
+
return
|
|
1320
|
+
visited.add(ds.dataset_rid)
|
|
1321
|
+
|
|
1322
|
+
for dtype in ds.dataset_types:
|
|
1323
|
+
if dtype in valid_types:
|
|
1324
|
+
found_types.add(dtype)
|
|
1325
|
+
|
|
1326
|
+
for child in ds.list_dataset_children():
|
|
1327
|
+
check_dataset(child, visited)
|
|
1328
|
+
|
|
1329
|
+
check_dataset(self, set())
|
|
1330
|
+
return list(found_types) if found_types else None
|
|
1331
|
+
|
|
1332
|
+
def restructure_assets(
|
|
1333
|
+
self,
|
|
1334
|
+
output_dir: Path | str,
|
|
1335
|
+
asset_table: str | None = None,
|
|
1336
|
+
group_by: list[str] | None = None,
|
|
1337
|
+
use_symlinks: bool = True,
|
|
1338
|
+
type_selector: Callable[[list[str]], str] | None = None,
|
|
1339
|
+
type_to_dir_map: dict[str, str] | None = None,
|
|
1340
|
+
enforce_vocabulary: bool = True,
|
|
1341
|
+
value_selector: Callable[[list[FeatureValueRecord]], FeatureValueRecord] | None = None,
|
|
1342
|
+
) -> Path:
|
|
1343
|
+
"""Restructure downloaded assets into a directory hierarchy.
|
|
1344
|
+
|
|
1345
|
+
Creates a directory structure organizing assets by dataset types and
|
|
1346
|
+
grouping values. This is useful for ML workflows that expect data
|
|
1347
|
+
organized in conventional folder structures (e.g., PyTorch ImageFolder).
|
|
1348
|
+
|
|
1349
|
+
The dataset should be of type Training or Testing, or have nested
|
|
1350
|
+
children of those types. The top-level directory name is determined
|
|
1351
|
+
by the dataset type (e.g., "Training" -> "training").
|
|
1352
|
+
|
|
1353
|
+
**Finding assets through foreign key relationships:**
|
|
1354
|
+
|
|
1355
|
+
Assets are found by traversing all foreign key paths from the dataset,
|
|
1356
|
+
not just direct associations. For example, if a dataset contains Subjects,
|
|
1357
|
+
and the schema has Subject -> Encounter -> Image relationships, this method
|
|
1358
|
+
will find all Images reachable through those paths even though they are
|
|
1359
|
+
not directly in a Dataset_Image association table.
|
|
1360
|
+
|
|
1361
|
+
**Handling datasets without types (prediction scenarios):**
|
|
1362
|
+
|
|
1363
|
+
If a dataset has no type defined, it is treated as Testing. This is
|
|
1364
|
+
common for prediction/inference scenarios where you want to apply a
|
|
1365
|
+
trained model to new unlabeled data.
|
|
1366
|
+
|
|
1367
|
+
**Handling missing labels:**
|
|
1368
|
+
|
|
1369
|
+
If an asset doesn't have a value for a group_by key (e.g., no label
|
|
1370
|
+
assigned), it is placed in an "Unknown" directory. This allows
|
|
1371
|
+
restructure_assets to work with unlabeled data for prediction.
|
|
1372
|
+
|
|
1373
|
+
Args:
|
|
1374
|
+
output_dir: Base directory for restructured assets.
|
|
1375
|
+
asset_table: Name of the asset table (e.g., "Image"). If None,
|
|
1376
|
+
auto-detects from dataset members. Raises DerivaMLException
|
|
1377
|
+
if multiple asset tables are found and none is specified.
|
|
1378
|
+
group_by: Names to group assets by. Each name creates a subdirectory
|
|
1379
|
+
level after the dataset type path. Names can be:
|
|
1380
|
+
|
|
1381
|
+
- **Column names**: Direct columns on the asset table. The column
|
|
1382
|
+
value becomes the subdirectory name.
|
|
1383
|
+
- **Feature names**: Features defined on the asset table (or tables
|
|
1384
|
+
it references via foreign keys). The feature's vocabulary term
|
|
1385
|
+
value becomes the subdirectory name.
|
|
1386
|
+
- **Feature.column**: Specify a particular column from a multi-term
|
|
1387
|
+
feature (e.g., "Classification.Label" to use the Label column).
|
|
1388
|
+
|
|
1389
|
+
Column names are checked first, then feature names. If a value
|
|
1390
|
+
is not found, "unknown" is used as the subdirectory name.
|
|
1391
|
+
|
|
1392
|
+
use_symlinks: If True (default), create symlinks to original files.
|
|
1393
|
+
If False, copy files. Symlinks save disk space but require
|
|
1394
|
+
the original bag to remain in place.
|
|
1395
|
+
type_selector: Function to select type when dataset has multiple types.
|
|
1396
|
+
Receives list of type names, returns selected type name.
|
|
1397
|
+
Defaults to selecting first type or "unknown" if no types.
|
|
1398
|
+
type_to_dir_map: Optional mapping from dataset type names to directory
|
|
1399
|
+
names. Defaults to {"Training": "training", "Testing": "testing",
|
|
1400
|
+
"Unknown": "unknown"}. Use this to customize directory names or
|
|
1401
|
+
add new type mappings.
|
|
1402
|
+
enforce_vocabulary: If True (default), only allow features that have
|
|
1403
|
+
controlled vocabulary term columns, and raise an error if an asset
|
|
1404
|
+
has multiple different values for the same feature without a
|
|
1405
|
+
value_selector. This ensures clean, unambiguous directory structures.
|
|
1406
|
+
If False, allow any feature type and use the first value found
|
|
1407
|
+
when multiple values exist.
|
|
1408
|
+
value_selector: Optional function to select which feature value to use
|
|
1409
|
+
when an asset has multiple values for the same feature. Receives a
|
|
1410
|
+
list of FeatureValueRecord objects (each containing target_rid,
|
|
1411
|
+
feature_name, value, execution_rid, and raw_record) and returns
|
|
1412
|
+
the selected FeatureValueRecord. Use execution_rid to distinguish
|
|
1413
|
+
between values from different executions.
|
|
1414
|
+
|
|
1415
|
+
Returns:
|
|
1416
|
+
Path to the output directory.
|
|
1417
|
+
|
|
1418
|
+
Raises:
|
|
1419
|
+
DerivaMLException: If asset_table cannot be determined (multiple
|
|
1420
|
+
asset tables exist without specification), if no valid dataset
|
|
1421
|
+
types (Training/Testing) are found, or if enforce_vocabulary
|
|
1422
|
+
is True and a feature has multiple values without value_selector.
|
|
1423
|
+
|
|
1424
|
+
Examples:
|
|
1425
|
+
Basic restructuring with auto-detected asset table::
|
|
1426
|
+
|
|
1427
|
+
bag.restructure_assets(
|
|
1428
|
+
output_dir="./ml_data",
|
|
1429
|
+
group_by=["Diagnosis"],
|
|
1430
|
+
)
|
|
1431
|
+
# Creates:
|
|
1432
|
+
# ./ml_data/training/Normal/image1.jpg
|
|
1433
|
+
# ./ml_data/testing/Abnormal/image2.jpg
|
|
1434
|
+
|
|
1435
|
+
Custom type-to-directory mapping::
|
|
1436
|
+
|
|
1437
|
+
bag.restructure_assets(
|
|
1438
|
+
output_dir="./ml_data",
|
|
1439
|
+
group_by=["Diagnosis"],
|
|
1440
|
+
type_to_dir_map={"Training": "train", "Testing": "test"},
|
|
1441
|
+
)
|
|
1442
|
+
# Creates:
|
|
1443
|
+
# ./ml_data/train/Normal/image1.jpg
|
|
1444
|
+
# ./ml_data/test/Abnormal/image2.jpg
|
|
1445
|
+
|
|
1446
|
+
Select specific feature column for multi-term features::
|
|
1447
|
+
|
|
1448
|
+
bag.restructure_assets(
|
|
1449
|
+
output_dir="./ml_data",
|
|
1450
|
+
group_by=["Classification.Label"], # Use Label column
|
|
1451
|
+
)
|
|
1452
|
+
|
|
1453
|
+
Handle multiple feature values with a selector::
|
|
1454
|
+
|
|
1455
|
+
def select_latest(records: list[FeatureValueRecord]) -> FeatureValueRecord:
|
|
1456
|
+
# Select value from most recent execution
|
|
1457
|
+
return max(records, key=lambda r: r.execution_rid or "")
|
|
1458
|
+
|
|
1459
|
+
bag.restructure_assets(
|
|
1460
|
+
output_dir="./ml_data",
|
|
1461
|
+
group_by=["Diagnosis"],
|
|
1462
|
+
value_selector=select_latest,
|
|
1463
|
+
)
|
|
1464
|
+
|
|
1465
|
+
Prediction scenario with unlabeled data::
|
|
1466
|
+
|
|
1467
|
+
# Dataset has no type - treated as Testing
|
|
1468
|
+
# Assets have no labels - placed in Unknown directory
|
|
1469
|
+
bag.restructure_assets(
|
|
1470
|
+
output_dir="./prediction_data",
|
|
1471
|
+
group_by=["Diagnosis"],
|
|
1472
|
+
)
|
|
1473
|
+
# Creates:
|
|
1474
|
+
# ./prediction_data/testing/Unknown/image1.jpg
|
|
1475
|
+
# ./prediction_data/testing/Unknown/image2.jpg
|
|
1476
|
+
"""
|
|
1477
|
+
logger = logging.getLogger("deriva_ml")
|
|
1478
|
+
group_by = group_by or []
|
|
1479
|
+
output_dir = Path(output_dir)
|
|
1480
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1481
|
+
|
|
1482
|
+
# Default type-to-directory mapping
|
|
1483
|
+
if type_to_dir_map is None:
|
|
1484
|
+
type_to_dir_map = {"Training": "training", "Testing": "testing", "Unknown": "unknown"}
|
|
1485
|
+
|
|
1486
|
+
# Auto-detect asset table if not provided
|
|
1487
|
+
if asset_table is None:
|
|
1488
|
+
asset_table = self._detect_asset_table()
|
|
1489
|
+
if asset_table is None:
|
|
1490
|
+
raise DerivaMLException(
|
|
1491
|
+
"Could not auto-detect asset table. No asset tables found in dataset members. "
|
|
1492
|
+
"Specify the asset_table parameter explicitly."
|
|
1493
|
+
)
|
|
1494
|
+
logger.info(f"Auto-detected asset table: {asset_table}")
|
|
1495
|
+
|
|
1496
|
+
# Step 1: Build dataset type path map with directory name mapping
|
|
1497
|
+
def map_type_to_dir(types: list[str]) -> str:
|
|
1498
|
+
"""Map dataset types to directory name using type_to_dir_map.
|
|
1499
|
+
|
|
1500
|
+
If dataset has no types, treat it as Testing (prediction use case).
|
|
1501
|
+
"""
|
|
1502
|
+
if not types:
|
|
1503
|
+
# No types defined - treat as Testing for prediction scenarios
|
|
1504
|
+
return type_to_dir_map.get("Testing", "testing")
|
|
1505
|
+
if type_selector:
|
|
1506
|
+
selected_type = type_selector(types)
|
|
1507
|
+
else:
|
|
1508
|
+
selected_type = types[0]
|
|
1509
|
+
return type_to_dir_map.get(selected_type, selected_type.lower())
|
|
1510
|
+
|
|
1511
|
+
type_path_map = self._build_dataset_type_path_map(map_type_to_dir)
|
|
1512
|
+
|
|
1513
|
+
# Step 2: Get asset-to-dataset mapping
|
|
1514
|
+
asset_dataset_map = self._get_asset_dataset_mapping(asset_table)
|
|
1515
|
+
|
|
1516
|
+
# Step 3: Load feature values cache for relevant features
|
|
1517
|
+
feature_cache = self._load_feature_values_cache(
|
|
1518
|
+
asset_table, group_by, enforce_vocabulary, value_selector
|
|
1519
|
+
)
|
|
1520
|
+
|
|
1521
|
+
# Step 4: Get all assets reachable through FK paths
|
|
1522
|
+
# This uses _get_reachable_assets which traverses FK relationships,
|
|
1523
|
+
# so assets connected via Subject -> Encounter -> Image are found
|
|
1524
|
+
# even if the dataset only contains Subjects directly.
|
|
1525
|
+
assets = self._get_reachable_assets(asset_table)
|
|
1526
|
+
|
|
1527
|
+
if not assets:
|
|
1528
|
+
logger.warning(f"No assets found in table '{asset_table}'")
|
|
1529
|
+
return output_dir
|
|
1530
|
+
|
|
1531
|
+
# Step 5: Process each asset
|
|
1532
|
+
for asset in assets:
|
|
1533
|
+
# Get source file path
|
|
1534
|
+
filename = asset.get("Filename")
|
|
1535
|
+
if not filename:
|
|
1536
|
+
logger.warning(f"Asset {asset.get('RID')} has no Filename")
|
|
1537
|
+
continue
|
|
1538
|
+
|
|
1539
|
+
source_path = Path(filename)
|
|
1540
|
+
if not source_path.exists():
|
|
1541
|
+
logger.warning(f"Asset file not found: {source_path}")
|
|
1542
|
+
continue
|
|
1543
|
+
|
|
1544
|
+
# Get dataset type path
|
|
1545
|
+
dataset_rid = asset_dataset_map.get(asset["RID"])
|
|
1546
|
+
type_path = type_path_map.get(dataset_rid, ["unknown"])
|
|
1547
|
+
|
|
1548
|
+
# Resolve grouping values
|
|
1549
|
+
group_path = []
|
|
1550
|
+
for key in group_by:
|
|
1551
|
+
value = self._resolve_grouping_value(asset, key, feature_cache)
|
|
1552
|
+
group_path.append(value)
|
|
1553
|
+
|
|
1554
|
+
# Build target directory
|
|
1555
|
+
target_dir = output_dir.joinpath(*type_path, *group_path)
|
|
1556
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
1557
|
+
|
|
1558
|
+
# Create link or copy
|
|
1559
|
+
target_path = target_dir / source_path.name
|
|
1560
|
+
|
|
1561
|
+
# Handle existing files
|
|
1562
|
+
if target_path.exists() or target_path.is_symlink():
|
|
1563
|
+
target_path.unlink()
|
|
1564
|
+
|
|
1565
|
+
if use_symlinks:
|
|
1566
|
+
try:
|
|
1567
|
+
target_path.symlink_to(source_path.resolve())
|
|
1568
|
+
except OSError as e:
|
|
1569
|
+
# Fall back to copy on platforms that don't support symlinks
|
|
1570
|
+
logger.warning(f"Symlink failed, falling back to copy: {e}")
|
|
1571
|
+
shutil.copy2(source_path, target_path)
|
|
1572
|
+
else:
|
|
1573
|
+
shutil.copy2(source_path, target_path)
|
|
443
1574
|
|
|
1575
|
+
return output_dir
|
|
444
1576
|
|
|
445
|
-
# Add annotations after definition to deal with forward reference issues in pydantic
|
|
446
1577
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
validate_return=True,
|
|
450
|
-
)(DatasetBag.list_dataset_children)
|
|
1578
|
+
# Note: validate_call decorators with Self return types were removed because
|
|
1579
|
+
# Pydantic doesn't support typing.Self in validate_call contexts.
|