deriva-ml 1.17.9__py3-none-any.whl → 1.17.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. deriva_ml/__init__.py +43 -1
  2. deriva_ml/asset/__init__.py +17 -0
  3. deriva_ml/asset/asset.py +357 -0
  4. deriva_ml/asset/aux_classes.py +100 -0
  5. deriva_ml/bump_version.py +254 -11
  6. deriva_ml/catalog/__init__.py +21 -0
  7. deriva_ml/catalog/clone.py +1199 -0
  8. deriva_ml/catalog/localize.py +426 -0
  9. deriva_ml/core/__init__.py +29 -0
  10. deriva_ml/core/base.py +817 -1067
  11. deriva_ml/core/config.py +169 -21
  12. deriva_ml/core/constants.py +120 -19
  13. deriva_ml/core/definitions.py +123 -13
  14. deriva_ml/core/enums.py +47 -73
  15. deriva_ml/core/ermrest.py +226 -193
  16. deriva_ml/core/exceptions.py +297 -14
  17. deriva_ml/core/filespec.py +99 -28
  18. deriva_ml/core/logging_config.py +225 -0
  19. deriva_ml/core/mixins/__init__.py +42 -0
  20. deriva_ml/core/mixins/annotation.py +915 -0
  21. deriva_ml/core/mixins/asset.py +384 -0
  22. deriva_ml/core/mixins/dataset.py +237 -0
  23. deriva_ml/core/mixins/execution.py +408 -0
  24. deriva_ml/core/mixins/feature.py +365 -0
  25. deriva_ml/core/mixins/file.py +263 -0
  26. deriva_ml/core/mixins/path_builder.py +145 -0
  27. deriva_ml/core/mixins/rid_resolution.py +204 -0
  28. deriva_ml/core/mixins/vocabulary.py +400 -0
  29. deriva_ml/core/mixins/workflow.py +322 -0
  30. deriva_ml/core/validation.py +389 -0
  31. deriva_ml/dataset/__init__.py +2 -1
  32. deriva_ml/dataset/aux_classes.py +20 -4
  33. deriva_ml/dataset/catalog_graph.py +575 -0
  34. deriva_ml/dataset/dataset.py +1242 -1008
  35. deriva_ml/dataset/dataset_bag.py +1311 -182
  36. deriva_ml/dataset/history.py +27 -14
  37. deriva_ml/dataset/upload.py +225 -38
  38. deriva_ml/demo_catalog.py +186 -105
  39. deriva_ml/execution/__init__.py +46 -2
  40. deriva_ml/execution/base_config.py +639 -0
  41. deriva_ml/execution/execution.py +545 -244
  42. deriva_ml/execution/execution_configuration.py +26 -11
  43. deriva_ml/execution/execution_record.py +592 -0
  44. deriva_ml/execution/find_caller.py +298 -0
  45. deriva_ml/execution/model_protocol.py +175 -0
  46. deriva_ml/execution/multirun_config.py +153 -0
  47. deriva_ml/execution/runner.py +595 -0
  48. deriva_ml/execution/workflow.py +224 -35
  49. deriva_ml/experiment/__init__.py +8 -0
  50. deriva_ml/experiment/experiment.py +411 -0
  51. deriva_ml/feature.py +6 -1
  52. deriva_ml/install_kernel.py +143 -6
  53. deriva_ml/interfaces.py +862 -0
  54. deriva_ml/model/__init__.py +99 -0
  55. deriva_ml/model/annotations.py +1278 -0
  56. deriva_ml/model/catalog.py +286 -60
  57. deriva_ml/model/database.py +144 -649
  58. deriva_ml/model/deriva_ml_database.py +308 -0
  59. deriva_ml/model/handles.py +14 -0
  60. deriva_ml/run_model.py +319 -0
  61. deriva_ml/run_notebook.py +507 -38
  62. deriva_ml/schema/__init__.py +18 -2
  63. deriva_ml/schema/annotations.py +62 -33
  64. deriva_ml/schema/create_schema.py +169 -69
  65. deriva_ml/schema/validation.py +601 -0
  66. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -5
  67. deriva_ml-1.17.11.dist-info/RECORD +77 -0
  68. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
  69. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +2 -0
  70. deriva_ml/protocols/dataset.py +0 -19
  71. deriva_ml/test.py +0 -94
  72. deriva_ml-1.17.9.dist-info/RECORD +0 -45
  73. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
  74. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,44 @@
1
- """
2
- The module implements the sqllite interface to a set of directories representing a dataset bag.
1
+ """SQLite-backed dataset access for downloaded BDBags.
2
+
3
+ This module provides the DatasetBag class, which allows querying and navigating
4
+ downloaded dataset bags using SQLite. When a dataset is downloaded from a Deriva
5
+ catalog, it is stored as a BDBag (Big Data Bag) containing:
6
+
7
+ - CSV files with table data
8
+ - Asset files (images, documents, etc.)
9
+ - A schema.json describing the catalog structure
10
+ - A fetch.txt manifest of referenced files
11
+
12
+ The DatasetBag class provides a read-only interface to this data, mirroring
13
+ the Dataset class API where possible. This allows code to work uniformly
14
+ with both live catalog datasets and downloaded bags.
15
+
16
+ Key concepts:
17
+ - DatasetBag wraps a single dataset within a downloaded bag
18
+ - A bag may contain multiple datasets (nested/hierarchical)
19
+ - All operations are read-only (bags are immutable snapshots)
20
+ - Queries use SQLite via SQLAlchemy ORM
21
+ - Table-level access (get_table_as_dict, lookup_term) is on the catalog (DerivaMLDatabase)
22
+
23
+ Typical usage:
24
+ >>> # Download a dataset from a catalog
25
+ >>> bag = ml.download_dataset_bag(dataset_spec)
26
+ >>> # List dataset members by type
27
+ >>> members = bag.list_dataset_members(recurse=True)
28
+ >>> for image in members.get("Image", []):
29
+ ... print(image["Filename"])
3
30
  """
4
31
 
5
32
  from __future__ import annotations
6
33
 
7
34
  # Standard library imports
35
+ import logging
36
+ import shutil
8
37
  from collections import defaultdict
9
38
  from copy import copy
10
- from typing import TYPE_CHECKING, Any, Generator, Iterable, cast
39
+ from dataclasses import dataclass, field
40
+ from pathlib import Path
41
+ from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Self, cast
11
42
 
12
43
  import deriva.core.datapath as datapath
13
44
 
@@ -18,17 +49,17 @@ import pandas as pd
18
49
  from deriva.core.ermrest_model import Table
19
50
 
20
51
  # Deriva imports
21
- from pydantic import ConfigDict, validate_call
22
- from sqlalchemy import CompoundSelect, Engine, RowMapping, Select, and_, inspect, select, union
52
+ from sqlalchemy import CompoundSelect, Engine, Select, and_, inspect, select, union
23
53
  from sqlalchemy.orm import RelationshipProperty, Session
24
54
  from sqlalchemy.orm.util import AliasedClass
25
55
 
26
- from deriva_ml.core.definitions import RID, VocabularyTerm
27
- from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
28
- from deriva_ml.feature import Feature
56
+ from deriva_ml.core.definitions import RID
57
+ from deriva_ml.core.exceptions import DerivaMLException
58
+ from deriva_ml.dataset.aux_classes import DatasetHistory, DatasetVersion
59
+ from deriva_ml.feature import Feature, FeatureRecord
29
60
 
30
61
  if TYPE_CHECKING:
31
- from deriva_ml.model.database import DatabaseModel
62
+ from deriva_ml.model.deriva_ml_database import DerivaMLDatabase
32
63
 
33
64
  try:
34
65
  from icecream import ic
@@ -36,69 +67,235 @@ except ImportError: # Graceful fallback if IceCream isn't installed.
36
67
  ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
37
68
 
38
69
 
39
- class DatasetBag:
70
+ @dataclass
71
+ class FeatureValueRecord:
72
+ """A feature value record with execution provenance.
73
+
74
+ This class represents a single feature value assigned to an asset,
75
+ including the execution that created it. Used by restructure_assets
76
+ when a value_selector function needs to choose between multiple
77
+ feature values for the same asset.
78
+
79
+ The raw_record attribute contains the complete feature table row as
80
+ a dictionary, which can be used to access all columns including any
81
+ additional metadata or columns beyond the primary value.
82
+
83
+ Attributes:
84
+ target_rid: RID of the asset/entity this feature value applies to.
85
+ feature_name: Name of the feature.
86
+ value: The feature value (typically a vocabulary term name).
87
+ execution_rid: RID of the execution that created this feature value, if any.
88
+ Use this to distinguish between values from different executions.
89
+ raw_record: The complete raw record from the feature table as a dictionary.
90
+ Access all columns via dict keys, e.g., record.raw_record["MyColumn"].
91
+
92
+ Example:
93
+ Using a value_selector to choose the most recent feature value::
94
+
95
+ def select_by_execution(records: list[FeatureValueRecord]) -> FeatureValueRecord:
96
+ # Select value from most recent execution (assuming RIDs are sortable)
97
+ return max(records, key=lambda r: r.execution_rid or "")
98
+
99
+ bag.restructure_assets(
100
+ output_dir="./ml_data",
101
+ group_by=["Diagnosis"],
102
+ value_selector=select_by_execution,
103
+ )
104
+
105
+ Accessing raw record data::
106
+
107
+ def select_by_confidence(records: list[FeatureValueRecord]) -> FeatureValueRecord:
108
+ # Select value with highest confidence score from raw record
109
+ return max(records, key=lambda r: r.raw_record.get("Confidence", 0))
40
110
  """
41
- DatasetBag is a class that manages a materialized bag. It is created from a locally materialized
42
- BDBag for a dataset_table, which is created either by DerivaML.create_execution, or directly by
43
- calling DerivaML.download_dataset.
111
+ target_rid: RID
112
+ feature_name: str
113
+ value: Any
114
+ execution_rid: RID | None = None
115
+ raw_record: dict[str, Any] = field(default_factory=dict)
116
+
117
+ def __repr__(self) -> str:
118
+ return (f"FeatureValueRecord(target_rid='{self.target_rid}', "
119
+ f"feature_name='{self.feature_name}', value='{self.value}', "
120
+ f"execution_rid='{self.execution_rid}')")
121
+
122
+
123
+ class DatasetBag:
124
+ """Read-only interface to a downloaded dataset bag.
125
+
126
+ DatasetBag manages access to a materialized BDBag (Big Data Bag) that contains
127
+ a snapshot of dataset data from a Deriva catalog. It provides methods for:
44
128
 
45
- A general a bag may contain multiple datasets, if the dataset is nested. The DatasetBag is used to
46
- represent only one of the datasets in the bag.
129
+ - Listing dataset members and their attributes
130
+ - Navigating dataset relationships (parents, children)
131
+ - Accessing feature values
132
+ - Denormalizing data across related tables
47
133
 
48
- All the metadata associated with the dataset is stored in a SQLLite database that can be queried using SQL.
134
+ A bag may contain multiple datasets when nested datasets are involved. Each
135
+ DatasetBag instance represents a single dataset within the bag - use
136
+ list_dataset_children() to navigate to nested datasets.
137
+
138
+ For catalog-level operations like querying arbitrary tables or looking up
139
+ vocabulary terms, use the DerivaMLDatabase class instead.
140
+
141
+ The class implements the DatasetLike protocol, providing the same read interface
142
+ as the Dataset class. This allows code to work with both live catalogs and
143
+ downloaded bags interchangeably.
49
144
 
50
145
  Attributes:
51
- dataset_rid (RID): RID for the specified dataset
52
- version: The version of the dataset
53
- model (DatabaseModel): The Database model that has all the catalog metadata associated with this dataset.
54
- database:
55
- dbase (sqlite3.Connection): connection to the sqlite database holding table values
56
- domain_schema (str): Name of the domain schema
146
+ dataset_rid (RID): The unique Resource Identifier for this dataset.
147
+ dataset_types (list[str]): List of vocabulary terms describing the dataset type.
148
+ description (str): Human-readable description of the dataset.
149
+ execution_rid (RID | None): RID of the execution associated with this dataset version, if any.
150
+ model (DatabaseModel): The DatabaseModel providing SQLite access to bag data.
151
+ engine (Engine): SQLAlchemy engine for database queries.
152
+ metadata (MetaData): SQLAlchemy metadata with table definitions.
153
+
154
+ Example:
155
+ >>> # Download a dataset
156
+ >>> bag = dataset.download_dataset_bag(version="1.0.0")
157
+ >>> # List members by type
158
+ >>> members = bag.list_dataset_members()
159
+ >>> for image in members.get("Image", []):
160
+ ... print(f"File: {image['Filename']}")
161
+ >>> # Navigate to nested datasets
162
+ >>> for child in bag.list_dataset_children():
163
+ ... print(f"Nested: {child.dataset_rid}")
57
164
  """
58
165
 
59
- def __init__(self, database_model: DatabaseModel, dataset_rid: RID | None = None) -> None:
60
- """
61
- Initialize a DatasetBag instance.
166
+ def __init__(
167
+ self,
168
+ catalog: "DerivaMLDatabase",
169
+ dataset_rid: RID | None = None,
170
+ dataset_types: str | list[str] | None = None,
171
+ description: str = "",
172
+ execution_rid: RID | None = None,
173
+ ):
174
+ """Initialize a DatasetBag instance for a dataset within a downloaded bag.
175
+
176
+ This mirrors the Dataset class initialization pattern, where both classes
177
+ take a catalog-like object as their first argument for consistency.
62
178
 
63
179
  Args:
64
- database_model: Database version of the bag.
65
- dataset_rid: Optional RID for the dataset.
180
+ catalog: The DerivaMLDatabase instance providing access to the bag's data.
181
+ This implements the DerivaMLCatalog protocol.
182
+ dataset_rid: The RID of the dataset to wrap. If None, uses the primary
183
+ dataset RID from the bag.
184
+ dataset_types: One or more dataset type terms. Can be a single string
185
+ or list of strings.
186
+ description: Human-readable description of the dataset.
187
+ execution_rid: RID of the execution associated with this dataset version.
188
+ If None, will be looked up from the Dataset_Version table.
189
+
190
+ Raises:
191
+ DerivaMLException: If no dataset_rid is provided and none can be
192
+ determined from the bag, or if the RID doesn't exist in the bag.
66
193
  """
67
- self.model = database_model
194
+ # Store reference to the catalog and extract the underlying model
195
+ self._catalog = catalog
196
+ self.model = catalog.model
68
197
  self.engine = cast(Engine, self.model.engine)
69
198
  self.metadata = self.model.metadata
70
199
 
200
+ # Use provided RID or fall back to the bag's primary dataset
71
201
  self.dataset_rid = dataset_rid or self.model.dataset_rid
202
+ self.description = description
203
+ self.execution_rid = execution_rid or (
204
+ self.model._get_dataset_execution(self.dataset_rid) or {}
205
+ ).get("Execution")
206
+
207
+ # Normalize dataset_types to always be a list of strings for consistency
208
+ # with the Dataset class interface
209
+ if dataset_types is None:
210
+ self.dataset_types: list[str] = []
211
+ elif isinstance(dataset_types, str):
212
+ self.dataset_types: list[str] = [dataset_types]
213
+ else:
214
+ self.dataset_types: list[str] = list(dataset_types)
215
+
72
216
  if not self.dataset_rid:
73
217
  raise DerivaMLException("No dataset RID provided")
74
218
 
75
- self.model.rid_lookup(self.dataset_rid) # Check to make sure that this dataset is in the bag.
219
+ # Validate that this dataset exists in the bag
220
+ self.model.rid_lookup(self.dataset_rid)
76
221
 
77
- self.version = self.model.dataset_version(self.dataset_rid)
222
+ # Cache the version and dataset table reference
223
+ self._current_version = self.model.dataset_version(self.dataset_rid)
78
224
  self._dataset_table = self.model.dataset_table
79
225
 
80
226
  def __repr__(self) -> str:
81
- return f"<deriva_ml.DatasetBag object {self.dataset_rid} at {hex(id(self))}>"
227
+ """Return a string representation of the DatasetBag for debugging."""
228
+ return (f"<deriva_ml.DatasetBag object at {hex(id(self))}: rid='{self.dataset_rid}', "
229
+ f"version='{self.current_version}', types={self.dataset_types}>")
230
+
231
+ @property
232
+ def current_version(self) -> DatasetVersion:
233
+ """Get the version of the dataset at the time the bag was downloaded.
234
+
235
+ For a DatasetBag, this is the version that was current when the bag was
236
+ created. Unlike the live Dataset class, this value is immutable since
237
+ bags are read-only snapshots.
238
+
239
+ Returns:
240
+ DatasetVersion: The semantic version (major.minor.patch) of this dataset.
241
+ """
242
+ return self._current_version
82
243
 
83
244
  def list_tables(self) -> list[str]:
84
- """List the names of the tables in the catalog
245
+ """List all tables available in the bag's SQLite database.
246
+
247
+ Returns the fully-qualified names of all tables (e.g., "domain.Image",
248
+ "deriva-ml.Dataset") that were exported in this bag.
85
249
 
86
250
  Returns:
87
- A list of table names. These names are all qualified with the Deriva schema name.
251
+ list[str]: Table names in "schema.table" format, sorted alphabetically.
88
252
  """
89
253
  return self.model.list_tables()
90
254
 
255
+ def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
256
+ """Get table contents as dictionaries.
257
+
258
+ Convenience method that delegates to the underlying catalog. This provides
259
+ access to all rows in a table, not just those belonging to this dataset.
260
+ For dataset-filtered results, use list_dataset_members() instead.
261
+
262
+ Args:
263
+ table: Name of the table to retrieve (e.g., "Subject", "Image").
264
+
265
+ Yields:
266
+ dict: Dictionary for each row in the table.
267
+
268
+ Example:
269
+ >>> for subject in bag.get_table_as_dict("Subject"):
270
+ ... print(subject["Name"])
271
+ """
272
+ return self._catalog.get_table_as_dict(table)
273
+
91
274
  @staticmethod
92
275
  def _find_relationship_attr(source, target):
93
- """
94
- Return the relationship attribute (InstrumentedAttribute) on `source`
95
- that points to `target`. Works with classes or AliasedClass.
96
- Raises LookupError if not found.
276
+ """Find the SQLAlchemy relationship attribute connecting two ORM classes.
277
+
278
+ Searches for a relationship on `source` that points to `target`, which is
279
+ needed to construct proper JOIN clauses in SQL queries.
280
+
281
+ Args:
282
+ source: Source ORM class or AliasedClass.
283
+ target: Target ORM class or AliasedClass.
284
+
285
+ Returns:
286
+ InstrumentedAttribute: The relationship attribute on source pointing to target.
287
+
288
+ Raises:
289
+ LookupError: If no relationship exists between the two classes.
290
+
291
+ Note:
292
+ When multiple relationships exist, prefers MANYTOONE direction as this
293
+ is typically the more natural join direction for denormalization.
97
294
  """
98
295
  src_mapper = inspect(source).mapper
99
296
  tgt_mapper = inspect(target).mapper
100
297
 
101
- # collect relationships on the *class* mapper (not on alias)
298
+ # Collect all relationships on the source mapper that point to target
102
299
  candidates: list[RelationshipProperty] = [rel for rel in src_mapper.relationships if rel.mapper is tgt_mapper]
103
300
 
104
301
  if not candidates:
@@ -108,86 +305,117 @@ class DatasetBag:
108
305
  candidates.sort(key=lambda r: r.direction.name != "MANYTOONE")
109
306
  rel = candidates[0]
110
307
 
111
- # Bind to the actual source (alias or class)
308
+ # Return the bound attribute (handles AliasedClass properly)
112
309
  return getattr(source, rel.key) if isinstance(source, AliasedClass) else rel.class_attribute
113
310
 
114
311
  def _dataset_table_view(self, table: str) -> CompoundSelect[Any]:
115
- """Return a SQL command that will return all of the elements in the specified table that are associated with
116
- dataset_rid"""
312
+ """Build a SQL query for all rows in a table that belong to this dataset.
313
+
314
+ Creates a UNION of queries that traverse all possible paths from the
315
+ Dataset table to the target table, filtering by this dataset's RID
316
+ (and any nested dataset RIDs).
317
+
318
+ This is necessary because table data may be linked to datasets through
319
+ different relationship paths (e.g., Image might be linked directly to
320
+ Dataset or through an intermediate Subject table).
321
+
322
+ Args:
323
+ table: Name of the table to query.
324
+
325
+ Returns:
326
+ CompoundSelect: A SQLAlchemy UNION query selecting all matching rows.
327
+ """
117
328
  table_class = self.model.get_orm_class_by_name(table)
118
329
  dataset_table_class = self.model.get_orm_class_by_name(self._dataset_table.name)
330
+
331
+ # Include this dataset and all nested datasets in the query
119
332
  dataset_rids = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
120
333
 
334
+ # Find all paths from Dataset to the target table
121
335
  paths = [[t.name for t in p] for p in self.model._schema_to_paths() if p[-1].name == table]
336
+
337
+ # Build a SELECT query for each path and UNION them together
122
338
  sql_cmds = []
123
339
  for path in paths:
124
340
  path_sql = select(table_class)
125
341
  last_class = self.model.get_orm_class_by_name(path[0])
342
+ # Join through each table in the path
126
343
  for t in path[1:]:
127
344
  t_class = self.model.get_orm_class_by_name(t)
128
345
  path_sql = path_sql.join(self._find_relationship_attr(last_class, t_class))
129
346
  last_class = t_class
347
+ # Filter to only rows belonging to our dataset(s)
130
348
  path_sql = path_sql.where(dataset_table_class.RID.in_(dataset_rids))
131
349
  sql_cmds.append(path_sql)
132
350
  return union(*sql_cmds)
133
351
 
134
- def get_table(self, table: str) -> Generator[tuple, None, None]:
135
- """Retrieve the contents of the specified table. If schema is not provided as part of the table name,
136
- the method will attempt to locate the schema for the table.
352
+ def dataset_history(self) -> list[DatasetHistory]:
353
+ """Retrieves the version history of a dataset.
137
354
 
138
- Args:
139
- table: return: A generator that yields tuples of column values.
355
+ Returns a chronological list of dataset versions, including their version numbers,
356
+ creation times, and associated metadata.
140
357
 
141
358
  Returns:
142
- A generator that yields tuples of column values.
359
+ list[DatasetHistory]: List of history entries, each containing:
360
+ - dataset_version: Version number (major.minor.patch)
361
+ - minid: Minimal Viable Identifier
362
+ - snapshot: Catalog snapshot time
363
+ - dataset_rid: Dataset Resource Identifier
364
+ - version_rid: Version Resource Identifier
365
+ - description: Version description
366
+ - execution_rid: Associated execution RID
143
367
 
144
- """
145
- with Session(self.engine) as session:
146
- result = session.execute(self._dataset_table_view(table))
147
- for row in result:
148
- yield row
149
-
150
- def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
151
- """Retrieve the contents of the specified table as a dataframe.
152
-
153
-
154
- If schema is not provided as part of the table name,
155
- the method will attempt to locate the schema for the table.
156
-
157
- Args:
158
- table: Table to retrieve data from.
159
-
160
- Returns:
161
- A dataframe containing the contents of the specified table.
162
- """
163
- return pd.read_sql(self._dataset_table_view(table), self.engine)
164
-
165
- def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
166
- """Retrieve the contents of the specified table as a dictionary.
167
-
168
- Args:
169
- table: Table to retrieve data from. f schema is not provided as part of the table name,
170
- the method will attempt to locate the schema for the table.
368
+ Raises:
369
+ DerivaMLException: If dataset_rid is not a valid dataset RID.
171
370
 
172
- Returns:
173
- A generator producing dictionaries containing the contents of the specified table as name/value pairs.
371
+ Example:
372
+ >>> history = ml.dataset_history("1-abc123")
373
+ >>> for entry in history:
374
+ ... print(f"Version {entry.dataset_version}: {entry.description}")
174
375
  """
376
+ # Query Dataset_Version table directly via the model
377
+ return [
378
+ DatasetHistory(
379
+ dataset_version=DatasetVersion.parse(v["Version"]),
380
+ minid=v["Minid"],
381
+ snapshot=v["Snapshot"],
382
+ dataset_rid=self.dataset_rid,
383
+ version_rid=v["RID"],
384
+ description=v["Description"],
385
+ execution_rid=v["Execution"],
386
+ )
387
+ for v in self.model._get_table_contents("Dataset_Version")
388
+ if v["Dataset"] == self.dataset_rid
389
+ ]
175
390
 
176
- with Session(self.engine) as session:
177
- result = session.execute(self._dataset_table_view(table))
178
- for row in result.mappings():
179
- yield row
180
-
181
- # @validate_call
182
- def list_dataset_members(self, recurse: bool = False) -> dict[str, list[dict[str, Any]]]:
391
+ def list_dataset_members(
392
+ self,
393
+ recurse: bool = False,
394
+ limit: int | None = None,
395
+ _visited: set[RID] | None = None,
396
+ version: Any = None,
397
+ **kwargs: Any,
398
+ ) -> dict[str, list[dict[str, Any]]]:
183
399
  """Return a list of entities associated with a specific dataset.
184
400
 
185
401
  Args:
186
- recurse: Whether to include nested datasets.
402
+ recurse: Whether to include members of nested datasets.
403
+ limit: Maximum number of members to return per type. None for no limit.
404
+ _visited: Internal parameter to track visited datasets and prevent infinite recursion.
405
+ version: Ignored (bags are immutable snapshots).
406
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
187
407
 
188
408
  Returns:
189
- Dictionary of entities associated with the dataset.
409
+ Dictionary mapping member types to lists of member records.
190
410
  """
411
+ # Initialize visited set for recursion guard
412
+ if _visited is None:
413
+ _visited = set()
414
+
415
+ # Prevent infinite recursion by checking if we've already visited this dataset
416
+ if self.dataset_rid in _visited:
417
+ return {}
418
+ _visited.add(self.dataset_rid)
191
419
 
192
420
  # Look at each of the element types that might be in the _dataset_table and get the list of rid for them from
193
421
  # the appropriate association table.
@@ -200,16 +428,29 @@ class DatasetBag:
200
428
  assoc_class, dataset_rel, element_rel = self.model.get_orm_association_class(dataset_class, element_class)
201
429
 
202
430
  element_table = inspect(element_class).mapped_table
203
- if element_table.schema != self.model.domain_schema and element_table.name not in ["Dataset", "File"]:
431
+ if not self.model.is_domain_schema(element_table.schema) and element_table.name not in ["Dataset", "File"]:
204
432
  # Look at domain tables and nested datasets.
205
433
  continue
434
+
206
435
  # Get the names of the columns that we are going to need for linking
207
436
  with Session(self.engine) as session:
208
- sql_cmd = (
209
- select(element_class)
210
- .join(element_rel)
211
- .where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
212
- )
437
+ # For Dataset_Dataset, use Nested_Dataset column to find nested datasets
438
+ # (similar to how the live catalog does it in Dataset.list_dataset_members)
439
+ if element_table.name == "Dataset":
440
+ sql_cmd = (
441
+ select(element_class)
442
+ .join(assoc_class, element_class.RID == assoc_class.__table__.c["Nested_Dataset"])
443
+ .where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
444
+ )
445
+ else:
446
+ # For other tables, use the original join via element_rel
447
+ sql_cmd = (
448
+ select(element_class)
449
+ .join(element_rel)
450
+ .where(self.dataset_rid == assoc_class.__table__.c["Dataset"])
451
+ )
452
+ if limit is not None:
453
+ sql_cmd = sql_cmd.limit(limit)
213
454
  # Get back the list of ORM entities and convert them to dictionaries.
214
455
  element_entities = session.scalars(sql_cmd).all()
215
456
  element_rows = [{c.key: getattr(obj, c.key) for c in obj.__table__.columns} for obj in element_entities]
@@ -218,8 +459,8 @@ class DatasetBag:
218
459
  # Get the members for all the nested datasets and add to the member list.
219
460
  nested_datasets = [d["RID"] for d in element_rows]
220
461
  for ds in nested_datasets:
221
- nested_dataset = self.model.get_dataset(ds)
222
- for k, v in nested_dataset.list_dataset_members(recurse=recurse).items():
462
+ nested_dataset = self._catalog.lookup_dataset(ds)
463
+ for k, v in nested_dataset.list_dataset_members(recurse=recurse, limit=limit, _visited=_visited).items():
223
464
  members[k].extend(v)
224
465
  return dict(members)
225
466
 
@@ -234,25 +475,63 @@ class DatasetBag:
234
475
  """
235
476
  return self.model.find_features(table)
236
477
 
237
- def list_feature_values(self, table: Table | str, feature_name: str) -> datapath._ResultSet:
238
- """Return feature values for a table.
478
+ def list_feature_values(
479
+ self, table: Table | str, feature_name: str
480
+ ) -> Iterable[FeatureRecord]:
481
+ """Retrieves all values for a feature as typed FeatureRecord instances.
482
+
483
+ Returns an iterator of dynamically-generated FeatureRecord objects for each
484
+ feature value. Each record is an instance of a Pydantic model specific to
485
+ this feature, with typed attributes for all columns including the Execution
486
+ that created the feature value.
239
487
 
240
488
  Args:
241
- table: The table to get feature values for.
242
- feature_name: Name of the feature.
489
+ table: The table containing the feature, either as name or Table object.
490
+ feature_name: Name of the feature to retrieve values for.
243
491
 
244
492
  Returns:
245
- Feature values.
493
+ Iterable[FeatureRecord]: An iterator of FeatureRecord instances.
494
+ Each instance has:
495
+ - Execution: RID of the execution that created this feature value
496
+ - Feature_Name: Name of the feature
497
+ - All feature-specific columns as typed attributes
498
+ - model_dump() method to convert back to a dictionary
499
+
500
+ Raises:
501
+ DerivaMLException: If the feature doesn't exist or cannot be accessed.
502
+
503
+ Example:
504
+ >>> # Get typed feature records
505
+ >>> for record in bag.list_feature_values("Image", "Quality"):
506
+ ... print(f"Image {record.Image}: {record.ImageQuality}")
507
+ ... print(f"Created by execution: {record.Execution}")
508
+
509
+ >>> # Convert records to dictionaries
510
+ >>> records = list(bag.list_feature_values("Image", "Quality"))
511
+ >>> dicts = [r.model_dump() for r in records]
246
512
  """
513
+ # Get table and feature
247
514
  feature = self.model.lookup_feature(table, feature_name)
248
- feature_class = self.model.get_orm_class_for_table(feature.feature_table)
515
+
516
+ # Get the dynamically-generated FeatureRecord subclass for this feature
517
+ record_class = feature.feature_record_class()
518
+
519
+ # Query raw values from SQLite
520
+ feature_table = self.model.find_table(feature.feature_table.name)
249
521
  with Session(self.engine) as session:
250
- sql_cmd = select(feature_class)
251
- return cast(datapath._ResultSet, [row for row in session.execute(sql_cmd).mappings()])
522
+ sql_cmd = select(feature_table)
523
+ result = session.execute(sql_cmd)
524
+ rows = [dict(row._mapping) for row in result]
252
525
 
253
- def list_dataset_element_types(self) -> list[Table]:
254
- """
255
- Lists the data types of elements contained within a dataset.
526
+ # Convert to typed records
527
+ for raw_value in rows:
528
+ # Filter to only include fields that the record class expects
529
+ field_names = set(record_class.model_fields.keys())
530
+ filtered_data = {k: v for k, v in raw_value.items() if k in field_names}
531
+ yield record_class(**filtered_data)
532
+
533
+ def list_dataset_element_types(self) -> Iterable[Table]:
534
+ """List the types of elements that can be contained in datasets.
256
535
 
257
536
  This method analyzes the dataset and identifies the data types for all
258
537
  elements within it. It is useful for understanding the structure and
@@ -266,15 +545,33 @@ class DatasetBag:
266
545
  """
267
546
  return self.model.list_dataset_element_types()
268
547
 
269
- def list_dataset_children(self, recurse: bool = False) -> list[DatasetBag]:
548
+ def list_dataset_children(
549
+ self,
550
+ recurse: bool = False,
551
+ _visited: set[RID] | None = None,
552
+ version: Any = None,
553
+ **kwargs: Any,
554
+ ) -> list[Self]:
270
555
  """Get nested datasets.
271
556
 
272
557
  Args:
273
558
  recurse: Whether to include children of children.
559
+ _visited: Internal parameter to track visited datasets and prevent infinite recursion.
560
+ version: Ignored (bags are immutable snapshots).
561
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
274
562
 
275
563
  Returns:
276
564
  List of child dataset bags.
277
565
  """
566
+ # Initialize visited set for recursion guard
567
+ if _visited is None:
568
+ _visited = set()
569
+
570
+ # Prevent infinite recursion by checking if we've already visited this dataset
571
+ if self.dataset_rid in _visited:
572
+ return []
573
+ _visited.add(self.dataset_rid)
574
+
278
575
  ds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset")
279
576
  nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
280
577
  dv_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Version")
@@ -286,63 +583,102 @@ class DatasetBag:
286
583
  .join_from(ds_table, dv_table, onclause=ds_table.Version == dv_table.RID)
287
584
  .where(nds_table.Dataset == self.dataset_rid)
288
585
  )
289
- nested = [DatasetBag(self.model, r[0]) for r in session.execute(sql_cmd).all()]
586
+ nested = [self._catalog.lookup_dataset(r[0]) for r in session.execute(sql_cmd).all()]
290
587
 
291
588
  result = copy(nested)
292
589
  if recurse:
293
590
  for child in nested:
294
- result.extend(child.list_dataset_children(recurse))
591
+ result.extend(child.list_dataset_children(recurse=recurse, _visited=_visited))
295
592
  return result
296
593
 
297
- @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
298
- def lookup_term(self, table: str | Table, term_name: str) -> VocabularyTerm:
299
- """Finds a term in a vocabulary table.
300
-
301
- Searches for a term in the specified vocabulary table, matching either the primary name
302
- or any of its synonyms.
594
+ def list_dataset_parents(
595
+ self,
596
+ recurse: bool = False,
597
+ _visited: set[RID] | None = None,
598
+ version: Any = None,
599
+ **kwargs: Any,
600
+ ) -> list[Self]:
601
+ """Given a dataset_table RID, return a list of RIDs of the parent datasets if this is included in a
602
+ nested dataset.
303
603
 
304
604
  Args:
305
- table: Vocabulary table to search in (name or Table object).
306
- term_name: Name or synonym of the term to find.
605
+ recurse: If True, recursively return all ancestor datasets.
606
+ _visited: Internal parameter to track visited datasets and prevent infinite recursion.
607
+ version: Ignored (bags are immutable snapshots).
608
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
307
609
 
308
610
  Returns:
309
- VocabularyTerm: The matching vocabulary term.
611
+ List of parent dataset bags.
612
+ """
613
+ # Initialize visited set for recursion guard
614
+ if _visited is None:
615
+ _visited = set()
310
616
 
311
- Raises:
312
- DerivaMLVocabularyException: If the table is not a vocabulary table, or term is not found.
617
+ # Prevent infinite recursion by checking if we've already visited this dataset
618
+ if self.dataset_rid in _visited:
619
+ return []
620
+ _visited.add(self.dataset_rid)
313
621
 
314
- Examples:
315
- Look up by primary name:
316
- >>> term = ml.lookup_term("tissue_types", "epithelial")
317
- >>> print(term.description)
622
+ nds_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Dataset")
318
623
 
319
- Look up by synonym:
320
- >>> term = ml.lookup_term("tissue_types", "epithelium")
321
- """
322
- # Get and validate vocabulary table reference
323
- if not self.model.is_vocabulary(table):
324
- raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
624
+ with Session(self.engine) as session:
625
+ sql_cmd = select(nds_table.Dataset).where(nds_table.Nested_Dataset == self.dataset_rid)
626
+ parents = [self._catalog.lookup_dataset(r[0]) for r in session.execute(sql_cmd).all()]
325
627
 
326
- # Search for term by name or synonym
327
- for term in self.get_table_as_dict(table):
328
- if term_name == term["Name"] or (term["Synonyms"] and term_name in term["Synonyms"]):
329
- term["Synonyms"] = list(term["Synonyms"])
330
- return VocabularyTerm.model_validate(term)
628
+ if recurse:
629
+ for parent in parents.copy():
630
+ parents.extend(parent.list_dataset_parents(recurse=True, _visited=_visited))
631
+ return parents
331
632
 
332
- # Term not found
333
- raise DerivaMLInvalidTerm(table, term_name)
633
+ def list_executions(self) -> list[RID]:
634
+ """List all execution RIDs associated with this dataset.
334
635
 
335
- def _denormalize(self, include_tables: list[str]) -> Select:
636
+ Returns all executions that used this dataset as input. This is
637
+ tracked through the Dataset_Execution association table.
638
+
639
+ Note:
640
+ Unlike the live Dataset class which returns Execution objects,
641
+ DatasetBag returns a list of execution RIDs since the bag is
642
+ an offline snapshot and cannot look up live execution objects.
643
+
644
+ Returns:
645
+ List of execution RIDs associated with this dataset.
646
+
647
+ Example:
648
+ >>> bag = ml.download_dataset_bag(dataset_spec)
649
+ >>> execution_rids = bag.list_executions()
650
+ >>> for rid in execution_rids:
651
+ ... print(f"Associated execution: {rid}")
336
652
  """
337
- Generates an SQL statement for denormalizing the dataset based on the tables to include. Processes cycles in
338
- graph relationships, ensures proper join order, and generates selected columns for denormalization.
653
+ de_table = self.model.get_orm_class_by_name(f"{self.model.ml_schema}.Dataset_Execution")
654
+
655
+ with Session(self.engine) as session:
656
+ sql_cmd = select(de_table.Execution).where(de_table.Dataset == self.dataset_rid)
657
+ return [r[0] for r in session.execute(sql_cmd).all()]
658
+
659
+ def _denormalize(self, include_tables: list[str]) -> Select:
660
+ """Build a SQL query that joins multiple tables into a denormalized view.
661
+
662
+ This method creates a "wide table" by joining related tables together,
663
+ producing a single query that returns columns from all specified tables.
664
+ This is useful for machine learning pipelines that need flat data.
665
+
666
+ The method:
667
+ 1. Analyzes the schema to find join paths between tables
668
+ 2. Determines the correct join order based on foreign key relationships
669
+ 3. Builds SELECT statements with properly aliased columns
670
+ 4. Creates a UNION if multiple paths exist to the same tables
339
671
 
340
672
  Args:
341
- include_tables (list[str] | None): List of table names to include in the denormalized dataset. If None,
342
- all tables from the dataset will be included.
673
+ include_tables: List of table names to include in the output. Additional
674
+ tables may be included if they're needed to join the requested tables.
343
675
 
344
676
  Returns:
345
- str: SQL query string that represents the process of denormalization.
677
+ Select: A SQLAlchemy query that produces the denormalized result.
678
+
679
+ Note:
680
+ Column names in the result are prefixed with the table name to avoid
681
+ collisions (e.g., "Image.Filename", "Subject.RID").
346
682
  """
347
683
  # Skip over tables that we don't want to include in the denormalized dataset.
348
684
  # Also, strip off the Dataset/Dataset_X part of the path so we don't include dataset columns in the denormalized
@@ -359,9 +695,7 @@ class DatasetBag:
359
695
  return relationship
360
696
  return None
361
697
 
362
- join_tables, denormalized_columns = (
363
- self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
364
- )
698
+ join_tables, denormalized_columns = self.model._prepare_wide_table(self, self.dataset_rid, include_tables)
365
699
 
366
700
  denormalized_columns = [
367
701
  self.model.get_orm_class_by_name(table_name)
@@ -382,69 +716,864 @@ class DatasetBag:
382
716
  if (r := find_relationship(table_class, on_condition))
383
717
  ]
384
718
  sql_statement = sql_statement.join(table_class, onclause=and_(*on_clause))
385
- dataset_rid_list = [self.dataset_rid] + self.list_dataset_children(recurse=True)
719
+ dataset_rid_list = [self.dataset_rid] + [c.dataset_rid for c in self.list_dataset_children(recurse=True)]
386
720
  dataset_class = self.model.get_orm_class_by_name(self._dataset_table.name)
387
721
  sql_statement = sql_statement.where(dataset_class.RID.in_(dataset_rid_list))
388
722
  sql_statements.append(sql_statement)
389
723
  return union(*sql_statements)
390
724
 
391
- def denormalize_as_dataframe(self, include_tables: list[str]) -> pd.DataFrame:
725
+ def _denormalize_from_members(
726
+ self,
727
+ include_tables: list[str],
728
+ ) -> Generator[dict[str, Any], None, None]:
729
+ """Denormalize dataset members by joining related tables.
730
+
731
+ This method creates a "wide table" view by joining related tables together,
732
+ using list_dataset_members() as the data source. This ensures consistency
733
+ with the catalog-based denormalize implementation. The result has outer join
734
+ semantics - tables without FK relationships are included with NULL values.
735
+
736
+ The method:
737
+ 1. Gets the list of dataset members for each included table via list_dataset_members
738
+ 2. For each member in the first table, follows foreign key relationships to
739
+ get related records from other tables
740
+ 3. Tables without FK connections to the first table are included with NULLs
741
+ 4. Includes nested dataset members recursively
742
+
743
+ Args:
744
+ include_tables: List of table names to include in the output.
745
+
746
+ Yields:
747
+ dict[str, Any]: Rows with column names prefixed by table name (e.g., "Image.Filename").
748
+ Unrelated tables have NULL values for their columns.
749
+
750
+ Note:
751
+ Column names in the result are prefixed with the table name to avoid
752
+ collisions (e.g., "Image.Filename", "Subject.RID").
392
753
  """
393
- Denormalize the dataset and return the result as a dataframe.
754
+ # Skip system columns in output
755
+ skip_columns = {"RCT", "RMT", "RCB", "RMB"}
756
+
757
+ # Get all members for the included tables (recursively includes nested datasets)
758
+ members = self.list_dataset_members(recurse=True)
759
+
760
+ # Build a lookup of columns for each table
761
+ table_columns: dict[str, list[str]] = {}
762
+ for table_name in include_tables:
763
+ table = self.model.name_to_table(table_name)
764
+ table_columns[table_name] = [
765
+ c.name for c in table.columns if c.name not in skip_columns
766
+ ]
767
+
768
+ # Find the primary table (first non-empty table in include_tables)
769
+ primary_table = None
770
+ for table_name in include_tables:
771
+ if table_name in members and members[table_name]:
772
+ primary_table = table_name
773
+ break
774
+
775
+ if primary_table is None:
776
+ # No data at all
777
+ return
778
+
779
+ primary_table_obj = self.model.name_to_table(primary_table)
780
+
781
+ for member in members[primary_table]:
782
+ # Build the row with all columns from all tables
783
+ row: dict[str, Any] = {}
784
+
785
+ # Add primary table columns
786
+ for col_name in table_columns[primary_table]:
787
+ prefixed_name = f"{primary_table}.{col_name}"
788
+ row[prefixed_name] = member.get(col_name)
789
+
790
+ # For each other table, try to join or add NULL values
791
+ for other_table_name in include_tables:
792
+ if other_table_name == primary_table:
793
+ continue
794
+
795
+ other_table = self.model.name_to_table(other_table_name)
796
+ other_cols = table_columns[other_table_name]
797
+
798
+ # Initialize all columns to None (outer join behavior)
799
+ for col_name in other_cols:
800
+ prefixed_name = f"{other_table_name}.{col_name}"
801
+ row[prefixed_name] = None
802
+
803
+ # Try to find FK relationship and join
804
+ if other_table_name in members:
805
+ try:
806
+ relationship = self.model._table_relationship(
807
+ primary_table_obj, other_table
808
+ )
809
+ fk_col, pk_col = relationship
810
+
811
+ # Look up the related record
812
+ fk_value = member.get(fk_col.name)
813
+ if fk_value:
814
+ for other_member in members.get(other_table_name, []):
815
+ if other_member.get(pk_col.name) == fk_value:
816
+ for col_name in other_cols:
817
+ prefixed_name = f"{other_table_name}.{col_name}"
818
+ row[prefixed_name] = other_member.get(col_name)
819
+ break
820
+ except DerivaMLException:
821
+ # No FK relationship - columns remain NULL (outer join)
822
+ pass
823
+
824
+ yield row
825
+
826
+ def denormalize_as_dataframe(
827
+ self,
828
+ include_tables: list[str],
829
+ version: Any = None,
830
+ **kwargs: Any,
831
+ ) -> pd.DataFrame:
832
+ """Denormalize the dataset bag into a single wide table (DataFrame).
833
+
834
+ Denormalization transforms normalized relational data into a single "wide table"
835
+ (also called a "flat table" or "denormalized table") by joining related tables
836
+ together. This produces a DataFrame where each row contains all related information
837
+ from multiple source tables, with columns from each table combined side-by-side.
838
+
839
+ Wide tables are the standard input format for most machine learning frameworks,
840
+ which expect all features for a single observation to be in one row. This method
841
+ bridges the gap between normalized database schemas and ML-ready tabular data.
842
+
843
+ **How it works:**
844
+
845
+ Tables are joined based on their foreign key relationships stored in the bag's
846
+ schema. For example, if Image has a foreign key to Subject, denormalizing
847
+ ["Subject", "Image"] produces rows where each image appears with its subject's
848
+ metadata.
849
+
850
+ **Column naming:**
851
+
852
+ Column names are prefixed with the source table name using dots to avoid
853
+ collisions (e.g., "Image.Filename", "Subject.RID"). This differs from the
854
+ live Dataset class which uses underscores.
855
+
856
+ Args:
857
+ include_tables: List of table names to include in the output. Tables
858
+ are joined based on their foreign key relationships.
859
+ Order doesn't matter - the join order is determined automatically.
860
+ version: Ignored (bags are immutable snapshots of a specific version).
861
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
862
+
863
+ Returns:
864
+ pd.DataFrame: Wide table with columns from all included tables.
865
+
866
+ Example:
867
+ Create a training dataset from a downloaded bag::
868
+
869
+ >>> # Download and materialize the dataset
870
+ >>> bag = ml.download_dataset_bag(spec, materialize=True)
871
+
872
+ >>> # Denormalize into a wide table
873
+ >>> df = bag.denormalize_as_dataframe(["Image", "Diagnosis"])
874
+ >>> print(df.columns.tolist())
875
+ ['Image.RID', 'Image.Filename', 'Image.URL', 'Diagnosis.RID',
876
+ 'Diagnosis.Label', 'Diagnosis.Confidence']
877
+
878
+ >>> # Access local file paths for images
879
+ >>> for _, row in df.iterrows():
880
+ ... local_path = bag.get_asset_path("Image", row["Image.RID"])
881
+ ... label = row["Diagnosis.Label"]
882
+ ... # Train on local_path with label
883
+
884
+ See Also:
885
+ denormalize_as_dict: Generator version for memory-efficient processing.
886
+ """
887
+ rows = list(self._denormalize_from_members(include_tables=include_tables))
888
+ return pd.DataFrame(rows)
889
+
890
+ def denormalize_as_dict(
891
+ self,
892
+ include_tables: list[str],
893
+ version: Any = None,
894
+ **kwargs: Any,
895
+ ) -> Generator[dict[str, Any], None, None]:
896
+ """Denormalize the dataset bag and yield rows as dictionaries.
897
+
898
+ This is a memory-efficient alternative to denormalize_as_dataframe() that
899
+ yields one row at a time as a dictionary instead of loading all data into
900
+ a DataFrame. Use this when processing large datasets that may not fit in
901
+ memory, or when you want to process rows incrementally.
394
902
 
395
- This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
396
- the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
397
- for each row in the denormalized wide table.
903
+ Like denormalize_as_dataframe(), this produces a "wide table" representation
904
+ where each yielded dictionary contains all columns from the joined tables.
905
+ See denormalize_as_dataframe() for detailed explanation of how denormalization
906
+ works.
398
907
 
399
- The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
400
- view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
401
- additional tables are required to complete the denormalization process. If include_tables is not specified,
402
- all of the tables in the schema will be included.
908
+ **Column naming:**
403
909
 
404
- The resulting wide table will include a column for every table needed to complete the denormalization process.
910
+ Column names are prefixed with the source table name using dots to avoid
911
+ collisions (e.g., "Image.Filename", "Subject.RID"). This differs from the
912
+ live Dataset class which uses underscores.
405
913
 
406
914
  Args:
407
- include_tables: List of table names to include in the denormalized dataset.
915
+ include_tables: List of table names to include in the output.
916
+ Tables are joined based on their foreign key relationships.
917
+ version: Ignored (bags are immutable snapshots of a specific version).
918
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
919
+
920
+ Yields:
921
+ dict[str, Any]: Dictionary representing one row of the wide table.
922
+ Keys are column names in "Table.Column" format.
923
+
924
+ Example:
925
+ Stream through a large dataset for training::
926
+
927
+ >>> bag = ml.download_dataset_bag(spec, materialize=True)
928
+ >>> for row in bag.denormalize_as_dict(["Image", "Diagnosis"]):
929
+ ... # Get local file path for this image
930
+ ... local_path = bag.get_asset_path("Image", row["Image.RID"])
931
+ ... label = row["Diagnosis.Label"]
932
+ ... # Process image and label...
933
+
934
+ Build a PyTorch dataset efficiently::
935
+
936
+ >>> class BagDataset(torch.utils.data.IterableDataset):
937
+ ... def __init__(self, bag, tables):
938
+ ... self.bag = bag
939
+ ... self.tables = tables
940
+ ... def __iter__(self):
941
+ ... for row in self.bag.denormalize_as_dict(self.tables):
942
+ ... img_path = self.bag.get_asset_path("Image", row["Image.RID"])
943
+ ... yield load_image(img_path), row["Diagnosis.Label"]
944
+
945
+ See Also:
946
+ denormalize_as_dataframe: Returns all data as a pandas DataFrame.
947
+ """
948
+ yield from self._denormalize_from_members(include_tables=include_tables)
949
+
950
+
951
+ # =========================================================================
952
+ # Asset Restructuring Methods
953
+ # =========================================================================
954
+
955
+ def _build_dataset_type_path_map(
956
+ self,
957
+ type_selector: Callable[[list[str]], str] | None = None,
958
+ ) -> dict[RID, list[str]]:
959
+ """Build a mapping from dataset RID to its type path in the hierarchy.
960
+
961
+ Recursively traverses nested datasets to create a mapping where each
962
+ dataset RID maps to its hierarchical type path (e.g., ["complete", "training"]).
963
+
964
+ Args:
965
+ type_selector: Function to select type when dataset has multiple types.
966
+ Receives list of type names, returns selected type name.
967
+ Defaults to selecting first type or "unknown" if no types.
408
968
 
409
969
  Returns:
410
- Dataframe containing the denormalized dataset.
970
+ Dictionary mapping dataset RID to list of type names from root to leaf.
971
+ e.g., {"4-ABC": ["complete", "training"], "4-DEF": ["complete", "testing"]}
411
972
  """
412
- return pd.read_sql(self._denormalize(include_tables=include_tables), self.engine)
973
+ if type_selector is None:
974
+ type_selector = lambda types: types[0] if types else "Testing"
975
+
976
+ type_paths: dict[RID, list[str]] = {}
977
+
978
+ def traverse(dataset: DatasetBag, parent_path: list[str], visited: set[RID]) -> None:
979
+ if dataset.dataset_rid in visited:
980
+ return
981
+ visited.add(dataset.dataset_rid)
982
+
983
+ current_type = type_selector(dataset.dataset_types)
984
+ current_path = parent_path + [current_type]
985
+ type_paths[dataset.dataset_rid] = current_path
986
+
987
+ for child in dataset.list_dataset_children():
988
+ traverse(child, current_path, visited)
989
+
990
+ traverse(self, [], set())
991
+ return type_paths
992
+
993
+ def _get_asset_dataset_mapping(self, asset_table: str) -> dict[RID, RID]:
994
+ """Map asset RIDs to their containing dataset RID.
995
+
996
+ For each asset in the specified table, determines which dataset it belongs to.
997
+ This uses _dataset_table_view to find assets reachable through any FK path
998
+ from the dataset, not just directly associated assets.
413
999
 
414
- def denormalize_as_dict(self, include_tables: list[str]) -> Generator[RowMapping, None, None]:
1000
+ Assets are mapped to their most specific (leaf) dataset in the hierarchy.
1001
+ For example, if a Split dataset contains Training and Testing children,
1002
+ and images are members of Training, the images map to Training (not Split).
1003
+
1004
+ Args:
1005
+ asset_table: Name of the asset table (e.g., "Image")
1006
+
1007
+ Returns:
1008
+ Dictionary mapping asset RID to the dataset RID that contains it.
415
1009
  """
416
- Denormalize the dataset and return the result as a set of dictionary's.
1010
+ asset_to_dataset: dict[RID, RID] = {}
1011
+
1012
+ def collect_from_dataset(dataset: DatasetBag, visited: set[RID]) -> None:
1013
+ if dataset.dataset_rid in visited:
1014
+ return
1015
+ visited.add(dataset.dataset_rid)
1016
+
1017
+ # Process children FIRST (depth-first) so leaf datasets get priority
1018
+ # This ensures assets are mapped to their most specific dataset
1019
+ for child in dataset.list_dataset_children():
1020
+ collect_from_dataset(child, visited)
1021
+
1022
+ # Then process this dataset's assets
1023
+ # Only set if not already mapped (child/leaf dataset wins)
1024
+ for asset in dataset._get_reachable_assets(asset_table):
1025
+ if asset["RID"] not in asset_to_dataset:
1026
+ asset_to_dataset[asset["RID"]] = dataset.dataset_rid
417
1027
 
418
- This routine will examine the domain schema for the dataset, determine which tables to include and denormalize
419
- the dataset values into a single wide table. The result is returned as a generator that returns a dictionary
420
- for each row in the denormalized wide table.
1028
+ collect_from_dataset(self, set())
1029
+ return asset_to_dataset
421
1030
 
422
- The optional argument include_tables can be used to specify a subset of tables to include in the denormalized
423
- view. The tables in this argument can appear anywhere in the dataset schema. The method will determine which
424
- additional tables are required to complete the denormalization process. If include_tables is not specified,
425
- all of the tables in the schema will be included.
1031
+ def _get_reachable_assets(self, asset_table: str) -> list[dict[str, Any]]:
1032
+ """Get all assets reachable from this dataset through any FK path.
426
1033
 
427
- The resulting wide table will include a only those column for the tables listed in include_columns.
1034
+ Unlike list_dataset_members which only returns directly associated entities,
1035
+ this method traverses foreign key relationships to find assets that are
1036
+ indirectly connected to the dataset. For example, if a dataset contains
1037
+ Subjects, and Subject -> Encounter -> Image, this method will find those
1038
+ Images even though they're not directly in the Dataset_Image association table.
428
1039
 
429
1040
  Args:
430
- include_tables: List of table names to include in the denormalized dataset. If None, than the entire schema
431
- is used.
1041
+ asset_table: Name of the asset table (e.g., "Image")
432
1042
 
433
1043
  Returns:
434
- A generator that returns a dictionary representation of each row in the denormalized dataset.
1044
+ List of asset records as dictionaries.
435
1045
  """
1046
+ # Use the _dataset_table_view query which traverses all FK paths
1047
+ sql_query = self._dataset_table_view(asset_table)
1048
+
436
1049
  with Session(self.engine) as session:
437
- cursor = session.execute(
438
- self._denormalize(include_tables=include_tables)
439
- )
440
- yield from cursor.mappings()
441
- for row in cursor.mappings():
442
- yield row
1050
+ result = session.execute(sql_query)
1051
+ # Convert rows to dictionaries
1052
+ rows = [dict(row._mapping) for row in result]
1053
+
1054
+ return rows
1055
+
1056
+ def _load_feature_values_cache(
1057
+ self,
1058
+ asset_table: str,
1059
+ group_keys: list[str],
1060
+ enforce_vocabulary: bool = True,
1061
+ value_selector: Callable[[list[FeatureValueRecord]], FeatureValueRecord] | None = None,
1062
+ ) -> dict[str, dict[RID, Any]]:
1063
+ """Load feature values into a cache for efficient lookup.
1064
+
1065
+ Pre-loads feature values for any group_keys that are feature names,
1066
+ organizing them by target entity RID for fast lookup.
1067
+
1068
+ Args:
1069
+ asset_table: The asset table name to find features for.
1070
+ group_keys: List of potential feature names to cache. Supports two formats:
1071
+ - "FeatureName": Uses the first term column (default behavior)
1072
+ - "FeatureName.column_name": Uses the specified column from the feature table
1073
+ enforce_vocabulary: If True (default), only allow features with
1074
+ controlled vocabulary term columns and raise an error if an
1075
+ asset has multiple values. If False, allow any feature type
1076
+ and use the first value found when multiple exist.
1077
+ value_selector: Optional function to select which feature value to use
1078
+ when an asset has multiple values for the same feature. Receives a
1079
+ list of FeatureValueRecord objects (each with execution_rid for
1080
+ provenance) and returns the selected one. If not provided and
1081
+ multiple values exist, raises DerivaMLException when
1082
+ enforce_vocabulary=True or uses the first value when False.
1083
+
1084
+ Returns:
1085
+ Dictionary mapping group_key -> {target_rid -> feature_value}
1086
+ Only includes entries for keys that are actually features.
1087
+
1088
+ Raises:
1089
+ DerivaMLException: If enforce_vocabulary is True and:
1090
+ - A feature has no term columns (not vocabulary-based), or
1091
+ - An asset has multiple different vocabulary term values for the same feature
1092
+ and no value_selector is provided.
1093
+ """
1094
+ from deriva_ml.core.exceptions import DerivaMLException
1095
+
1096
+ cache: dict[str, dict[RID, Any]] = {}
1097
+ # Store all feature value records for later selection when there are multiples
1098
+ records_cache: dict[str, dict[RID, list[FeatureValueRecord]]] = {}
1099
+ logger = logging.getLogger("deriva_ml")
1100
+
1101
+ # Parse group_keys to extract feature names and optional column specifications
1102
+ # Format: "FeatureName" or "FeatureName.column_name"
1103
+ feature_column_map: dict[str, str | None] = {} # group_key -> specific column or None
1104
+ feature_names_to_check: set[str] = set()
1105
+ for key in group_keys:
1106
+ if "." in key:
1107
+ parts = key.split(".", 1)
1108
+ feature_name = parts[0]
1109
+ column_name = parts[1]
1110
+ feature_column_map[key] = column_name
1111
+ feature_names_to_check.add(feature_name)
1112
+ else:
1113
+ feature_column_map[key] = None
1114
+ feature_names_to_check.add(key)
1115
+
1116
+ def process_feature(feat: Any, table_name: str, group_key: str, specific_column: str | None) -> None:
1117
+ """Process a single feature and add its values to the cache."""
1118
+ term_cols = [c.name for c in feat.term_columns]
1119
+ value_cols = [c.name for c in feat.value_columns]
1120
+ all_cols = term_cols + value_cols
1121
+
1122
+ # Determine which column to use for the value
1123
+ if specific_column:
1124
+ # User specified a specific column
1125
+ if specific_column not in all_cols:
1126
+ raise DerivaMLException(
1127
+ f"Column '{specific_column}' not found in feature '{feat.feature_name}'. "
1128
+ f"Available columns: {all_cols}"
1129
+ )
1130
+ use_column = specific_column
1131
+ elif term_cols:
1132
+ # Use first term column (default behavior)
1133
+ use_column = term_cols[0]
1134
+ elif not enforce_vocabulary and value_cols:
1135
+ # Fall back to value columns if allowed
1136
+ use_column = value_cols[0]
1137
+ else:
1138
+ if enforce_vocabulary:
1139
+ raise DerivaMLException(
1140
+ f"Feature '{feat.feature_name}' on table '{table_name}' has no "
1141
+ f"controlled vocabulary term columns. Only vocabulary-based features "
1142
+ f"can be used for grouping when enforce_vocabulary=True. "
1143
+ f"Set enforce_vocabulary=False to allow non-vocabulary features."
1144
+ )
1145
+ return
1146
+
1147
+ records_cache[group_key] = defaultdict(list)
1148
+ feature_values = self.list_feature_values(table_name, feat.feature_name)
1149
+
1150
+ for fv in feature_values:
1151
+ # Convert FeatureRecord to dict for easier access
1152
+ fv_dict = fv.model_dump()
1153
+ target_col = table_name
1154
+ if target_col not in fv_dict:
1155
+ continue
1156
+
1157
+ target_rid = fv_dict[target_col]
1158
+
1159
+ # Get the value from the specified column
1160
+ value = fv_dict.get(use_column) if use_column in fv_dict else None
1161
+
1162
+ if value is None:
1163
+ continue
1164
+
1165
+ # Create a FeatureValueRecord with execution provenance
1166
+ record = FeatureValueRecord(
1167
+ target_rid=target_rid,
1168
+ feature_name=feat.feature_name,
1169
+ value=value,
1170
+ execution_rid=fv_dict.get("Execution"),
1171
+ raw_record=fv_dict,
1172
+ )
1173
+ records_cache[group_key][target_rid].append(record)
1174
+
1175
+ # Find all features on tables that this asset table references
1176
+ asset_table_obj = self.model.name_to_table(asset_table)
1177
+
1178
+ # Check features on the asset table itself
1179
+ for feature in self.find_features(asset_table):
1180
+ if feature.feature_name in feature_names_to_check:
1181
+ # Find all group_keys that reference this feature
1182
+ for group_key, specific_col in feature_column_map.items():
1183
+ # Check if this group_key references this feature
1184
+ key_feature = group_key.split(".")[0] if "." in group_key else group_key
1185
+ if key_feature == feature.feature_name:
1186
+ try:
1187
+ process_feature(feature, asset_table, group_key, specific_col)
1188
+ except DerivaMLException:
1189
+ raise
1190
+ except Exception as e:
1191
+ logger.warning(f"Could not load feature {feature.feature_name}: {e}")
1192
+
1193
+ # Also check features on referenced tables (via foreign keys)
1194
+ for fk in asset_table_obj.foreign_keys:
1195
+ target_table = fk.pk_table
1196
+ for feature in self.find_features(target_table):
1197
+ if feature.feature_name in feature_names_to_check:
1198
+ # Find all group_keys that reference this feature
1199
+ for group_key, specific_col in feature_column_map.items():
1200
+ # Check if this group_key references this feature
1201
+ key_feature = group_key.split(".")[0] if "." in group_key else group_key
1202
+ if key_feature == feature.feature_name:
1203
+ try:
1204
+ process_feature(feature, target_table.name, group_key, specific_col)
1205
+ except DerivaMLException:
1206
+ raise
1207
+ except Exception as e:
1208
+ logger.warning(f"Could not load feature {feature.feature_name}: {e}")
1209
+
1210
+ # Now resolve multiple values using value_selector or error handling
1211
+ for group_key, target_records in records_cache.items():
1212
+ cache[group_key] = {}
1213
+ for target_rid, records in target_records.items():
1214
+ if len(records) == 1:
1215
+ # Single value - straightforward
1216
+ cache[group_key][target_rid] = records[0].value
1217
+ elif len(records) > 1:
1218
+ # Multiple values - need to resolve
1219
+ unique_values = set(r.value for r in records)
1220
+ if len(unique_values) == 1:
1221
+ # All records have same value, use it
1222
+ cache[group_key][target_rid] = records[0].value
1223
+ elif value_selector:
1224
+ # Use provided selector function
1225
+ selected = value_selector(records)
1226
+ cache[group_key][target_rid] = selected.value
1227
+ elif enforce_vocabulary:
1228
+ # Multiple different values without selector - error
1229
+ values_str = ", ".join(f"'{r.value}' (exec: {r.execution_rid})" for r in records)
1230
+ raise DerivaMLException(
1231
+ f"Asset '{target_rid}' has multiple different values for "
1232
+ f"feature '{records[0].feature_name}': {values_str}. "
1233
+ f"Provide a value_selector function to choose between values, "
1234
+ f"or set enforce_vocabulary=False to use the first value."
1235
+ )
1236
+ else:
1237
+ # Not enforcing - use first value
1238
+ cache[group_key][target_rid] = records[0].value
1239
+
1240
+ return cache
1241
+
1242
+ def _resolve_grouping_value(
1243
+ self,
1244
+ asset: dict[str, Any],
1245
+ group_key: str,
1246
+ feature_cache: dict[str, dict[RID, Any]],
1247
+ ) -> str:
1248
+ """Resolve a grouping value for an asset.
1249
+
1250
+ First checks if group_key is a direct column on the asset record,
1251
+ then checks if it's a feature name in the feature cache.
1252
+
1253
+ Args:
1254
+ asset: The asset record dictionary.
1255
+ group_key: Column name or feature name to group by.
1256
+ feature_cache: Pre-loaded feature values keyed by feature name -> target RID -> value.
1257
+
1258
+ Returns:
1259
+ The resolved value as a string, or "Unknown" if not found or None.
1260
+ Uses "Unknown" (capitalized) to match vocabulary term naming conventions.
1261
+ """
1262
+ # First check if it's a direct column on the asset table
1263
+ if group_key in asset:
1264
+ value = asset[group_key]
1265
+ if value is not None:
1266
+ return str(value)
1267
+ return "Unknown"
1268
+
1269
+ # Check if it's a feature name
1270
+ if group_key in feature_cache:
1271
+ feature_values = feature_cache[group_key]
1272
+ # Check each column in the asset that might be a FK to the feature target
1273
+ for column_name, column_value in asset.items():
1274
+ if column_value and column_value in feature_values:
1275
+ return str(feature_values[column_value])
1276
+ # Also check if the asset's own RID is in the feature values
1277
+ if asset.get("RID") in feature_values:
1278
+ return str(feature_values[asset["RID"]])
1279
+
1280
+ return "Unknown"
1281
+
1282
+ def _detect_asset_table(self) -> str | None:
1283
+ """Auto-detect the asset table from dataset members.
1284
+
1285
+ Searches for asset tables in the dataset members by examining
1286
+ the schema. Returns the first asset table found, or None if
1287
+ no asset tables are in the dataset.
1288
+
1289
+ Returns:
1290
+ Name of the detected asset table, or None if not found.
1291
+ """
1292
+ members = self.list_dataset_members(recurse=True)
1293
+ for table_name in members:
1294
+ if table_name == "Dataset":
1295
+ continue
1296
+ # Check if this table is an asset table
1297
+ try:
1298
+ table = self.model.name_to_table(table_name)
1299
+ if self.model.is_asset(table):
1300
+ return table_name
1301
+ except (KeyError, AttributeError):
1302
+ continue
1303
+ return None
1304
+
1305
+ def _validate_dataset_types(self) -> list[str] | None:
1306
+ """Validate that the dataset or its children have Training/Testing types.
1307
+
1308
+ Checks if this dataset is of type Training or Testing, or if it has
1309
+ nested children of those types. Returns the valid types found.
1310
+
1311
+ Returns:
1312
+ List of Training/Testing type names found, or None if validation fails.
1313
+ """
1314
+ valid_types = {"Training", "Testing"}
1315
+ found_types: set[str] = set()
1316
+
1317
+ def check_dataset(ds: DatasetBag, visited: set[RID]) -> None:
1318
+ if ds.dataset_rid in visited:
1319
+ return
1320
+ visited.add(ds.dataset_rid)
1321
+
1322
+ for dtype in ds.dataset_types:
1323
+ if dtype in valid_types:
1324
+ found_types.add(dtype)
1325
+
1326
+ for child in ds.list_dataset_children():
1327
+ check_dataset(child, visited)
1328
+
1329
+ check_dataset(self, set())
1330
+ return list(found_types) if found_types else None
1331
+
1332
+ def restructure_assets(
1333
+ self,
1334
+ output_dir: Path | str,
1335
+ asset_table: str | None = None,
1336
+ group_by: list[str] | None = None,
1337
+ use_symlinks: bool = True,
1338
+ type_selector: Callable[[list[str]], str] | None = None,
1339
+ type_to_dir_map: dict[str, str] | None = None,
1340
+ enforce_vocabulary: bool = True,
1341
+ value_selector: Callable[[list[FeatureValueRecord]], FeatureValueRecord] | None = None,
1342
+ ) -> Path:
1343
+ """Restructure downloaded assets into a directory hierarchy.
1344
+
1345
+ Creates a directory structure organizing assets by dataset types and
1346
+ grouping values. This is useful for ML workflows that expect data
1347
+ organized in conventional folder structures (e.g., PyTorch ImageFolder).
1348
+
1349
+ The dataset should be of type Training or Testing, or have nested
1350
+ children of those types. The top-level directory name is determined
1351
+ by the dataset type (e.g., "Training" -> "training").
1352
+
1353
+ **Finding assets through foreign key relationships:**
1354
+
1355
+ Assets are found by traversing all foreign key paths from the dataset,
1356
+ not just direct associations. For example, if a dataset contains Subjects,
1357
+ and the schema has Subject -> Encounter -> Image relationships, this method
1358
+ will find all Images reachable through those paths even though they are
1359
+ not directly in a Dataset_Image association table.
1360
+
1361
+ **Handling datasets without types (prediction scenarios):**
1362
+
1363
+ If a dataset has no type defined, it is treated as Testing. This is
1364
+ common for prediction/inference scenarios where you want to apply a
1365
+ trained model to new unlabeled data.
1366
+
1367
+ **Handling missing labels:**
1368
+
1369
+ If an asset doesn't have a value for a group_by key (e.g., no label
1370
+ assigned), it is placed in an "Unknown" directory. This allows
1371
+ restructure_assets to work with unlabeled data for prediction.
1372
+
1373
+ Args:
1374
+ output_dir: Base directory for restructured assets.
1375
+ asset_table: Name of the asset table (e.g., "Image"). If None,
1376
+ auto-detects from dataset members. Raises DerivaMLException
1377
+ if multiple asset tables are found and none is specified.
1378
+ group_by: Names to group assets by. Each name creates a subdirectory
1379
+ level after the dataset type path. Names can be:
1380
+
1381
+ - **Column names**: Direct columns on the asset table. The column
1382
+ value becomes the subdirectory name.
1383
+ - **Feature names**: Features defined on the asset table (or tables
1384
+ it references via foreign keys). The feature's vocabulary term
1385
+ value becomes the subdirectory name.
1386
+ - **Feature.column**: Specify a particular column from a multi-term
1387
+ feature (e.g., "Classification.Label" to use the Label column).
1388
+
1389
+ Column names are checked first, then feature names. If a value
1390
+ is not found, "unknown" is used as the subdirectory name.
1391
+
1392
+ use_symlinks: If True (default), create symlinks to original files.
1393
+ If False, copy files. Symlinks save disk space but require
1394
+ the original bag to remain in place.
1395
+ type_selector: Function to select type when dataset has multiple types.
1396
+ Receives list of type names, returns selected type name.
1397
+ Defaults to selecting first type or "unknown" if no types.
1398
+ type_to_dir_map: Optional mapping from dataset type names to directory
1399
+ names. Defaults to {"Training": "training", "Testing": "testing",
1400
+ "Unknown": "unknown"}. Use this to customize directory names or
1401
+ add new type mappings.
1402
+ enforce_vocabulary: If True (default), only allow features that have
1403
+ controlled vocabulary term columns, and raise an error if an asset
1404
+ has multiple different values for the same feature without a
1405
+ value_selector. This ensures clean, unambiguous directory structures.
1406
+ If False, allow any feature type and use the first value found
1407
+ when multiple values exist.
1408
+ value_selector: Optional function to select which feature value to use
1409
+ when an asset has multiple values for the same feature. Receives a
1410
+ list of FeatureValueRecord objects (each containing target_rid,
1411
+ feature_name, value, execution_rid, and raw_record) and returns
1412
+ the selected FeatureValueRecord. Use execution_rid to distinguish
1413
+ between values from different executions.
1414
+
1415
+ Returns:
1416
+ Path to the output directory.
1417
+
1418
+ Raises:
1419
+ DerivaMLException: If asset_table cannot be determined (multiple
1420
+ asset tables exist without specification), if no valid dataset
1421
+ types (Training/Testing) are found, or if enforce_vocabulary
1422
+ is True and a feature has multiple values without value_selector.
1423
+
1424
+ Examples:
1425
+ Basic restructuring with auto-detected asset table::
1426
+
1427
+ bag.restructure_assets(
1428
+ output_dir="./ml_data",
1429
+ group_by=["Diagnosis"],
1430
+ )
1431
+ # Creates:
1432
+ # ./ml_data/training/Normal/image1.jpg
1433
+ # ./ml_data/testing/Abnormal/image2.jpg
1434
+
1435
+ Custom type-to-directory mapping::
1436
+
1437
+ bag.restructure_assets(
1438
+ output_dir="./ml_data",
1439
+ group_by=["Diagnosis"],
1440
+ type_to_dir_map={"Training": "train", "Testing": "test"},
1441
+ )
1442
+ # Creates:
1443
+ # ./ml_data/train/Normal/image1.jpg
1444
+ # ./ml_data/test/Abnormal/image2.jpg
1445
+
1446
+ Select specific feature column for multi-term features::
1447
+
1448
+ bag.restructure_assets(
1449
+ output_dir="./ml_data",
1450
+ group_by=["Classification.Label"], # Use Label column
1451
+ )
1452
+
1453
+ Handle multiple feature values with a selector::
1454
+
1455
+ def select_latest(records: list[FeatureValueRecord]) -> FeatureValueRecord:
1456
+ # Select value from most recent execution
1457
+ return max(records, key=lambda r: r.execution_rid or "")
1458
+
1459
+ bag.restructure_assets(
1460
+ output_dir="./ml_data",
1461
+ group_by=["Diagnosis"],
1462
+ value_selector=select_latest,
1463
+ )
1464
+
1465
+ Prediction scenario with unlabeled data::
1466
+
1467
+ # Dataset has no type - treated as Testing
1468
+ # Assets have no labels - placed in Unknown directory
1469
+ bag.restructure_assets(
1470
+ output_dir="./prediction_data",
1471
+ group_by=["Diagnosis"],
1472
+ )
1473
+ # Creates:
1474
+ # ./prediction_data/testing/Unknown/image1.jpg
1475
+ # ./prediction_data/testing/Unknown/image2.jpg
1476
+ """
1477
+ logger = logging.getLogger("deriva_ml")
1478
+ group_by = group_by or []
1479
+ output_dir = Path(output_dir)
1480
+ output_dir.mkdir(parents=True, exist_ok=True)
1481
+
1482
+ # Default type-to-directory mapping
1483
+ if type_to_dir_map is None:
1484
+ type_to_dir_map = {"Training": "training", "Testing": "testing", "Unknown": "unknown"}
1485
+
1486
+ # Auto-detect asset table if not provided
1487
+ if asset_table is None:
1488
+ asset_table = self._detect_asset_table()
1489
+ if asset_table is None:
1490
+ raise DerivaMLException(
1491
+ "Could not auto-detect asset table. No asset tables found in dataset members. "
1492
+ "Specify the asset_table parameter explicitly."
1493
+ )
1494
+ logger.info(f"Auto-detected asset table: {asset_table}")
1495
+
1496
+ # Step 1: Build dataset type path map with directory name mapping
1497
+ def map_type_to_dir(types: list[str]) -> str:
1498
+ """Map dataset types to directory name using type_to_dir_map.
1499
+
1500
+ If dataset has no types, treat it as Testing (prediction use case).
1501
+ """
1502
+ if not types:
1503
+ # No types defined - treat as Testing for prediction scenarios
1504
+ return type_to_dir_map.get("Testing", "testing")
1505
+ if type_selector:
1506
+ selected_type = type_selector(types)
1507
+ else:
1508
+ selected_type = types[0]
1509
+ return type_to_dir_map.get(selected_type, selected_type.lower())
1510
+
1511
+ type_path_map = self._build_dataset_type_path_map(map_type_to_dir)
1512
+
1513
+ # Step 2: Get asset-to-dataset mapping
1514
+ asset_dataset_map = self._get_asset_dataset_mapping(asset_table)
1515
+
1516
+ # Step 3: Load feature values cache for relevant features
1517
+ feature_cache = self._load_feature_values_cache(
1518
+ asset_table, group_by, enforce_vocabulary, value_selector
1519
+ )
1520
+
1521
+ # Step 4: Get all assets reachable through FK paths
1522
+ # This uses _get_reachable_assets which traverses FK relationships,
1523
+ # so assets connected via Subject -> Encounter -> Image are found
1524
+ # even if the dataset only contains Subjects directly.
1525
+ assets = self._get_reachable_assets(asset_table)
1526
+
1527
+ if not assets:
1528
+ logger.warning(f"No assets found in table '{asset_table}'")
1529
+ return output_dir
1530
+
1531
+ # Step 5: Process each asset
1532
+ for asset in assets:
1533
+ # Get source file path
1534
+ filename = asset.get("Filename")
1535
+ if not filename:
1536
+ logger.warning(f"Asset {asset.get('RID')} has no Filename")
1537
+ continue
1538
+
1539
+ source_path = Path(filename)
1540
+ if not source_path.exists():
1541
+ logger.warning(f"Asset file not found: {source_path}")
1542
+ continue
1543
+
1544
+ # Get dataset type path
1545
+ dataset_rid = asset_dataset_map.get(asset["RID"])
1546
+ type_path = type_path_map.get(dataset_rid, ["unknown"])
1547
+
1548
+ # Resolve grouping values
1549
+ group_path = []
1550
+ for key in group_by:
1551
+ value = self._resolve_grouping_value(asset, key, feature_cache)
1552
+ group_path.append(value)
1553
+
1554
+ # Build target directory
1555
+ target_dir = output_dir.joinpath(*type_path, *group_path)
1556
+ target_dir.mkdir(parents=True, exist_ok=True)
1557
+
1558
+ # Create link or copy
1559
+ target_path = target_dir / source_path.name
1560
+
1561
+ # Handle existing files
1562
+ if target_path.exists() or target_path.is_symlink():
1563
+ target_path.unlink()
1564
+
1565
+ if use_symlinks:
1566
+ try:
1567
+ target_path.symlink_to(source_path.resolve())
1568
+ except OSError as e:
1569
+ # Fall back to copy on platforms that don't support symlinks
1570
+ logger.warning(f"Symlink failed, falling back to copy: {e}")
1571
+ shutil.copy2(source_path, target_path)
1572
+ else:
1573
+ shutil.copy2(source_path, target_path)
443
1574
 
1575
+ return output_dir
444
1576
 
445
- # Add annotations after definition to deal with forward reference issues in pydantic
446
1577
 
447
- DatasetBag.list_dataset_children = validate_call(
448
- config=ConfigDict(arbitrary_types_allowed=True),
449
- validate_return=True,
450
- )(DatasetBag.list_dataset_children)
1578
+ # Note: validate_call decorators with Self return types were removed because
1579
+ # Pydantic doesn't support typing.Self in validate_call contexts.