deriva-ml 1.17.10__py3-none-any.whl → 1.17.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. deriva_ml/__init__.py +43 -1
  2. deriva_ml/asset/__init__.py +17 -0
  3. deriva_ml/asset/asset.py +357 -0
  4. deriva_ml/asset/aux_classes.py +100 -0
  5. deriva_ml/bump_version.py +254 -11
  6. deriva_ml/catalog/__init__.py +21 -0
  7. deriva_ml/catalog/clone.py +1199 -0
  8. deriva_ml/catalog/localize.py +426 -0
  9. deriva_ml/core/__init__.py +29 -0
  10. deriva_ml/core/base.py +817 -1067
  11. deriva_ml/core/config.py +169 -21
  12. deriva_ml/core/constants.py +120 -19
  13. deriva_ml/core/definitions.py +123 -13
  14. deriva_ml/core/enums.py +47 -73
  15. deriva_ml/core/ermrest.py +226 -193
  16. deriva_ml/core/exceptions.py +297 -14
  17. deriva_ml/core/filespec.py +99 -28
  18. deriva_ml/core/logging_config.py +225 -0
  19. deriva_ml/core/mixins/__init__.py +42 -0
  20. deriva_ml/core/mixins/annotation.py +915 -0
  21. deriva_ml/core/mixins/asset.py +384 -0
  22. deriva_ml/core/mixins/dataset.py +237 -0
  23. deriva_ml/core/mixins/execution.py +408 -0
  24. deriva_ml/core/mixins/feature.py +365 -0
  25. deriva_ml/core/mixins/file.py +263 -0
  26. deriva_ml/core/mixins/path_builder.py +145 -0
  27. deriva_ml/core/mixins/rid_resolution.py +204 -0
  28. deriva_ml/core/mixins/vocabulary.py +400 -0
  29. deriva_ml/core/mixins/workflow.py +322 -0
  30. deriva_ml/core/validation.py +389 -0
  31. deriva_ml/dataset/__init__.py +2 -1
  32. deriva_ml/dataset/aux_classes.py +20 -4
  33. deriva_ml/dataset/catalog_graph.py +575 -0
  34. deriva_ml/dataset/dataset.py +1242 -1008
  35. deriva_ml/dataset/dataset_bag.py +1311 -182
  36. deriva_ml/dataset/history.py +27 -14
  37. deriva_ml/dataset/upload.py +225 -38
  38. deriva_ml/demo_catalog.py +126 -110
  39. deriva_ml/execution/__init__.py +46 -2
  40. deriva_ml/execution/base_config.py +639 -0
  41. deriva_ml/execution/execution.py +543 -242
  42. deriva_ml/execution/execution_configuration.py +26 -11
  43. deriva_ml/execution/execution_record.py +592 -0
  44. deriva_ml/execution/find_caller.py +298 -0
  45. deriva_ml/execution/model_protocol.py +175 -0
  46. deriva_ml/execution/multirun_config.py +153 -0
  47. deriva_ml/execution/runner.py +595 -0
  48. deriva_ml/execution/workflow.py +223 -34
  49. deriva_ml/experiment/__init__.py +8 -0
  50. deriva_ml/experiment/experiment.py +411 -0
  51. deriva_ml/feature.py +6 -1
  52. deriva_ml/install_kernel.py +143 -6
  53. deriva_ml/interfaces.py +862 -0
  54. deriva_ml/model/__init__.py +99 -0
  55. deriva_ml/model/annotations.py +1278 -0
  56. deriva_ml/model/catalog.py +286 -60
  57. deriva_ml/model/database.py +144 -649
  58. deriva_ml/model/deriva_ml_database.py +308 -0
  59. deriva_ml/model/handles.py +14 -0
  60. deriva_ml/run_model.py +319 -0
  61. deriva_ml/run_notebook.py +507 -38
  62. deriva_ml/schema/__init__.py +18 -2
  63. deriva_ml/schema/annotations.py +62 -33
  64. deriva_ml/schema/create_schema.py +169 -69
  65. deriva_ml/schema/validation.py +601 -0
  66. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -4
  67. deriva_ml-1.17.11.dist-info/RECORD +77 -0
  68. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
  69. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +1 -0
  70. deriva_ml/protocols/dataset.py +0 -19
  71. deriva_ml/test.py +0 -94
  72. deriva_ml-1.17.10.dist-info/RECORD +0 -45
  73. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
  74. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,16 @@ DerivaML class instances.
14
14
 
15
15
  Typical usage example:
16
16
  >>> ml = DerivaML('deriva.example.org', 'my_catalog')
17
- >>> dataset_rid = ml.create_dataset('experiment', 'Experimental data')
18
- >>> ml.add_dataset_members(dataset_rid=dataset_rid, members=['1-abc123', '1-def456'])
19
- >>> ml.increment_dataset_version(datset_rid=dataset_rid, component=VersionPart.minor,
20
- ... description='Added new samples')
17
+ >>> with ml.create_execution(config) as exe:
18
+ ... dataset = exe.create_dataset(
19
+ ... dataset_types=['experiment'],
20
+ ... description='Experimental data'
21
+ ... )
22
+ ... dataset.add_dataset_members(members=['1-abc123', '1-def456'])
23
+ ... dataset.increment_dataset_version(
24
+ ... component=VersionPart.minor,
25
+ ... description='Added new samples'
26
+ ... )
21
27
  """
22
28
 
23
29
  from __future__ import annotations
@@ -29,21 +35,23 @@ from collections import defaultdict
29
35
  # Standard library imports
30
36
  from graphlib import TopologicalSorter
31
37
  from pathlib import Path
38
+
39
+ # Local imports
40
+ from pprint import pformat
32
41
  from tempfile import TemporaryDirectory
33
- from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator
42
+ from typing import Any, Generator, Iterable, Self
34
43
  from urllib.parse import urlparse
35
44
 
45
+ # Deriva imports
36
46
  import deriva.core.utils.hash_utils as hash_utils
37
- import requests
38
47
 
39
48
  # Third-party imports
49
+ import pandas as pd
50
+ import requests
40
51
  from bdbag import bdbag_api as bdb
41
52
  from bdbag.fetch.fetcher import fetch_single_file
42
-
43
- # Deriva imports
44
53
  from deriva.core.ermrest_model import Table
45
54
  from deriva.core.utils.core_utils import format_exception
46
- from deriva.core.utils.core_utils import tag as deriva_tags
47
55
  from deriva.transfer.download.deriva_download import (
48
56
  DerivaDownloadAuthenticationError,
49
57
  DerivaDownloadAuthorizationError,
@@ -54,22 +62,25 @@ from deriva.transfer.download.deriva_download import (
54
62
  from deriva.transfer.download.deriva_export import DerivaExport
55
63
  from pydantic import ConfigDict, validate_call
56
64
 
57
- # Local imports
58
65
  try:
59
66
  from icecream import ic
60
67
 
61
- ic.configureOutput(includeContext=True)
68
+ ic.configureOutput(
69
+ includeContext=True,
70
+ argToStringFunction=lambda x: pformat(x.model_dump() if hasattr(x, "model_dump") else x, width=80, depth=10),
71
+ )
72
+
62
73
  except ImportError: # Graceful fallback if IceCream isn't installed.
63
74
  ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
64
75
 
65
76
  from deriva_ml.core.constants import RID
66
77
  from deriva_ml.core.definitions import (
67
78
  DRY_RUN_RID,
68
- ML_SCHEMA,
69
79
  MLVocab,
70
80
  Status,
81
+ VocabularyTerm,
71
82
  )
72
- from deriva_ml.core.exceptions import DerivaMLException, DerivaMLTableTypeError
83
+ from deriva_ml.core.exceptions import DerivaMLException
73
84
  from deriva_ml.dataset.aux_classes import (
74
85
  DatasetHistory,
75
86
  DatasetMinid,
@@ -77,18 +88,12 @@ from deriva_ml.dataset.aux_classes import (
77
88
  DatasetVersion,
78
89
  VersionPart,
79
90
  )
91
+ from deriva_ml.dataset.catalog_graph import CatalogGraph
80
92
  from deriva_ml.dataset.dataset_bag import DatasetBag
81
- from deriva_ml.model.catalog import DerivaModel
93
+ from deriva_ml.feature import Feature
94
+ from deriva_ml.interfaces import DerivaMLCatalog
82
95
  from deriva_ml.model.database import DatabaseModel
83
96
 
84
- from .history import iso_to_snap
85
-
86
- # Stop pycharm from complaining about undefined reference in docstring....
87
- ml: DerivaML
88
-
89
- if TYPE_CHECKING:
90
- from deriva_ml.core.base import DerivaML
91
-
92
97
 
93
98
  class Dataset:
94
99
  """Manages dataset operations in a Deriva catalog.
@@ -96,139 +101,368 @@ class Dataset:
96
101
  The Dataset class provides functionality for creating, modifying, and tracking datasets
97
102
  in a Deriva catalog. It handles versioning, relationships between datasets, and data export.
98
103
 
104
+ A Dataset is a versioned collection of related data elements. Each dataset:
105
+ - Has a unique RID (Resource Identifier) within the catalog
106
+ - Maintains a version history using semantic versioning (major.minor.patch)
107
+ - Can contain nested datasets, forming a hierarchy
108
+ - Can be exported as a BDBag for offline use or sharing
109
+
110
+ The class implements the DatasetLike protocol, allowing code to work uniformly
111
+ with both live catalog datasets and downloaded DatasetBag objects.
112
+
99
113
  Attributes:
100
- dataset_table (Table): ERMrest table storing dataset information.
101
- _model (DerivaModel): Catalog model instance.
102
- _ml_schema (str): Schema name for ML-specific tables.
103
- _cache_dir (Path): Directory for caching downloaded datasets.
104
- _working_dir (Path): Directory for working data.
105
- _use_minid (bool): Whether to use MINID service for dataset identification.
106
-
107
- Note:
108
- This class is typically used as a base class, with its methods accessed through
109
- DerivaML class instances rather than directly.
114
+ dataset_rid (RID): The unique Resource Identifier for this dataset.
115
+ dataset_types (list[str]): List of vocabulary terms describing the dataset type.
116
+ description (str): Human-readable description of the dataset.
117
+ execution_rid (RID | None): Optional RID of the execution that created this dataset.
118
+ _ml_instance (DerivaMLCatalog): Reference to the catalog containing this dataset.
119
+
120
+ Example:
121
+ >>> # Create a new dataset via an execution
122
+ >>> with ml.create_execution(config) as exe:
123
+ ... dataset = exe.create_dataset(
124
+ ... dataset_types=["training_data"],
125
+ ... description="Image classification training set"
126
+ ... )
127
+ ... # Add members to the dataset
128
+ ... dataset.add_dataset_members(members=["1-abc", "1-def"])
129
+ ... # Increment version after changes
130
+ ... new_version = dataset.increment_dataset_version(VersionPart.minor, "Added samples")
131
+ >>> # Download for offline use
132
+ >>> bag = dataset.download_dataset_bag(version=new_version)
110
133
  """
111
134
 
112
- _Logger = logging.getLogger("deriva_ml")
113
-
135
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
114
136
  def __init__(
115
137
  self,
116
- model: DerivaModel,
117
- cache_dir: Path,
118
- working_dir: Path,
119
- use_minid: bool = True,
138
+ catalog: DerivaMLCatalog,
139
+ dataset_rid: RID,
140
+ description: str = "",
141
+ execution_rid: RID | None = None,
120
142
  ):
121
- """Initializes a Dataset instance.
143
+ """Initialize a Dataset object from an existing dataset in the catalog.
144
+
145
+ This constructor wraps an existing dataset record. To create a new dataset
146
+ in the catalog, use the static method Dataset.create_dataset() instead.
122
147
 
123
148
  Args:
124
- model: DerivaModel instance representing the catalog.
125
- cache_dir: Directory path for caching downloaded datasets.
126
- working_dir: Directory path for working data.
127
- use_minid: Whether to use MINID service for dataset identification.
149
+ catalog: The DerivaMLCatalog instance containing this dataset.
150
+ dataset_rid: The RID of the existing dataset record.
151
+ description: Human-readable description of the dataset's purpose and contents.
152
+ execution_rid: Optional execution RID that created or is associated with this dataset.
153
+
154
+ Example:
155
+ >>> # Wrap an existing dataset
156
+ >>> dataset = Dataset(catalog=ml, dataset_rid="4HM")
128
157
  """
129
- self._model = model
130
- self._ml_schema = ML_SCHEMA
131
- self._cache_dir = cache_dir
132
- self._working_dir = working_dir
133
158
  self._logger = logging.getLogger("deriva_ml")
134
- self._use_minid = use_minid
159
+ self.dataset_rid = dataset_rid
160
+ self.execution_rid = execution_rid
161
+ self._ml_instance = catalog
162
+ self.description = description
163
+
164
+ def __repr__(self) -> str:
165
+ """Return a string representation of the Dataset for debugging."""
166
+ return (f"<deriva_ml.Dataset object at {hex(id(self))}: rid='{self.dataset_rid}', "
167
+ f"version='{self.current_version}', types={self.dataset_types}>")
168
+
169
+ def __hash__(self) -> int:
170
+ """Return hash based on dataset RID for use in sets and as dict keys.
171
+
172
+ This allows Dataset objects to be stored in sets and used as dictionary keys.
173
+ Two Dataset objects with the same RID will hash to the same value.
174
+ """
175
+ return hash(self.dataset_rid)
176
+
177
+ def __eq__(self, other: object) -> bool:
178
+ """Check equality based on dataset RID.
179
+
180
+ Two Dataset objects are considered equal if they reference the same
181
+ dataset RID, regardless of other attributes like version or types.
182
+
183
+ Args:
184
+ other: Object to compare with.
185
+
186
+ Returns:
187
+ True if other is a Dataset with the same RID, False otherwise.
188
+ Returns NotImplemented for non-Dataset objects.
189
+ """
190
+ if not isinstance(other, Dataset):
191
+ return NotImplemented
192
+ return self.dataset_rid == other.dataset_rid
193
+
194
+ def _get_dataset_type_association_table(self) -> tuple[str, Any]:
195
+ """Get the association table for dataset types.
196
+
197
+ Returns:
198
+ Tuple of (table_name, table_path) for the Dataset-Dataset_Type association table.
199
+ """
200
+ associations = list(
201
+ self._ml_instance.model.schemas[self._ml_instance.ml_schema]
202
+ .tables[MLVocab.dataset_type]
203
+ .find_associations()
204
+ )
205
+ atable_name = associations[0].name if associations else None
206
+ pb = self._ml_instance.pathBuilder()
207
+ atable_path = pb.schemas[self._ml_instance.ml_schema].tables[atable_name]
208
+ return atable_name, atable_path
135
209
 
136
210
  @property
137
- def _dataset_table(self):
138
- return self._model.schemas[self._ml_schema].tables["Dataset"]
211
+ def dataset_types(self) -> list[str]:
212
+ """Get the dataset types from the catalog.
139
213
 
140
- def _is_dataset_rid(self, dataset_rid: RID, deleted: bool = False) -> bool:
141
- try:
142
- rid_info = self._model.catalog.resolve_rid(dataset_rid, self._model.model)
143
- except KeyError as _e:
144
- raise DerivaMLException(f"Invalid RID {dataset_rid}")
145
- if rid_info.table != self._dataset_table:
146
- return False
147
- elif deleted:
148
- # Got a dataset rid. Now check to see if its deleted or not.
149
- return True
150
- else:
151
- return not list(rid_info.datapath.entities().fetch())[0]["Deleted"]
214
+ This property fetches the current dataset types directly from the catalog,
215
+ ensuring consistency when multiple Dataset instances reference the same
216
+ dataset or when types are modified externally.
152
217
 
153
- def _insert_dataset_versions(
154
- self,
155
- dataset_list: list[DatasetSpec],
156
- description: str | None = "",
157
- execution_rid: RID | None = None,
158
- ) -> None:
159
- schema_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema]
160
- # determine snapshot after changes were made
218
+ Returns:
219
+ List of dataset type term names from the Dataset_Type vocabulary.
220
+ """
221
+ _, atable_path = self._get_dataset_type_association_table()
222
+ ds_types = (
223
+ atable_path.filter(atable_path.Dataset == self.dataset_rid)
224
+ .attributes(atable_path.Dataset_Type)
225
+ .fetch()
226
+ )
227
+ return [ds[MLVocab.dataset_type] for ds in ds_types]
161
228
 
162
- # Construct version records for insert
163
- version_records = schema_path.tables["Dataset_Version"].insert(
229
+ @staticmethod
230
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
231
+ def create_dataset(
232
+ ml_instance: DerivaMLCatalog,
233
+ execution_rid: RID,
234
+ dataset_types: str | list[str] | None = None,
235
+ description: str = "",
236
+ version: DatasetVersion | None = None,
237
+ ) -> Self:
238
+ """Creates a new dataset in the catalog.
239
+
240
+ Creates a dataset with specified types and description. The dataset must be
241
+ associated with an execution for provenance tracking.
242
+
243
+ Args:
244
+ ml_instance: DerivaMLCatalog instance.
245
+ execution_rid: Execution RID to associate with dataset creation (required).
246
+ dataset_types: One or more dataset type terms from Dataset_Type vocabulary.
247
+ description: Description of the dataset's purpose and contents.
248
+ version: Optional initial version number. Defaults to 0.1.0.
249
+
250
+ Returns:
251
+ Dataset: The newly created dataset.
252
+
253
+ Raises:
254
+ DerivaMLException: If dataset_types are invalid or creation fails.
255
+
256
+ Example:
257
+ >>> with ml.create_execution(config) as exe:
258
+ ... dataset = exe.create_dataset(
259
+ ... dataset_types=["experiment", "raw_data"],
260
+ ... description="RNA sequencing experiment data",
261
+ ... version=DatasetVersion(1, 0, 0)
262
+ ... )
263
+ """
264
+
265
+ version = version or DatasetVersion(0, 1, 0)
266
+
267
+ # Validate dataset types
268
+ ds_types = [dataset_types] if isinstance(dataset_types, str) else dataset_types
269
+ dataset_types = [ml_instance.lookup_term(MLVocab.dataset_type, t) for t in ds_types]
270
+
271
+ # Create the entry for the new dataset_table and get its RID.
272
+ pb = ml_instance.pathBuilder()
273
+ dataset_table_path = pb.schemas[ml_instance._dataset_table.schema.name].tables[ml_instance._dataset_table.name]
274
+ dataset_rid = dataset_table_path.insert(
164
275
  [
165
276
  {
166
- "Dataset": dataset.rid,
167
- "Version": str(dataset.version),
168
277
  "Description": description,
169
- "Execution": execution_rid,
278
+ "Deleted": False,
170
279
  }
171
- for dataset in dataset_list
172
280
  ]
281
+ )[0]["RID"]
282
+
283
+ pb.schemas[ml_instance.model.ml_schema].Dataset_Execution.insert(
284
+ [{"Dataset": dataset_rid, "Execution": execution_rid}]
173
285
  )
174
- version_records = list(version_records)
175
- snap = self._model.catalog.get("/").json()["snaptime"]
176
- schema_path.tables["Dataset_Version"].update(
177
- [{"RID": v["RID"], "Dataset": v["Dataset"], "Snapshot": snap} for v in version_records]
286
+ Dataset._insert_dataset_versions(
287
+ ml_instance=ml_instance,
288
+ dataset_list=[DatasetSpec(rid=dataset_rid, version=version)],
289
+ execution_rid=execution_rid,
290
+ description="Initial dataset creation.",
291
+ )
292
+ dataset = Dataset(
293
+ catalog=ml_instance,
294
+ dataset_rid=dataset_rid,
295
+ description=description,
178
296
  )
179
297
 
180
- # And update the dataset records.
181
- schema_path.tables["Dataset"].update([{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records])
298
+ # Skip version increment during initial creation (version already set above)
299
+ dataset.add_dataset_types(dataset_types, _skip_version_increment=True)
300
+ return dataset
182
301
 
183
- def _bootstrap_versions(self):
184
- datasets = [ds["RID"] for ds in self.find_datasets()]
185
- ds_version = [
186
- {
187
- "Dataset": d,
188
- "Version": "0.1.0",
189
- "Description": "Dataset at the time of conversion to versioned datasets",
190
- }
191
- for d in datasets
192
- ]
193
- schema_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema]
194
- version_path = schema_path.tables["Dataset_Version"]
195
- dataset_path = schema_path.tables["Dataset"]
196
- history = list(version_path.insert(ds_version))
197
- dataset_versions = [{"RID": h["Dataset"], "Version": h["Version"]} for h in history]
198
- dataset_path.update(dataset_versions)
199
-
200
- def _synchronize_dataset_versions(self):
201
- datasets = [ds["RID"] for ds in self.find_datasets()]
202
- for ds in datasets:
203
- self.dataset_version(ds)
204
- schema_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema]
205
- dataset_version_path = schema_path.tables["Dataset_Version"]
206
- # Get the maximum version number for each dataset.
207
- versions = {}
208
- for v in dataset_version_path.entities().fetch():
209
- if v["Version"] > versions.get("Dataset", DatasetVersion(0, 0, 0)):
210
- versions[v["Dataset"]] = v
211
- dataset_path = schema_path.tables["Dataset"]
212
-
213
- dataset_path.update([{"RID": dataset, "Version": version["RID"]} for dataset, version in versions.items()])
214
-
215
- def _set_version_snapshot(self):
216
- """Update the Snapshot column of the Dataset_Version table to the correct time."""
217
- dataset_version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
218
- versions = dataset_version_path.entities().fetch()
219
- dataset_version_path.update(
220
- [{"RID": h["RID"], "Snapshot": iso_to_snap(h["RCT"])} for h in versions if not h["Snapshot"]]
221
- )
302
+ def add_dataset_type(
303
+ self,
304
+ dataset_type: str | VocabularyTerm,
305
+ _skip_version_increment: bool = False,
306
+ ) -> None:
307
+ """Add a dataset type to this dataset.
308
+
309
+ Adds a type term to this dataset if it's not already present. The term must
310
+ exist in the Dataset_Type vocabulary. Also increments the dataset's minor
311
+ version to reflect the metadata change.
312
+
313
+ Args:
314
+ dataset_type: Term name (string) or VocabularyTerm object from Dataset_Type vocabulary.
315
+ _skip_version_increment: Internal parameter to skip version increment when
316
+ called from add_dataset_types (which handles versioning itself).
317
+
318
+ Raises:
319
+ DerivaMLInvalidTerm: If the term doesn't exist in the Dataset_Type vocabulary.
320
+
321
+ Example:
322
+ >>> dataset.add_dataset_type("Training")
323
+ >>> dataset.add_dataset_type("Validation")
324
+ """
325
+ # Convert to VocabularyTerm if needed (validates the term exists)
326
+ if isinstance(dataset_type, VocabularyTerm):
327
+ vocab_term = dataset_type
328
+ else:
329
+ vocab_term = self._ml_instance.lookup_term(MLVocab.dataset_type, dataset_type)
330
+
331
+ # Check if already present
332
+ if vocab_term.name in self.dataset_types:
333
+ return
334
+
335
+ # Insert into association table
336
+ _, atable_path = self._get_dataset_type_association_table()
337
+ atable_path.insert([{MLVocab.dataset_type: vocab_term.name, "Dataset": self.dataset_rid}])
338
+
339
+ # Increment minor version to reflect metadata change (unless called from add_dataset_types)
340
+ if not _skip_version_increment:
341
+ self.increment_dataset_version(
342
+ VersionPart.minor,
343
+ description=f"Added dataset type: {vocab_term.name}",
344
+ )
345
+
346
+ def remove_dataset_type(self, dataset_type: str | VocabularyTerm) -> None:
347
+ """Remove a dataset type from this dataset.
348
+
349
+ Removes a type term from this dataset if it's currently associated. The term
350
+ must exist in the Dataset_Type vocabulary.
351
+
352
+ Args:
353
+ dataset_type: Term name (string) or VocabularyTerm object from Dataset_Type vocabulary.
354
+
355
+ Raises:
356
+ DerivaMLInvalidTerm: If the term doesn't exist in the Dataset_Type vocabulary.
357
+
358
+ Example:
359
+ >>> dataset.remove_dataset_type("Training")
360
+ """
361
+ # Convert to VocabularyTerm if needed (validates the term exists)
362
+ if isinstance(dataset_type, VocabularyTerm):
363
+ vocab_term = dataset_type
364
+ else:
365
+ vocab_term = self._ml_instance.lookup_term(MLVocab.dataset_type, dataset_type)
366
+
367
+ # Check if present
368
+ if vocab_term.name not in self.dataset_types:
369
+ return
370
+
371
+ # Delete from association table
372
+ _, atable_path = self._get_dataset_type_association_table()
373
+ atable_path.filter(
374
+ (atable_path.Dataset == self.dataset_rid) & (atable_path.Dataset_Type == vocab_term.name)
375
+ ).delete()
376
+
377
+ def add_dataset_types(
378
+ self,
379
+ dataset_types: str | VocabularyTerm | list[str | VocabularyTerm],
380
+ _skip_version_increment: bool = False,
381
+ ) -> None:
382
+ """Add one or more dataset types to this dataset.
383
+
384
+ Convenience method for adding multiple types at once. Each term must exist
385
+ in the Dataset_Type vocabulary. Types that are already associated with the
386
+ dataset are silently skipped. Increments the dataset's minor version once
387
+ after all types are added.
222
388
 
223
- def dataset_history(self, dataset_rid: RID) -> list[DatasetHistory]:
389
+ Args:
390
+ dataset_types: Single term or list of terms. Can be strings (term names)
391
+ or VocabularyTerm objects.
392
+ _skip_version_increment: Internal parameter to skip version increment
393
+ (used during initial dataset creation).
394
+
395
+ Raises:
396
+ DerivaMLInvalidTerm: If any term doesn't exist in the Dataset_Type vocabulary.
397
+
398
+ Example:
399
+ >>> dataset.add_dataset_types(["Training", "Image"])
400
+ >>> dataset.add_dataset_types("Testing")
401
+ """
402
+ # Normalize input to a list
403
+ types_to_add = [dataset_types] if not isinstance(dataset_types, list) else dataset_types
404
+
405
+ # Track which types were actually added (not already present)
406
+ added_types: list[str] = []
407
+ for term in types_to_add:
408
+ # Get term name before calling add_dataset_type
409
+ if isinstance(term, VocabularyTerm):
410
+ term_name = term.name
411
+ else:
412
+ term_name = self._ml_instance.lookup_term(MLVocab.dataset_type, term).name
413
+
414
+ # Check if already present before adding
415
+ if term_name not in self.dataset_types:
416
+ self.add_dataset_type(term, _skip_version_increment=True)
417
+ added_types.append(term_name)
418
+
419
+ # Increment version once for all added types (if any were added)
420
+ if added_types and not _skip_version_increment:
421
+ type_names = ", ".join(added_types)
422
+ self.increment_dataset_version(
423
+ VersionPart.minor,
424
+ description=f"Added dataset type(s): {type_names}",
425
+ )
426
+
427
+ @property
428
+ def _dataset_table(self) -> Table:
429
+ """Get the Dataset table from the catalog schema.
430
+
431
+ Returns:
432
+ Table: The Deriva Table object for the Dataset table in the ML schema.
433
+ """
434
+ return self._ml_instance.model.schemas[self._ml_instance.ml_schema].tables["Dataset"]
435
+
436
+ # ==================== Read Interface Methods ====================
437
+ # These methods implement the DatasetLike protocol for read operations.
438
+ # They delegate to the catalog instance for actual data retrieval.
439
+ # This allows Dataset and DatasetBag to share a common interface.
440
+
441
+ def list_dataset_element_types(self) -> Iterable[Table]:
442
+ """List the types of elements that can be contained in this dataset.
443
+
444
+ Returns:
445
+ Iterable of Table objects representing element types.
446
+ """
447
+ return self._ml_instance.list_dataset_element_types()
448
+
449
+ def find_features(self, table: str | Table) -> Iterable[Feature]:
450
+ """Find features associated with a table.
451
+
452
+ Args:
453
+ table: Table to find features for.
454
+
455
+ Returns:
456
+ Iterable of Feature objects.
457
+ """
458
+ return self._ml_instance.find_features(table)
459
+
460
+ def dataset_history(self) -> list[DatasetHistory]:
224
461
  """Retrieves the version history of a dataset.
225
462
 
226
463
  Returns a chronological list of dataset versions, including their version numbers,
227
464
  creation times, and associated metadata.
228
465
 
229
- Args:
230
- dataset_rid: Resource Identifier of the dataset.
231
-
232
466
  Returns:
233
467
  list[DatasetHistory]: List of history entries, each containing:
234
468
  - dataset_version: Version number (major.minor.patch)
@@ -248,38 +482,36 @@ class Dataset:
248
482
  ... print(f"Version {entry.dataset_version}: {entry.description}")
249
483
  """
250
484
 
251
- if not self._is_dataset_rid(dataset_rid):
252
- raise DerivaMLException(f"RID is not for a data set: {dataset_rid}")
253
- version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
485
+ if not self._ml_instance.model.is_dataset_rid(self.dataset_rid):
486
+ raise DerivaMLException(f"RID is not for a data set: {self.dataset_rid}")
487
+ version_path = self._ml_instance.pathBuilder().schemas[self._ml_instance.ml_schema].tables["Dataset_Version"]
254
488
  return [
255
489
  DatasetHistory(
256
490
  dataset_version=DatasetVersion.parse(v["Version"]),
257
491
  minid=v["Minid"],
258
492
  snapshot=v["Snapshot"],
259
- dataset_rid=dataset_rid,
493
+ dataset_rid=self.dataset_rid,
260
494
  version_rid=v["RID"],
261
495
  description=v["Description"],
262
496
  execution_rid=v["Execution"],
263
497
  )
264
- for v in version_path.filter(version_path.Dataset == dataset_rid).entities().fetch()
498
+ for v in version_path.filter(version_path.Dataset == self.dataset_rid).entities().fetch()
265
499
  ]
266
500
 
267
- @validate_call
268
- def dataset_version(self, dataset_rid: RID) -> DatasetVersion:
501
+ @property
502
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
503
+ def current_version(self) -> DatasetVersion:
269
504
  """Retrieve the current version of the specified dataset_table.
270
505
 
271
- Given a rid, return the most recent version of the dataset. It is important to remember that this version
506
+ Return the most recent version of the dataset. It is important to remember that this version
272
507
  captures the state of the catalog at the time the version was created, not the current state of the catalog.
273
508
  This means that its possible that the values associated with an object in the catalog may be different
274
509
  from the values of that object in the dataset.
275
510
 
276
- Args:
277
- dataset_rid: The RID of the dataset to retrieve the version for.
278
-
279
511
  Returns:
280
512
  A tuple with the semantic version of the dataset_table.
281
513
  """
282
- history = self.dataset_history(dataset_rid)
514
+ history = self.dataset_history()
283
515
  if not history:
284
516
  return DatasetVersion(0, 1, 0)
285
517
  else:
@@ -287,28 +519,129 @@ class Dataset:
287
519
  versions = [h.dataset_version for h in history]
288
520
  return max(versions) if versions else DatasetVersion(0, 1, 0)
289
521
 
290
- def _build_dataset_graph(self, dataset_rid: RID) -> Iterable[RID]:
522
+ def get_chaise_url(self) -> str:
523
+ """Get the Chaise URL for viewing this dataset in the browser.
524
+
525
+ Returns:
526
+ URL string for the dataset record in Chaise.
527
+ """
528
+ return (
529
+ f"https://{self._ml_instance.host_name}/chaise/record/"
530
+ f"#{self._ml_instance.catalog_id}/deriva-ml:Dataset/RID={self.dataset_rid}"
531
+ )
532
+
533
+ def to_markdown(self, show_children: bool = False, indent: int = 0) -> str:
534
+ """Generate a markdown representation of this dataset.
535
+
536
+ Returns a formatted markdown string with a link to the dataset,
537
+ version, types, and description. Optionally includes nested children.
538
+
539
+ Args:
540
+ show_children: If True, include direct child datasets.
541
+ indent: Number of indent levels (each level is 2 spaces).
542
+
543
+ Returns:
544
+ Markdown-formatted string.
545
+
546
+ Example:
547
+ >>> ds = ml.lookup_dataset("4HM")
548
+ >>> print(ds.to_markdown())
549
+ """
550
+ prefix = " " * indent
551
+ version = str(self.current_version) if self.current_version else "n/a"
552
+ types = ", ".join(self.dataset_types) if self.dataset_types else ""
553
+ desc = self.description or ""
554
+
555
+ line = f"{prefix}- [{self.dataset_rid}]({self.get_chaise_url()}) v{version}"
556
+ if types:
557
+ line += f" [{types}]"
558
+ if desc:
559
+ line += f": {desc}"
560
+
561
+ lines = [line]
562
+
563
+ if show_children:
564
+ children = self.list_dataset_children(recurse=False)
565
+ for child in children:
566
+ lines.append(child.to_markdown(show_children=False, indent=indent + 1))
567
+
568
+ return "\n".join(lines)
569
+
570
+ def display_markdown(self, show_children: bool = False, indent: int = 0) -> None:
571
+ """Display a formatted markdown representation of this dataset in Jupyter.
572
+
573
+ Convenience method that calls to_markdown() and displays the result
574
+ using IPython.display.Markdown.
575
+
576
+ Args:
577
+ show_children: If True, include direct child datasets.
578
+ indent: Number of indent levels (each level is 2 spaces).
579
+
580
+ Example:
581
+ >>> ds = ml.lookup_dataset("4HM")
582
+ >>> ds.display_markdown(show_children=True)
583
+ """
584
+ from IPython.display import display, Markdown
585
+
586
+ display(Markdown(self.to_markdown(show_children, indent)))
587
+
588
+ def _build_dataset_graph(self) -> Iterable[Dataset]:
589
+ """Build a dependency graph of all related datasets and return in topological order.
590
+
591
+ This method is used when incrementing dataset versions. Because datasets can be
592
+ nested (parent-child relationships), changing the version of one dataset may
593
+ require updating related datasets.
594
+
595
+ The topological sort ensures that children are processed before parents,
596
+ so version updates propagate correctly through the hierarchy.
597
+
598
+ Returns:
599
+ Iterable[Dataset]: Datasets in topological order (children before parents).
600
+
601
+ Example:
602
+ If dataset A contains nested dataset B, which contains C:
603
+ A -> B -> C
604
+ The returned order would be [C, B, A], ensuring C's version is
605
+ updated before B's, and B's before A's.
606
+ """
291
607
  ts: TopologicalSorter = TopologicalSorter()
292
- self._build_dataset_graph_1(dataset_rid, ts, set())
608
+ self._build_dataset_graph_1(ts, set())
293
609
  return ts.static_order()
294
610
 
295
- def _build_dataset_graph_1(self, dataset_rid: RID, ts: TopologicalSorter, visited) -> None:
296
- """Use topological sort to return bottom up list of nested datasets"""
297
- ts.add(dataset_rid)
298
- if dataset_rid not in visited:
299
- visited.add(dataset_rid)
300
- children = self.list_dataset_children(dataset_rid=dataset_rid)
301
- parents = self.list_dataset_parents(dataset_rid=dataset_rid)
302
- for parent in parents:
303
- # Convert string to RID type
304
- self._build_dataset_graph_1(RID(parent), ts, visited)
305
- for child in children:
306
- self._build_dataset_graph_1(child, ts, visited)
611
+ def _build_dataset_graph_1(self, ts: TopologicalSorter, visited: set[str]) -> None:
612
+ """Recursively build the dataset dependency graph.
613
+
614
+ Uses topological sort where parents depend on their children, ensuring
615
+ children are processed before parents in the resulting order.
616
+
617
+ Args:
618
+ ts: TopologicalSorter instance to add nodes and dependencies to.
619
+ visited: Set of already-visited dataset RIDs to avoid cycles.
620
+ """
621
+ if self.dataset_rid in visited:
622
+ return
623
+
624
+ visited.add(self.dataset_rid)
625
+ # Use current catalog state for graph traversal, not version snapshot.
626
+ # Parent/child relationships need to reflect current state for version updates.
627
+ children = self._list_dataset_children_current()
628
+ parents = self._list_dataset_parents_current()
629
+
630
+ # Add this node with its children as dependencies.
631
+ # This means: self depends on children, so children will be ordered before self.
632
+ ts.add(self, *children)
633
+
634
+ # Recursively process children
635
+ for child in children:
636
+ child._build_dataset_graph_1(ts, visited)
637
+
638
+ # Recursively process parents (they will depend on this node)
639
+ for parent in parents:
640
+ parent._build_dataset_graph_1(ts, visited)
307
641
 
308
642
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
309
643
  def increment_dataset_version(
310
644
  self,
311
- dataset_rid: RID,
312
645
  component: VersionPart,
313
646
  description: str | None = "",
314
647
  execution_rid: RID | None = None,
@@ -320,7 +653,6 @@ class Dataset:
320
653
  and execution reference.
321
654
 
322
655
  Args:
323
- dataset_rid: Resource Identifier of the dataset to version.
324
656
  component: Which version component to increment ('major', 'minor', or 'patch').
325
657
  description: Optional description of the changes in this version.
326
658
  execution_rid: Optional execution RID to associate with this version.
@@ -341,190 +673,27 @@ class Dataset:
341
673
  """
342
674
 
343
675
  # Find all the datasets that are reachable from this dataset and determine their new version numbers.
344
- related_datasets = list(self._build_dataset_graph(dataset_rid=dataset_rid))
676
+ related_datasets = list(self._build_dataset_graph())
345
677
  version_update_list = [
346
678
  DatasetSpec(
347
- rid=ds_rid,
348
- version=self.dataset_version(ds_rid).increment_version(component),
679
+ rid=ds.dataset_rid,
680
+ version=ds.current_version.increment_version(component),
349
681
  )
350
- for ds_rid in related_datasets
682
+ for ds in related_datasets
351
683
  ]
352
- self._insert_dataset_versions(version_update_list, description=description, execution_rid=execution_rid)
353
- return next((d.version for d in version_update_list if d.rid == dataset_rid))
354
-
355
- @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
356
- def create_dataset(
357
- self,
358
- dataset_types: str | list[str] | None = None,
359
- description: str = "",
360
- execution_rid: RID | None = None,
361
- version: DatasetVersion | None = None,
362
- ) -> RID:
363
- """Creates a new dataset in the catalog.
364
-
365
- Creates a dataset with specified types and description. The dataset can be associated
366
- with an execution and initialized with a specific version.
367
-
368
- Args:
369
- dataset_types: One or more dataset type terms from Dataset_Type vocabulary.
370
- description: Description of the dataset's purpose and contents.
371
- execution_rid: Optional execution RID to associate with dataset creation.
372
- version: Optional initial version number. Defaults to 0.1.0.
373
-
374
- Returns:
375
- RID: Resource Identifier of the newly created dataset.
376
-
377
- Raises:
378
- DerivaMLException: If dataset_types are invalid or creation fails.
379
-
380
- Example:
381
- >>> rid = ml.create_dataset(
382
- ... dataset_types=["experiment", "raw_data"],
383
- ... description="RNA sequencing experiment data",
384
- ... version=DatasetVersion(1, 0, 0)
385
- ... )
386
- """
387
-
388
- version = version or DatasetVersion(0, 1, 0)
389
- dataset_types = dataset_types or []
390
-
391
- type_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables[MLVocab.dataset_type.value]
392
- defined_types = list(type_path.entities().fetch())
393
-
394
- def check_dataset_type(dtype: str) -> bool:
395
- for term in defined_types:
396
- if dtype == term["Name"] or (term["Synonyms"] and ds_type in term["Synonyms"]):
397
- return True
398
- return False
399
-
400
- # Create the entry for the new dataset_table and get its RID.
401
- ds_types = [dataset_types] if isinstance(dataset_types, str) else dataset_types
402
- pb = self._model.catalog.getPathBuilder()
403
- for ds_type in ds_types:
404
- if not check_dataset_type(ds_type):
405
- raise DerivaMLException("Dataset type must be a vocabulary term.")
406
- dataset_table_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
407
- dataset_rid = dataset_table_path.insert(
408
- [
409
- {
410
- "Description": description,
411
- "Deleted": False,
412
- }
413
- ]
414
- )[0]["RID"]
415
-
416
- # Get the name of the association table between dataset_table and dataset_type.
417
- associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
418
- atable = associations[0].name if associations else None
419
- pb.schemas[self._ml_schema].tables[atable].insert(
420
- [{MLVocab.dataset_type: ds_type, "Dataset": dataset_rid} for ds_type in ds_types]
684
+ Dataset._insert_dataset_versions(
685
+ self._ml_instance, version_update_list, description=description, execution_rid=execution_rid
421
686
  )
422
- if execution_rid is not None:
423
- pb.schemas[self._ml_schema].Dataset_Execution.insert([{"Dataset": dataset_rid, "Execution": execution_rid}])
424
- self._insert_dataset_versions(
425
- [DatasetSpec(rid=dataset_rid, version=version)],
426
- execution_rid=execution_rid,
427
- description="Initial dataset creation.",
428
- )
429
- return dataset_rid
430
-
431
- @validate_call
432
- def delete_dataset(self, dataset_rid: RID, recurse: bool = False) -> None:
433
- """Delete a dataset_table from the catalog.
434
-
435
- Args:
436
- dataset_rid: RID of the dataset_table to delete.
437
- recurse: If True, delete the dataset_table along with any nested datasets. (Default value = False)
438
- """
439
- # Get association table entries for this dataset_table
440
- # Delete association table entries
441
- if not self._is_dataset_rid(dataset_rid):
442
- raise DerivaMLException("Dataset_rid is not a dataset.")
443
-
444
- if parents := self.list_dataset_parents(dataset_rid):
445
- raise DerivaMLException(f'Dataset_rid "{dataset_rid}" is in a nested dataset: {parents}.')
446
-
447
- pb = self._model.catalog.getPathBuilder()
448
- dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
449
-
450
- rid_list = [dataset_rid] + (self.list_dataset_children(dataset_rid=dataset_rid) if recurse else [])
451
- dataset_path.update([{"RID": r, "Deleted": True} for r in rid_list])
452
-
453
- def find_datasets(self, deleted: bool = False) -> Iterable[dict[str, Any]]:
454
- """Returns a list of currently available datasets.
455
-
456
- Arguments:
457
- deleted: If True, included the datasets that have been deleted.
458
-
459
- Returns:
460
- list of currently available datasets.
461
- """
462
- # Get datapath to all the tables we will need: Dataset, DatasetType and the association table.
463
- pb = self._model.catalog.getPathBuilder()
464
- dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
465
- associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
466
- atable = associations[0].name if associations else None
467
- ml_path = pb.schemas[self._ml_schema]
468
- atable_path = ml_path.tables[atable]
469
-
470
- if deleted:
471
- filtered_path = dataset_path
472
- else:
473
- filtered_path = dataset_path.filter(
474
- (dataset_path.Deleted == False) | (dataset_path.Deleted == None) # noqa: E711, E712
475
- )
476
-
477
- # Get a list of all the dataset_type values associated with this dataset_table.
478
- datasets = []
479
- for dataset in filtered_path.entities().fetch():
480
- ds_types = (
481
- atable_path.filter(atable_path.Dataset == dataset["RID"]).attributes(atable_path.Dataset_Type).fetch()
482
- )
483
- datasets.append(dataset | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in ds_types]})
484
- return datasets
485
-
486
- def list_dataset_element_types(self) -> Iterable[Table]:
487
- """List the types of entities that can be added to a dataset_table.
488
-
489
- Returns:
490
- :return: An iterable of Table objects that can be included as an element of a dataset_table.
491
- """
492
-
493
- def domain_table(table: Table) -> bool:
494
- return table.schema.name == self._model.domain_schema or table.name == self._dataset_table.name
495
-
496
- return [t for a in self._dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
687
+ return next((d.version for d in version_update_list if d.rid == self.dataset_rid))
497
688
 
498
689
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
499
- def add_dataset_element_type(self, element: str | Table) -> Table:
500
- """A dataset_table is a heterogeneous collection of objects, each of which comes from a different table. This
501
- routine makes it possible to add objects from the specified table to a dataset_table.
502
-
503
- Args:
504
- element: Name of the table or table object that is to be added to the dataset_table.
505
-
506
- Returns:
507
- The table object that was added to the dataset_table.
508
- """
509
- # Add table to map
510
- element_table = self._model.name_to_table(element)
511
- atable_def = Table.define_association([self._dataset_table, element_table])
512
- try:
513
- table = self._model.schemas[self._model.domain_schema].create_table(atable_def)
514
- except ValueError as e:
515
- if "already exists" in str(e):
516
- table = self._model.name_to_table(atable_def["table_name"])
517
- else:
518
- raise e
519
-
520
- # self.model = self.catalog.getCatalogModel()
521
- self._dataset_table.annotations.update(self._generate_dataset_download_annotations())
522
- self._model.model.apply()
523
- return table
524
-
525
- # @validate_call
526
690
  def list_dataset_members(
527
- self, dataset_rid: RID, recurse: bool = False, limit: int | None = None
691
+ self,
692
+ recurse: bool = False,
693
+ limit: int | None = None,
694
+ _visited: set[RID] | None = None,
695
+ version: DatasetVersion | str | None = None,
696
+ **kwargs: Any,
528
697
  ) -> dict[str, list[dict[str, Any]]]:
529
698
  """Lists members of a dataset.
530
699
 
@@ -532,9 +701,11 @@ class Dataset:
532
701
  recurse through nested datasets and limit the number of results.
533
702
 
534
703
  Args:
535
- dataset_rid: Resource Identifier of the dataset.
536
704
  recurse: Whether to include members of nested datasets. Defaults to False.
537
705
  limit: Maximum number of members to return per type. None for no limit.
706
+ _visited: Internal parameter to track visited datasets and prevent infinite recursion.
707
+ version: Dataset version to list members from. Defaults to the current version.
708
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
538
709
 
539
710
  Returns:
540
711
  dict[str, list[dict[str, Any]]]: Dictionary mapping member types to lists of members.
@@ -548,21 +719,27 @@ class Dataset:
548
719
  >>> for type_name, records in members.items():
549
720
  ... print(f"{type_name}: {len(records)} records")
550
721
  """
722
+ # Initialize visited set for recursion guard
723
+ if _visited is None:
724
+ _visited = set()
551
725
 
552
- if not self._is_dataset_rid(dataset_rid):
553
- raise DerivaMLException(f"RID is not for a dataset_table: {dataset_rid}")
726
+ # Prevent infinite recursion by checking if we've already visited this dataset
727
+ if self.dataset_rid in _visited:
728
+ return {}
729
+ _visited.add(self.dataset_rid)
554
730
 
555
731
  # Look at each of the element types that might be in the dataset_table and get the list of rid for them from
556
732
  # the appropriate association table.
557
733
  members = defaultdict(list)
558
- pb = self._model.catalog.getPathBuilder()
734
+ version_snapshot_catalog = self._version_snapshot_catalog(version)
735
+ pb = version_snapshot_catalog.pathBuilder()
559
736
  for assoc_table in self._dataset_table.find_associations():
560
737
  other_fkey = assoc_table.other_fkeys.pop()
561
738
  target_table = other_fkey.pk_table
562
739
  member_table = assoc_table.table
563
740
 
564
741
  # Look at domain tables and nested datasets.
565
- if target_table.schema.name != self._model.domain_schema and not (
742
+ if not self._ml_instance.model.is_domain_schema(target_table.schema.name) and not (
566
743
  target_table == self._dataset_table or target_table.name == "File"
567
744
  ):
568
745
  continue
@@ -573,7 +750,7 @@ class Dataset:
573
750
  target_path = pb.schemas[target_table.schema.name].tables[target_table.name]
574
751
  member_path = pb.schemas[member_table.schema.name].tables[member_table.name]
575
752
 
576
- path = member_path.filter(member_path.Dataset == dataset_rid).link(
753
+ path = member_path.filter(member_path.Dataset == self.dataset_rid).link(
577
754
  target_path,
578
755
  on=(member_path.columns[member_column] == target_path.columns["RID"]),
579
756
  )
@@ -582,15 +759,241 @@ class Dataset:
582
759
  if recurse and target_table == self._dataset_table:
583
760
  # Get the members for all the nested datasets and add to the member list.
584
761
  nested_datasets = [d["RID"] for d in target_entities]
585
- for ds in nested_datasets:
586
- for k, v in self.list_dataset_members(ds, recurse=recurse).items():
762
+ for ds_rid in nested_datasets:
763
+ ds = version_snapshot_catalog.lookup_dataset(ds_rid)
764
+ for k, v in ds.list_dataset_members(version=version, recurse=recurse, _visited=_visited).items():
587
765
  members[k].extend(v)
588
766
  return dict(members)
589
767
 
590
- @validate_call
768
+ def _denormalize_datapath(
769
+ self,
770
+ include_tables: list[str],
771
+ version: DatasetVersion | str | None = None,
772
+ ) -> Generator[dict[str, Any], None, None]:
773
+ """Denormalize dataset members by joining related tables.
774
+
775
+ This method creates a "wide table" view by joining related tables together using
776
+ the Deriva datapath API, producing rows that contain columns from all specified
777
+ tables. The result has outer join semantics - rows from tables without FK
778
+ relationships are included with NULL values for unrelated columns.
779
+
780
+ The method:
781
+ 1. Gets the list of dataset members for each included table
782
+ 2. For each member in the first table, follows foreign key relationships to
783
+ get related records from other tables
784
+ 3. Tables without FK connections to the first table are included with NULLs
785
+ 4. Includes nested dataset members recursively
786
+
787
+ Args:
788
+ include_tables: List of table names to include in the output.
789
+ version: Dataset version to query. Defaults to current version.
790
+
791
+ Yields:
792
+ dict[str, Any]: Rows with column names prefixed by table name (e.g., "Image_Filename").
793
+ Unrelated tables have NULL values for their columns.
794
+
795
+ Note:
796
+ Column names in the result are prefixed with the table name to avoid
797
+ collisions (e.g., "Image_Filename", "Subject_RID").
798
+ """
799
+ # Skip system columns in output
800
+ skip_columns = {"RCT", "RMT", "RCB", "RMB"}
801
+
802
+ # Get all members for the included tables (recursively includes nested datasets)
803
+ members = self.list_dataset_members(version=version, recurse=True)
804
+
805
+ # Build a lookup of columns for each table
806
+ table_columns: dict[str, list[str]] = {}
807
+ for table_name in include_tables:
808
+ table = self._ml_instance.model.name_to_table(table_name)
809
+ table_columns[table_name] = [
810
+ c.name for c in table.columns if c.name not in skip_columns
811
+ ]
812
+
813
+ # Find the primary table (first non-empty table in include_tables)
814
+ primary_table = None
815
+ for table_name in include_tables:
816
+ if table_name in members and members[table_name]:
817
+ primary_table = table_name
818
+ break
819
+
820
+ if primary_table is None:
821
+ # No data at all
822
+ return
823
+
824
+ primary_table_obj = self._ml_instance.model.name_to_table(primary_table)
825
+
826
+ for member in members[primary_table]:
827
+ # Build the row with all columns from all tables
828
+ row: dict[str, Any] = {}
829
+
830
+ # Add primary table columns
831
+ for col_name in table_columns[primary_table]:
832
+ prefixed_name = f"{primary_table}_{col_name}"
833
+ row[prefixed_name] = member.get(col_name)
834
+
835
+ # For each other table, try to join or add NULL values
836
+ for other_table_name in include_tables:
837
+ if other_table_name == primary_table:
838
+ continue
839
+
840
+ other_table = self._ml_instance.model.name_to_table(other_table_name)
841
+ other_cols = table_columns[other_table_name]
842
+
843
+ # Initialize all columns to None (outer join behavior)
844
+ for col_name in other_cols:
845
+ prefixed_name = f"{other_table_name}_{col_name}"
846
+ row[prefixed_name] = None
847
+
848
+ # Try to find FK relationship and join
849
+ if other_table_name in members:
850
+ try:
851
+ relationship = self._ml_instance.model._table_relationship(
852
+ primary_table_obj, other_table
853
+ )
854
+ fk_col, pk_col = relationship
855
+
856
+ # Look up the related record
857
+ fk_value = member.get(fk_col.name)
858
+ if fk_value:
859
+ for other_member in members.get(other_table_name, []):
860
+ if other_member.get(pk_col.name) == fk_value:
861
+ for col_name in other_cols:
862
+ prefixed_name = f"{other_table_name}_{col_name}"
863
+ row[prefixed_name] = other_member.get(col_name)
864
+ break
865
+ except DerivaMLException:
866
+ # No FK relationship - columns remain NULL (outer join)
867
+ pass
868
+
869
+ yield row
870
+
871
+ def denormalize_as_dataframe(
872
+ self,
873
+ include_tables: list[str],
874
+ version: DatasetVersion | str | None = None,
875
+ **kwargs: Any,
876
+ ) -> pd.DataFrame:
877
+ """Denormalize the dataset into a single wide table (DataFrame).
878
+
879
+ Denormalization transforms normalized relational data into a single "wide table"
880
+ (also called a "flat table" or "denormalized table") by joining related tables
881
+ together. This produces a DataFrame where each row contains all related information
882
+ from multiple source tables, with columns from each table combined side-by-side.
883
+
884
+ Wide tables are the standard input format for most machine learning frameworks,
885
+ which expect all features for a single observation to be in one row. This method
886
+ bridges the gap between normalized database schemas and ML-ready tabular data.
887
+
888
+ **How it works:**
889
+
890
+ Tables are joined based on their foreign key relationships. For example, if
891
+ Image has a foreign key to Subject, and Diagnosis has a foreign key to Image,
892
+ then denormalizing ["Subject", "Image", "Diagnosis"] produces rows where each
893
+ image appears with its subject's metadata and any associated diagnoses.
894
+
895
+ **Column naming:**
896
+
897
+ Column names are prefixed with the source table name using underscores
898
+ to avoid collisions (e.g., "Image_Filename", "Subject_RID").
899
+
900
+ Args:
901
+ include_tables: List of table names to include in the output. Tables
902
+ are joined based on their foreign key relationships.
903
+ Order doesn't matter - the join order is determined automatically.
904
+ version: Dataset version to query. Defaults to current version.
905
+ Use this to get a reproducible snapshot of the data.
906
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
907
+
908
+ Returns:
909
+ pd.DataFrame: Wide table with columns from all included tables.
910
+
911
+ Example:
912
+ Create a training dataset with images and their labels::
913
+
914
+ >>> # Get all images with their diagnoses in one table
915
+ >>> df = dataset.denormalize_as_dataframe(["Image", "Diagnosis"])
916
+ >>> print(df.columns.tolist())
917
+ ['Image_RID', 'Image_Filename', 'Image_URL', 'Diagnosis_RID',
918
+ 'Diagnosis_Label', 'Diagnosis_Confidence']
919
+
920
+ >>> # Use with scikit-learn
921
+ >>> X = df[["Image_Filename"]] # Features
922
+ >>> y = df["Diagnosis_Label"] # Labels
923
+
924
+ Include subject metadata for stratified splitting::
925
+
926
+ >>> df = dataset.denormalize_as_dataframe(
927
+ ... ["Subject", "Image", "Diagnosis"]
928
+ ... )
929
+ >>> # Now df has Subject_Age, Subject_Gender, etc.
930
+ >>> # for stratified train/test splits by subject
931
+
932
+ See Also:
933
+ denormalize_as_dict: Generator version for memory-efficient processing.
934
+ """
935
+ rows = list(self._denormalize_datapath(include_tables, version))
936
+ return pd.DataFrame(rows)
937
+
938
+ def denormalize_as_dict(
939
+ self,
940
+ include_tables: list[str],
941
+ version: DatasetVersion | str | None = None,
942
+ **kwargs: Any,
943
+ ) -> Generator[dict[str, Any], None, None]:
944
+ """Denormalize the dataset and yield rows as dictionaries.
945
+
946
+ This is a memory-efficient alternative to denormalize_as_dataframe() that
947
+ yields one row at a time as a dictionary instead of loading all data into
948
+ a DataFrame. Use this when processing large datasets that may not fit in
949
+ memory, or when you want to process rows incrementally.
950
+
951
+ Like denormalize_as_dataframe(), this produces a "wide table" representation
952
+ where each yielded dictionary contains all columns from the joined tables.
953
+ See denormalize_as_dataframe() for detailed explanation of how denormalization
954
+ works.
955
+
956
+ **Column naming:**
957
+
958
+ Column names are prefixed with the source table name using underscores
959
+ to avoid collisions (e.g., "Image_Filename", "Subject_RID").
960
+
961
+ Args:
962
+ include_tables: List of table names to include in the output.
963
+ Tables are joined based on their foreign key relationships.
964
+ version: Dataset version to query. Defaults to current version.
965
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
966
+
967
+ Yields:
968
+ dict[str, Any]: Dictionary representing one row of the wide table.
969
+ Keys are column names in "Table_Column" format.
970
+
971
+ Example:
972
+ Process images one at a time for training::
973
+
974
+ >>> for row in dataset.denormalize_as_dict(["Image", "Diagnosis"]):
975
+ ... # Load and preprocess each image
976
+ ... img = load_image(row["Image_Filename"])
977
+ ... label = row["Diagnosis_Label"]
978
+ ... yield img, label # Feed to training loop
979
+
980
+ Count labels without loading all data into memory::
981
+
982
+ >>> from collections import Counter
983
+ >>> labels = Counter()
984
+ >>> for row in dataset.denormalize_as_dict(["Image", "Diagnosis"]):
985
+ ... labels[row["Diagnosis_Label"]] += 1
986
+ >>> print(labels)
987
+ Counter({'Normal': 450, 'Abnormal': 150})
988
+
989
+ See Also:
990
+ denormalize_as_dataframe: Returns all data as a pandas DataFrame.
991
+ """
992
+ yield from self._denormalize_datapath(include_tables, version)
993
+
994
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
591
995
  def add_dataset_members(
592
996
  self,
593
- dataset_rid: RID,
594
997
  members: list[RID] | dict[str, list[RID]],
595
998
  validate: bool = True,
596
999
  description: str | None = "",
@@ -598,30 +1001,58 @@ class Dataset:
598
1001
  ) -> None:
599
1002
  """Adds members to a dataset.
600
1003
 
601
- Associates one or more records with a dataset. Can optionally validate member types
602
- and create a new dataset version to track the changes.
1004
+ Associates one or more records with a dataset. Members can be provided in two forms:
1005
+
1006
+ **List of RIDs (simpler but slower):**
1007
+ When `members` is a list of RIDs, each RID is resolved to determine which table
1008
+ it belongs to. This uses batch RID resolution for efficiency, but still requires
1009
+ querying the catalog to identify each RID's table.
1010
+
1011
+ **Dictionary by table name (faster, recommended for large datasets):**
1012
+ When `members` is a dict mapping table names to lists of RIDs, no RID resolution
1013
+ is needed. The RIDs are inserted directly into the dataset. Use this form when
1014
+ you already know which table each RID belongs to.
1015
+
1016
+ **Important:** Members can only be added from tables that have been registered as
1017
+ dataset element types. Use :meth:`DerivaML.add_dataset_element_type` to register
1018
+ a table before adding its records to datasets.
1019
+
1020
+ Adding members automatically increments the dataset's minor version.
603
1021
 
604
1022
  Args:
605
- dataset_rid: Resource Identifier of the dataset.
606
- members: List of RIDs to add as dataset members. Can be orginized into a dictionary that indicates the
607
- table that the member rids belong to.
608
- validate: Whether to validate member types. Defaults to True.
1023
+ members: Either:
1024
+ - list[RID]: List of RIDs to add. Each RID will be resolved to find its table.
1025
+ - dict[str, list[RID]]: Mapping of table names to RID lists. Skips resolution.
1026
+ validate: Whether to validate that members don't already exist. Defaults to True.
609
1027
  description: Optional description of the member additions.
610
1028
  execution_rid: Optional execution RID to associate with changes.
611
1029
 
612
1030
  Raises:
613
1031
  DerivaMLException: If:
614
- - dataset_rid is invalid
615
- - members are invalid or of wrong type
616
- - adding members would create a cycle
617
- - validation fails
618
-
619
- Example:
620
- >>> ml.add_dataset_members(
621
- ... dataset_rid="1-abc123",
622
- ... members=["1-def456", "1-ghi789"],
623
- ... description="Added sample data"
624
- ... )
1032
+ - Any RID is invalid or cannot be resolved
1033
+ - Any RID belongs to a table that isn't registered as a dataset element type
1034
+ - Adding members would create a cycle (for nested datasets)
1035
+ - Validation finds duplicate members (when validate=True)
1036
+
1037
+ See Also:
1038
+ :meth:`DerivaML.add_dataset_element_type`: Register a table as a dataset element type.
1039
+ :meth:`DerivaML.list_dataset_element_types`: List registered dataset element types.
1040
+
1041
+ Examples:
1042
+ Using a list of RIDs (simpler):
1043
+ >>> dataset.add_dataset_members(
1044
+ ... members=["1-ABC", "1-DEF", "1-GHI"],
1045
+ ... description="Added sample images"
1046
+ ... )
1047
+
1048
+ Using a dict by table name (faster for large datasets):
1049
+ >>> dataset.add_dataset_members(
1050
+ ... members={
1051
+ ... "Image": ["1-ABC", "1-DEF"],
1052
+ ... "Subject": ["2-XYZ"]
1053
+ ... },
1054
+ ... description="Added images and subjects"
1055
+ ... )
625
1056
  """
626
1057
  description = description or "Updated dataset via add_dataset_members"
627
1058
 
@@ -635,410 +1066,535 @@ class Dataset:
635
1066
  Returns:
636
1067
 
637
1068
  """
638
- path = path or set(dataset_rid)
1069
+ path = path or set(self.dataset_rid)
639
1070
  return member_rid in path
640
1071
 
641
1072
  if validate:
642
- existing_rids = set(m["RID"] for ms in self.list_dataset_members(dataset_rid).values() for m in ms)
1073
+ existing_rids = set(m["RID"] for ms in self.list_dataset_members().values() for m in ms)
643
1074
  if overlap := set(existing_rids).intersection(members):
644
- raise DerivaMLException(f"Attempting to add existing member to dataset_table {dataset_rid}: {overlap}")
1075
+ raise DerivaMLException(
1076
+ f"Attempting to add existing member to dataset_table {self.dataset_rid}: {overlap}"
1077
+ )
645
1078
 
646
1079
  # Now go through every rid to be added to the data set and sort them based on what association table entries
647
1080
  # need to be made.
648
- dataset_elements = {}
649
- association_map = {
650
- a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
651
- }
1081
+ dataset_elements: dict[str, list[RID]] = {}
1082
+
1083
+ # Build map of valid element tables to their association tables
1084
+ associations = list(self._dataset_table.find_associations())
1085
+ association_map = {a.other_fkeys.pop().pk_table.name: a.table.name for a in associations}
652
1086
 
653
1087
  # Get a list of all the object types that can be linked to a dataset_table.
654
1088
  if type(members) is list:
655
1089
  members = set(members)
656
- for m in members:
657
- try:
658
- rid_info = self._model.catalog.resolve_rid(m)
659
- except KeyError:
660
- raise DerivaMLException(f"Invalid RID: {m}")
661
- if rid_info.table.name not in association_map:
662
- raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
1090
+
1091
+ # Get candidate tables for batch resolution (only tables that can be dataset elements)
1092
+ candidate_tables = [
1093
+ self._ml_instance.model.name_to_table(table_name) for table_name in association_map.keys()
1094
+ ]
1095
+
1096
+ # Batch resolve all RIDs at once instead of one-by-one
1097
+ rid_results = self._ml_instance.resolve_rids(members, candidate_tables=candidate_tables)
1098
+
1099
+ # Group by table and validate
1100
+ for rid, rid_info in rid_results.items():
1101
+ if rid_info.table_name not in association_map:
1102
+ raise DerivaMLException(f"RID table: {rid_info.table_name} not part of dataset_table")
663
1103
  if rid_info.table == self._dataset_table and check_dataset_cycle(rid_info.rid):
664
1104
  raise DerivaMLException("Creating cycle of datasets is not allowed")
665
- dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
1105
+ dataset_elements.setdefault(rid_info.table_name, []).append(rid_info.rid)
666
1106
  else:
667
- dataset_elements = {t: set(ms) for t, ms in members.items()}
1107
+ dataset_elements = {t: list(set(ms)) for t, ms in members.items()}
668
1108
  # Now make the entries into the association tables.
669
- pb = self._model.catalog.getPathBuilder()
1109
+ pb = self._ml_instance.pathBuilder()
670
1110
  for table, elements in dataset_elements.items():
671
- schema_path = pb.schemas[
672
- self._ml_schema if (table == "Dataset" or table == "File") else self._model.domain_schema
673
- ]
1111
+ # Determine schema: ML schema for Dataset/File, otherwise use the table's actual schema
1112
+ if table == "Dataset" or table == "File":
1113
+ schema_name = self._ml_instance.ml_schema
1114
+ else:
1115
+ # Find the table and use its schema
1116
+ table_obj = self._ml_instance.model.name_to_table(table)
1117
+ schema_name = table_obj.schema.name
1118
+ schema_path = pb.schemas[schema_name]
674
1119
  fk_column = "Nested_Dataset" if table == "Dataset" else table
675
1120
  if len(elements):
676
1121
  # Find out the name of the column in the association table.
677
1122
  schema_path.tables[association_map[table]].insert(
678
- [{"Dataset": dataset_rid, fk_column: e} for e in elements]
1123
+ [{"Dataset": self.dataset_rid, fk_column: e} for e in elements]
679
1124
  )
680
1125
  self.increment_dataset_version(
681
- dataset_rid,
682
1126
  VersionPart.minor,
683
1127
  description=description,
684
1128
  execution_rid=execution_rid,
685
1129
  )
686
1130
 
687
- @validate_call
1131
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
688
1132
  def delete_dataset_members(
689
1133
  self,
690
- dataset_rid: RID,
691
1134
  members: list[RID],
692
1135
  description: str = "",
693
1136
  execution_rid: RID | None = None,
694
1137
  ) -> None:
695
- """Remove elements to an existing dataset_table.
1138
+ """Remove members from this dataset.
696
1139
 
697
- Delete elements from an existing dataset. In addition to deleting members, the minor version number of the
698
- dataset is incremented and the description, if provide is applied to that new version.
1140
+ Removes the specified members from the dataset. In addition to removing members,
1141
+ the minor version number of the dataset is incremented and the description,
1142
+ if provided, is applied to that new version.
699
1143
 
700
1144
  Args:
701
- dataset_rid: RID of dataset_table to extend or None if a new dataset_table is to be created.
702
- members: List of member RIDs to add to the dataset_table.
703
- description: Markdown description of the updated dataset.
1145
+ members: List of member RIDs to remove from the dataset.
1146
+ description: Optional description of the removal operation.
704
1147
  execution_rid: Optional RID of execution associated with this operation.
705
- """
706
1148
 
1149
+ Raises:
1150
+ DerivaMLException: If any RID is invalid or not part of this dataset.
1151
+
1152
+ Example:
1153
+ >>> dataset.delete_dataset_members(
1154
+ ... members=["1-ABC", "1-DEF"],
1155
+ ... description="Removed corrupted samples"
1156
+ ... )
1157
+ """
707
1158
  members = set(members)
708
- description = description or "Deletes dataset members"
1159
+ description = description or "Deleted dataset members"
709
1160
 
710
- # Now go through every rid to be added to the data set and sort them based on what association table entries
711
- # need to be made.
1161
+ # Go through every rid to be deleted and sort them based on what association table entries
1162
+ # need to be removed.
712
1163
  dataset_elements = {}
713
1164
  association_map = {
714
1165
  a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
715
1166
  }
716
- # Get a list of all the object types that can be linked to a dataset_table.
1167
+ # Get a list of all the object types that can be linked to a dataset.
717
1168
  for m in members:
718
1169
  try:
719
- rid_info = self._model.catalog.resolve_rid(m)
1170
+ rid_info = self._ml_instance.resolve_rid(m)
720
1171
  except KeyError:
721
1172
  raise DerivaMLException(f"Invalid RID: {m}")
722
1173
  if rid_info.table.name not in association_map:
723
- raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
1174
+ raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset")
724
1175
  dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
725
- # Now make the entries into the association tables.
726
- pb = self._model.catalog.getPathBuilder()
1176
+
1177
+ # Delete the entries from the association tables.
1178
+ pb = self._ml_instance.pathBuilder()
727
1179
  for table, elements in dataset_elements.items():
728
- schema_path = pb.schemas[self._ml_schema if table == "Dataset" else self._model.domain_schema]
1180
+ # Determine schema: ML schema for Dataset, otherwise use the table's actual schema
1181
+ if table == "Dataset":
1182
+ schema_name = self._ml_instance.ml_schema
1183
+ else:
1184
+ # Find the table and use its schema
1185
+ table_obj = self._ml_instance.model.name_to_table(table)
1186
+ schema_name = table_obj.schema.name
1187
+ schema_path = pb.schemas[schema_name]
729
1188
  fk_column = "Nested_Dataset" if table == "Dataset" else table
730
1189
 
731
1190
  if len(elements):
732
1191
  atable_path = schema_path.tables[association_map[table]]
733
- # Find out the name of the column in the association table.
734
1192
  for e in elements:
735
1193
  entity = atable_path.filter(
736
- (atable_path.Dataset == dataset_rid) & (atable_path.columns[fk_column] == e),
1194
+ (atable_path.Dataset == self.dataset_rid) & (atable_path.columns[fk_column] == e),
737
1195
  )
738
1196
  entity.delete()
1197
+
739
1198
  self.increment_dataset_version(
740
- dataset_rid,
741
1199
  VersionPart.minor,
742
1200
  description=description,
743
1201
  execution_rid=execution_rid,
744
1202
  )
745
1203
 
746
- @validate_call
747
- def list_dataset_parents(self, dataset_rid: RID) -> list[str]:
1204
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
1205
+ def list_dataset_parents(
1206
+ self,
1207
+ recurse: bool = False,
1208
+ _visited: set[RID] | None = None,
1209
+ version: DatasetVersion | str | None = None,
1210
+ **kwargs: Any,
1211
+ ) -> list[Self]:
748
1212
  """Given a dataset_table RID, return a list of RIDs of the parent datasets if this is included in a
749
1213
  nested dataset.
750
1214
 
751
1215
  Args:
752
- dataset_rid: return: RID of the parent dataset_table.
1216
+ recurse: If True, recursively return all ancestor datasets.
1217
+ _visited: Internal parameter to track visited datasets and prevent infinite recursion.
1218
+ version: Dataset version to list parents from. Defaults to the current version.
1219
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
753
1220
 
754
1221
  Returns:
755
- RID of the parent dataset_table.
1222
+ List of parent datasets.
756
1223
  """
757
- if not self._is_dataset_rid(dataset_rid):
758
- raise DerivaMLException(f"RID: {dataset_rid} does not belong to dataset_table {self._dataset_table.name}")
1224
+ # Initialize visited set for recursion guard
1225
+ if _visited is None:
1226
+ _visited = set()
1227
+
1228
+ # Prevent infinite recursion by checking if we've already visited this dataset
1229
+ if self.dataset_rid in _visited:
1230
+ return []
1231
+ _visited.add(self.dataset_rid)
1232
+
759
1233
  # Get association table for nested datasets
760
- pb = self._model.catalog.getPathBuilder()
761
- atable_path = pb.schemas[self._ml_schema].Dataset_Dataset
762
- return [p["Dataset"] for p in atable_path.filter(atable_path.Nested_Dataset == dataset_rid).entities().fetch()]
1234
+ version_snapshot_catalog = self._version_snapshot_catalog(version)
1235
+ pb = version_snapshot_catalog.pathBuilder()
1236
+ atable_path = pb.schemas[self._ml_instance.ml_schema].Dataset_Dataset
1237
+ parents = [
1238
+ version_snapshot_catalog.lookup_dataset(p["Dataset"])
1239
+ for p in atable_path.filter(atable_path.Nested_Dataset == self.dataset_rid).entities().fetch()
1240
+ ]
1241
+ if recurse:
1242
+ for parent in parents.copy():
1243
+ parents.extend(parent.list_dataset_parents(recurse=True, _visited=_visited, version=version))
1244
+ return parents
763
1245
 
764
- @validate_call
765
- def list_dataset_children(self, dataset_rid: RID, recurse: bool = False) -> list[RID]:
1246
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
1247
+ def list_dataset_children(
1248
+ self,
1249
+ recurse: bool = False,
1250
+ _visited: set[RID] | None = None,
1251
+ version: DatasetVersion | str | None = None,
1252
+ **kwargs: Any,
1253
+ ) -> list[Self]:
766
1254
  """Given a dataset_table RID, return a list of RIDs for any nested datasets.
767
1255
 
768
1256
  Args:
769
- dataset_rid: A dataset_table RID.
770
1257
  recurse: If True, return a list of nested datasets RIDs.
1258
+ _visited: Internal parameter to track visited datasets and prevent infinite recursion.
1259
+ version: Dataset version to list children from. Defaults to the current version.
1260
+ **kwargs: Additional arguments (ignored, for protocol compatibility).
771
1261
 
772
1262
  Returns:
773
1263
  list of nested dataset RIDs.
774
1264
 
775
1265
  """
776
- dataset_dataset_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
1266
+ # Initialize visited set for recursion guard
1267
+ if _visited is None:
1268
+ _visited = set()
1269
+
1270
+ version = DatasetVersion.parse(version) if isinstance(version, str) else version
1271
+ version_snapshot_catalog = self._version_snapshot_catalog(version)
1272
+ dataset_dataset_path = (
1273
+ version_snapshot_catalog.pathBuilder().schemas[self._ml_instance.ml_schema].tables["Dataset_Dataset"]
1274
+ )
777
1275
  nested_datasets = list(dataset_dataset_path.entities().fetch())
778
1276
 
779
- def find_children(rid: RID):
1277
+ def find_children(rid: RID) -> list[RID]:
1278
+ # Prevent infinite recursion by checking if we've already visited this dataset
1279
+ if rid in _visited:
1280
+ return []
1281
+ _visited.add(rid)
1282
+
780
1283
  children = [child["Nested_Dataset"] for child in nested_datasets if child["Dataset"] == rid]
781
1284
  if recurse:
782
1285
  for child in children.copy():
783
1286
  children.extend(find_children(child))
784
1287
  return children
785
1288
 
786
- return find_children(dataset_rid)
1289
+ return [version_snapshot_catalog.lookup_dataset(rid) for rid in find_children(self.dataset_rid)]
787
1290
 
788
- def _export_vocabulary(self, writer: Callable[[str, str, Table], list[dict[str, Any]]]) -> list[dict[str, Any]]:
789
- """
1291
+ def _list_dataset_parents_current(self) -> list[Self]:
1292
+ """Return parent datasets using current catalog state (not version snapshot).
790
1293
 
791
- Args:
792
- writer: Callable[[list[Table]]: list[dict[str: Any]]]:
1294
+ Used by _build_dataset_graph_1 to find all related datasets for version updates.
1295
+ """
1296
+ pb = self._ml_instance.pathBuilder()
1297
+ atable_path = pb.schemas[self._ml_instance.ml_schema].Dataset_Dataset
1298
+ return [
1299
+ self._ml_instance.lookup_dataset(p["Dataset"])
1300
+ for p in atable_path.filter(atable_path.Nested_Dataset == self.dataset_rid).entities().fetch()
1301
+ ]
793
1302
 
794
- Returns:
1303
+ def _list_dataset_children_current(self) -> list[Self]:
1304
+ """Return child datasets using current catalog state (not version snapshot).
795
1305
 
1306
+ Used by _build_dataset_graph_1 to find all related datasets for version updates.
796
1307
  """
797
- vocabs = [
798
- table
799
- for s in self._model.schemas.values()
800
- for table in s.tables.values()
801
- if self._model.is_vocabulary(table)
802
- ]
803
- return [o for table in vocabs for o in writer(f"{table.schema.name}:{table.name}", table.name, table)]
1308
+ dataset_dataset_path = (
1309
+ self._ml_instance.pathBuilder().schemas[self._ml_instance.ml_schema].tables["Dataset_Dataset"]
1310
+ )
1311
+ nested_datasets = list(dataset_dataset_path.entities().fetch())
804
1312
 
805
- def _table_paths(
806
- self,
807
- dataset: DatasetSpec | None = None,
808
- snapshot_catalog: DerivaML | None = None,
809
- ) -> Iterator[tuple[str, str, Table]]:
810
- paths = self._collect_paths(dataset and dataset.rid, snapshot_catalog)
811
-
812
- def source_path(path: tuple[Table, ...]) -> list[str]:
813
- """Convert a tuple representing a path into a source path component with FK linkage"""
814
- path = list(path)
815
- p = [f"{self._model.ml_schema}:Dataset/RID={{RID}}"]
816
- for table in path[1:]:
817
- if table.name == "Dataset_Dataset":
818
- p.append("(RID)=(deriva-ml:Dataset_Dataset:Dataset)")
819
- elif table.name == "Dataset":
820
- p.append("(Nested_Dataset)=(deriva-ml:Dataset:RID)")
821
- elif table.name == "Dataset_Version":
822
- p.append(f"(RID)=({self._model.ml_schema}:Dataset_Version:Dataset)")
823
- else:
824
- p.append(f"{table.schema.name}:{table.name}")
825
- return p
826
-
827
- src_paths = ["/".join(source_path(p)) for p in paths]
828
- dest_paths = ["/".join([t.name for t in p]) for p in paths]
829
- target_tables = [p[-1] for p in paths]
830
- return zip(src_paths, dest_paths, target_tables)
831
-
832
- def _collect_paths(
833
- self,
834
- dataset_rid: RID | None = None,
835
- snapshot: Dataset | None = None,
836
- dataset_nesting_depth: int | None = None,
837
- ) -> set[tuple[Table, ...]]:
838
- snapshot_catalog = snapshot if snapshot else self
839
-
840
- dataset_table = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset"]
841
- dataset_dataset = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset_Dataset"]
842
-
843
- # Figure out what types of elements the dataset contains.
844
- dataset_associations = [
845
- a
846
- for a in self._dataset_table.find_associations()
847
- if a.table.schema.name != self._ml_schema or a.table.name == "Dataset_Dataset"
848
- ]
849
- if dataset_rid:
850
- # Get a list of the members of the dataset so we can figure out which tables to query.
851
- dataset_elements = [
852
- snapshot_catalog._model.name_to_table(e)
853
- for e, m in snapshot_catalog.list_dataset_members(
854
- dataset_rid=dataset_rid, # limit=1 Limit seems to make things run slow.
855
- ).items()
856
- if m
857
- ]
858
- included_associations = [
859
- a.table for a in dataset_table.find_associations() if a.other_fkeys.pop().pk_table in dataset_elements
860
- ]
861
- else:
862
- included_associations = dataset_associations
863
-
864
- # Get the paths through the schema and filter out all the dataset paths not used by this dataset.
865
- paths = {
866
- tuple(p)
867
- for p in snapshot_catalog._model._schema_to_paths()
868
- if (len(p) == 1)
869
- or (p[1] not in dataset_associations) # Tables in the domain schema
870
- or (p[1] in included_associations) # Tables that include members of the dataset
871
- }
872
- # Now get paths for nested datasets
873
- nested_paths = set()
874
- if dataset_rid:
875
- for c in snapshot_catalog.list_dataset_children(dataset_rid=dataset_rid):
876
- nested_paths |= self._collect_paths(c, snapshot=snapshot_catalog)
877
- else:
878
- # Initialize nesting depth if not already provided.
879
- dataset_nesting_depth = (
880
- self._dataset_nesting_depth() if dataset_nesting_depth is None else dataset_nesting_depth
881
- )
882
- if dataset_nesting_depth:
883
- nested_paths = self._collect_paths(dataset_nesting_depth=dataset_nesting_depth - 1)
884
- if nested_paths:
885
- paths |= {
886
- tuple([dataset_table]),
887
- (dataset_table, dataset_dataset),
888
- }
889
- paths |= {(self._dataset_table, dataset_dataset) + p for p in nested_paths}
890
- return paths
891
-
892
- def _dataset_nesting_depth(self, dataset_rid: RID | None = None) -> int:
893
- """Determine the maximum dataset nesting depth in the current catalog.
1313
+ def find_children(rid: RID) -> list[RID]:
1314
+ return [child["Nested_Dataset"] for child in nested_datasets if child["Dataset"] == rid]
1315
+
1316
+ return [self._ml_instance.lookup_dataset(rid) for rid in find_children(self.dataset_rid)]
1317
+
1318
+ def list_executions(self) -> list["Execution"]:
1319
+ """List all executions associated with this dataset.
1320
+
1321
+ Returns all executions that used this dataset as input. This is
1322
+ tracked through the Dataset_Execution association table.
894
1323
 
895
1324
  Returns:
1325
+ List of Execution objects associated with this dataset.
896
1326
 
1327
+ Example:
1328
+ >>> dataset = ml.lookup_dataset("1-abc123")
1329
+ >>> executions = dataset.list_executions()
1330
+ >>> for exe in executions:
1331
+ ... print(f"Execution {exe.execution_rid}: {exe.status}")
897
1332
  """
1333
+ # Import here to avoid circular dependency
1334
+ from deriva_ml.execution.execution import Execution
898
1335
 
899
- def children_depth(dataset_rid: RID, nested_datasets: dict[str, list[str]]) -> int:
900
- """Return the number of nested datasets for the dataset_rid if provided, otherwise in the current catalog"""
901
- try:
902
- children = nested_datasets[dataset_rid]
903
- return max(map(lambda x: children_depth(x, nested_datasets), children)) + 1 if children else 1
904
- except KeyError:
905
- return 0
1336
+ pb = self._ml_instance.pathBuilder()
1337
+ dataset_execution_path = pb.schemas[self._ml_instance.ml_schema].Dataset_Execution
1338
+
1339
+ # Query for all executions associated with this dataset
1340
+ records = list(
1341
+ dataset_execution_path.filter(dataset_execution_path.Dataset == self.dataset_rid)
1342
+ .entities()
1343
+ .fetch()
1344
+ )
1345
+
1346
+ return [self._ml_instance.lookup_execution(record["Execution"]) for record in records]
1347
+
1348
+ @staticmethod
1349
+ def _insert_dataset_versions(
1350
+ ml_instance: DerivaMLCatalog,
1351
+ dataset_list: list[DatasetSpec],
1352
+ description: str | None = "",
1353
+ execution_rid: RID | None = None,
1354
+ ) -> None:
1355
+ """Insert new version records for a list of datasets.
1356
+
1357
+ This internal method creates Dataset_Version records in the catalog for
1358
+ each dataset in the list. It also captures a catalog snapshot timestamp
1359
+ to associate with these versions.
1360
+
1361
+ The version record links:
1362
+ - The dataset RID to its new version number
1363
+ - An optional description of what changed
1364
+ - An optional execution that triggered the version change
1365
+ - The catalog snapshot time for reproducibility
1366
+
1367
+ Args:
1368
+ ml_instance: The catalog instance to insert versions into.
1369
+ dataset_list: List of DatasetSpec objects containing RID and version info.
1370
+ description: Optional description of the version change.
1371
+ execution_rid: Optional execution RID to associate with the version.
1372
+ """
1373
+ schema_path = ml_instance.pathBuilder().schemas[ml_instance.ml_schema]
906
1374
 
907
- # Build up the dataset_table nesting graph...
908
- pb = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
909
- dataset_children = (
1375
+ # Insert version records for all datasets in the list
1376
+ version_records = schema_path.tables["Dataset_Version"].insert(
910
1377
  [
911
1378
  {
912
- "Dataset": dataset_rid,
913
- "Nested_Dataset": c,
914
- } # Make uniform with return from datapath
915
- for c in self.list_dataset_children(dataset_rid=dataset_rid)
1379
+ "Dataset": dataset.rid,
1380
+ "Version": str(dataset.version),
1381
+ "Description": description,
1382
+ "Execution": execution_rid,
1383
+ }
1384
+ for dataset in dataset_list
916
1385
  ]
917
- if dataset_rid
918
- else pb.entities().fetch()
919
1386
  )
920
- nested_dataset = defaultdict(list)
921
- for ds in dataset_children:
922
- nested_dataset[ds["Dataset"]].append(ds["Nested_Dataset"])
923
- return max(map(lambda d: children_depth(d, dict(nested_dataset)), nested_dataset)) if nested_dataset else 0
1387
+ version_records = list(version_records)
924
1388
 
925
- def _dataset_specification(
926
- self,
927
- writer: Callable[[str, str, Table], list[dict[str, Any]]],
928
- dataset: DatasetSpec | None = None,
929
- snapshot_catalog: DerivaML | None = None,
930
- ) -> list[dict[str, Any]]:
931
- """Output a download/export specification for a dataset_table. Each element of the dataset_table
932
- will be placed in its own directory.
933
- The top level data directory of the resulting BDBag will have one subdirectory for element type.
934
- The subdirectory will contain the CSV indicating which elements of that type are present in the
935
- dataset_table, and then there will be a subdirectory for each object that is reachable from the
936
- dataset_table members.
937
-
938
- To simplify reconstructing the relationship between tables, the CVS for each element is included.
939
- The top level data directory will also contain a subdirectory for any controlled vocabularies used in
940
- the dataset_table. All assets will be placed into a directory named asset in a subdirectory with the
941
- asset table name.
942
-
943
- For example, consider a dataset_table that consists of two element types, T1 and T2. T1 has foreign
944
- key relationships to objects in tables T3 and T4. There are also two controlled vocabularies, CV1 and
945
- CV2. T2 is an asset table which has two assets in it. The layout of the resulting bdbag would be:
946
- data
947
- CV1/
948
- cv1.csv
949
- CV2/
950
- cv2.csv
951
- Dataset/
952
- T1/
953
- t1.csv
954
- T3/
955
- t3.csv
956
- T4/
957
- t4.csv
958
- T2/
959
- t2.csv
960
- asset/
961
- T2
962
- f1
963
- f2
1389
+ # Capture the current catalog snapshot timestamp. This allows us to
1390
+ # recreate the exact state of the catalog when this version was created.
1391
+ snap = ml_instance.catalog.get("/").json()["snaptime"]
964
1392
 
965
- Args:
966
- writer: Callable[[list[Table]]: list[dict[str: Any]]]:
1393
+ # Update version records with the snapshot timestamp
1394
+ schema_path.tables["Dataset_Version"].update(
1395
+ [{"RID": v["RID"], "Dataset": v["Dataset"], "Snapshot": snap} for v in version_records]
1396
+ )
967
1397
 
968
- Returns:
969
- A dataset_table specification.
970
- """
971
- element_spec = self._export_vocabulary(writer)
972
- for path in self._table_paths(dataset=dataset, snapshot_catalog=snapshot_catalog):
973
- element_spec.extend(writer(*path))
974
- return element_spec
1398
+ # Update each dataset's current version pointer to the new version record
1399
+ schema_path.tables["Dataset"].update([{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records])
975
1400
 
976
- def _download_dataset_bag(
1401
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
1402
+ def download_dataset_bag(
977
1403
  self,
978
- dataset: DatasetSpec,
979
- execution_rid: RID | None = None,
980
- snapshot_catalog: DerivaML | None = None,
1404
+ version: DatasetVersion | str,
1405
+ materialize: bool = True,
1406
+ use_minid: bool = False,
981
1407
  ) -> DatasetBag:
982
- """Download a dataset onto the local file system. Create a MINID for the dataset if one doesn't already exist.
1408
+ """Downloads a dataset to the local filesystem and optionally creates a MINID.
1409
+
1410
+ Downloads a dataset to the local file system. If the dataset has a version set, that version is used.
1411
+ If the dataset has a version and a version is provided, the version specified takes precedence.
983
1412
 
984
1413
  Args:
985
- dataset: Specification of the dataset to be downloaded.
986
- execution_rid: Execution RID for the dataset.
987
- snapshot_catalog: Snapshot catalog for the dataset version if specified.
1414
+ version: Dataset version to download. If not specified, the version must be set in the dataset.
1415
+ materialize: If True, materialize the dataset after downloading.
1416
+ use_minid: If True, upload the bag to S3 and create a MINID for the dataset.
1417
+ Requires s3_bucket to be configured on the catalog. Defaults to False.
988
1418
 
989
1419
  Returns:
990
- Tuple consisting of the path to the dataset, the RID of the dataset that was downloaded and the MINID
991
- for the dataset.
1420
+ DatasetBag: Object containing:
1421
+ - path: Local filesystem path to downloaded dataset
1422
+ - rid: Dataset's Resource Identifier
1423
+ - minid: Dataset's Minimal Viable Identifier (if use_minid=True)
1424
+
1425
+ Raises:
1426
+ DerivaMLException: If use_minid=True but s3_bucket is not configured on the catalog.
1427
+
1428
+ Examples:
1429
+ Download without MINID (default):
1430
+ >>> bag = dataset.download_dataset_bag(version="1.0.0")
1431
+ >>> print(f"Downloaded to {bag.path}")
1432
+
1433
+ Download with MINID (requires s3_bucket configured):
1434
+ >>> # Catalog must be created with s3_bucket="s3://my-bucket"
1435
+ >>> bag = dataset.download_dataset_bag(version="1.0.0", use_minid=True)
992
1436
  """
993
- if (
994
- execution_rid
995
- and execution_rid != DRY_RUN_RID
996
- and self._model.catalog.resolve_rid(execution_rid).table.name != "Execution"
997
- ):
998
- raise DerivaMLException(f"RID {execution_rid} is not an execution")
999
- minid = self._get_dataset_minid(dataset, snapshot_catalog=snapshot_catalog)
1437
+ if isinstance(version, str):
1438
+ version = DatasetVersion.parse(version)
1439
+
1440
+ # Validate use_minid requires s3_bucket configuration
1441
+ if use_minid and not self._ml_instance.s3_bucket:
1442
+ raise DerivaMLException(
1443
+ "Cannot use use_minid=True without s3_bucket configured. "
1444
+ "Configure s3_bucket when creating the DerivaML instance to enable MINID support."
1445
+ )
1446
+
1447
+ minid = self._get_dataset_minid(version, create=True, use_minid=use_minid)
1000
1448
 
1001
1449
  bag_path = (
1002
- self._materialize_dataset_bag(minid, execution_rid=execution_rid)
1003
- if dataset.materialize
1004
- else self._download_dataset_minid(minid)
1450
+ self._materialize_dataset_bag(minid, use_minid=use_minid)
1451
+ if materialize
1452
+ else self._download_dataset_minid(minid, use_minid)
1005
1453
  )
1006
- return DatabaseModel(minid, bag_path, self._working_dir).get_dataset()
1454
+ from deriva_ml.model.deriva_ml_database import DerivaMLDatabase
1455
+ db_model = DatabaseModel(minid, bag_path, self._ml_instance.working_dir)
1456
+ return DerivaMLDatabase(db_model).lookup_dataset(self.dataset_rid)
1457
+
1458
+ def _version_snapshot_catalog(self, dataset_version: DatasetVersion | str | None) -> DerivaMLCatalog:
1459
+ """Get a catalog instance bound to a specific version's snapshot.
1460
+
1461
+ Dataset versions are associated with catalog snapshots, which represent
1462
+ the exact state of the catalog at the time the version was created.
1463
+ This method returns a catalog instance that queries against that snapshot,
1464
+ ensuring reproducible access to historical data.
1465
+
1466
+ Args:
1467
+ dataset_version: The version to get a snapshot for, or None to use
1468
+ the current catalog state.
1469
+
1470
+ Returns:
1471
+ DerivaMLCatalog: Either a snapshot-bound catalog or the current catalog.
1472
+ """
1473
+ if isinstance(dataset_version, str) and str:
1474
+ dataset_version = DatasetVersion.parse(dataset_version)
1475
+ if dataset_version:
1476
+ return self._ml_instance.catalog_snapshot(self._version_snapshot_catalog_id(dataset_version))
1477
+ else:
1478
+ return self._ml_instance
1479
+
1480
+ def _version_snapshot_catalog_id(self, version: DatasetVersion | str) -> str:
1481
+ """Get the catalog ID with snapshot suffix for a specific version.
1482
+
1483
+ Constructs a catalog identifier in the format "catalog_id@snapshot_time"
1484
+ that can be used to access the catalog state at the time the version
1485
+ was created.
1486
+
1487
+ Args:
1488
+ version: The dataset version to get the snapshot for.
1489
+
1490
+ Returns:
1491
+ str: Catalog ID with snapshot suffix (e.g., "1@2023-01-15T10:30:00").
1007
1492
 
1008
- def _version_snapshot(self, dataset: DatasetSpec) -> str:
1009
- """Return a catalog with snapshot for the specified dataset version"""
1493
+ Raises:
1494
+ DerivaMLException: If the specified version doesn't exist.
1495
+ """
1496
+ version = str(version)
1010
1497
  try:
1011
- version_record = next(
1012
- h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1013
- )
1498
+ version_record = next(h for h in self.dataset_history() if h.dataset_version == version)
1014
1499
  except StopIteration:
1015
- raise DerivaMLException(f"Dataset version {dataset.version} not found for dataset {dataset.rid}")
1016
- return f"{self._model.catalog.catalog_id}@{version_record.snapshot}"
1500
+ raise DerivaMLException(f"Dataset version {version} not found for dataset {self.dataset_rid}")
1501
+ return (
1502
+ f"{self._ml_instance.catalog.catalog_id}@{version_record.snapshot}"
1503
+ if version_record.snapshot
1504
+ else self._ml_instance.catalog.catalog_id
1505
+ )
1506
+
1507
+ def _download_dataset_minid(self, minid: DatasetMinid, use_minid: bool) -> Path:
1508
+ """Download and extract a dataset bag from a MINID or direct URL.
1509
+
1510
+ This method handles the download of a BDBag archive, either from S3 storage
1511
+ (if using MINIDs) or directly from the catalog server. Downloaded bags are
1512
+ cached by checksum to avoid redundant downloads.
1513
+
1514
+ Args:
1515
+ minid: DatasetMinid containing the bag URL and metadata.
1516
+ use_minid: If True, download from S3 using the MINID URL.
1517
+ If False, download directly from the catalog server.
1518
+
1519
+ Returns:
1520
+ Path: The path to the extracted and validated bag directory.
1521
+
1522
+ Note:
1523
+ Bags are cached in the cache_dir with the naming convention:
1524
+ "{dataset_rid}_{checksum}/Dataset_{dataset_rid}"
1525
+ """
1526
+
1527
+ # Check to see if we have an existing idempotent materialization of the desired bag. If so, then reuse
1528
+ # it. If not, then we need to extract the contents of the archive into our cache directory.
1529
+ bag_dir = self._ml_instance.cache_dir / f"{minid.dataset_rid}_{minid.checksum}"
1530
+ if bag_dir.exists():
1531
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1532
+ return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1533
+
1534
+ # Either bag hasn't been downloaded yet, or we are not using a Minid, so we don't know the checksum yet.
1535
+ with TemporaryDirectory() as tmp_dir:
1536
+ if use_minid:
1537
+ # Get bag from S3
1538
+ bag_path = Path(tmp_dir) / Path(urlparse(minid.bag_url).path).name
1539
+ archive_path = fetch_single_file(minid.bag_url, output_path=bag_path)
1540
+ else:
1541
+ exporter = DerivaExport(host=self._ml_instance.catalog.deriva_server.server, output_dir=tmp_dir)
1542
+ archive_path = exporter.retrieve_file(minid.bag_url)
1543
+ hashes = hash_utils.compute_file_hashes(archive_path, hashes=["md5", "sha256"])
1544
+ checksum = hashes["sha256"][0]
1545
+ bag_dir = self._ml_instance.cache_dir / f"{minid.dataset_rid}_{checksum}"
1546
+ if bag_dir.exists():
1547
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1548
+ return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1549
+ bag_path = bdb.extract_bag(archive_path, bag_dir.as_posix())
1550
+ bdb.validate_bag_structure(bag_path)
1551
+ return Path(bag_path)
1552
+
1553
+ def _create_dataset_minid(self, version: DatasetVersion, use_minid=True) -> str:
1554
+ """Create a new MINID (Minimal Viable Identifier) for the dataset.
1555
+
1556
+ This method generates a BDBag export of the dataset and optionally
1557
+ registers it with a MINID service for persistent identification.
1558
+ The bag is uploaded to S3 storage when using MINIDs.
1559
+
1560
+ Args:
1561
+ version: The dataset version to create a MINID for.
1562
+ use_minid: If True, register with MINID service and upload to S3.
1563
+ If False, just generate the bag and return a local URL.
1017
1564
 
1018
- def _create_dataset_minid(self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None) -> str:
1565
+ Returns:
1566
+ str: URL to the MINID landing page (if use_minid=True) or
1567
+ the direct bag download URL.
1568
+ """
1019
1569
  with TemporaryDirectory() as tmp_dir:
1020
1570
  # Generate a download specification file for the current catalog schema. By default, this spec
1021
1571
  # will generate a minid and place the bag into S3 storage.
1022
1572
  spec_file = Path(tmp_dir) / "download_spec.json"
1573
+ version_snapshot_catalog = self._version_snapshot_catalog(version)
1023
1574
  with spec_file.open("w", encoding="utf-8") as ds:
1024
- json.dump(self._generate_dataset_download_spec(dataset, snapshot_catalog), ds)
1575
+ downloader = CatalogGraph(
1576
+ version_snapshot_catalog,
1577
+ s3_bucket=self._ml_instance.s3_bucket,
1578
+ use_minid=use_minid,
1579
+ )
1580
+ json.dump(downloader.generate_dataset_download_spec(self), ds)
1025
1581
  try:
1026
1582
  self._logger.info(
1027
1583
  "Downloading dataset %s for catalog: %s@%s"
1028
1584
  % (
1029
- "minid" if self._use_minid else "bag",
1030
- dataset.rid,
1031
- str(dataset.version),
1585
+ "minid" if use_minid else "bag",
1586
+ self.dataset_rid,
1587
+ str(version),
1032
1588
  )
1033
1589
  )
1034
1590
  # Generate the bag and put into S3 storage.
1035
1591
  exporter = DerivaExport(
1036
- host=self._model.catalog.deriva_server.server,
1592
+ host=self._ml_instance.catalog.deriva_server.server,
1037
1593
  config_file=spec_file,
1038
1594
  output_dir=tmp_dir,
1039
1595
  defer_download=True,
1040
1596
  timeout=(10, 610),
1041
- envars={"RID": dataset.rid},
1597
+ envars={"RID": self.dataset_rid},
1042
1598
  )
1043
1599
  minid_page_url = exporter.export()[0] # Get the MINID launch page
1044
1600
  except (
@@ -1050,131 +1606,117 @@ class Dataset:
1050
1606
  ) as e:
1051
1607
  raise DerivaMLException(format_exception(e))
1052
1608
  # Update version table with MINID.
1053
- if self._use_minid:
1054
- version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
1055
- version_rid = [
1056
- h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1057
- ][0].version_rid
1609
+ if use_minid:
1610
+ version_path = (
1611
+ self._ml_instance.pathBuilder().schemas[self._ml_instance.ml_schema].tables["Dataset_Version"]
1612
+ )
1613
+ version_rid = [h for h in self.dataset_history() if h.dataset_version == version][0].version_rid
1058
1614
  version_path.update([{"RID": version_rid, "Minid": minid_page_url}])
1059
1615
  return minid_page_url
1060
1616
 
1061
1617
  def _get_dataset_minid(
1062
1618
  self,
1063
- dataset: DatasetSpec,
1064
- snapshot_catalog: DerivaML | None = None,
1065
- create: bool = True,
1619
+ version: DatasetVersion,
1620
+ create: bool,
1621
+ use_minid: bool,
1066
1622
  ) -> DatasetMinid | None:
1067
- """Return a MINID for the specified dataset. If no version is specified, use the latest.
1623
+ """Get or create a MINID for the specified dataset version.
1624
+
1625
+ This method retrieves the MINID associated with a specific dataset version,
1626
+ optionally creating one if it doesn't exist.
1068
1627
 
1069
1628
  Args:
1070
- dataset: Specification of the dataset.
1071
- snapshot_catalog: Snapshot catalog for the dataset version if specified.
1072
- create: Create a new MINID if one doesn't already exist.
1629
+ version: The dataset version to get the MINID for.
1630
+ create: If True, create a new MINID if one doesn't already exist.
1631
+ If False, raise an exception if no MINID exists.
1632
+ use_minid: If True, use the MINID service for persistent identification.
1633
+ If False, generate a direct download URL without MINID registration.
1073
1634
 
1074
1635
  Returns:
1075
- New or existing MINID for the dataset.
1076
- """
1077
- rid = dataset.rid
1078
-
1079
- # Case 1: RID is already a MINID or direct URL
1080
- if rid.startswith("minid"):
1081
- return self._fetch_minid_metadata(f"https://identifiers.org/{rid}", dataset.version)
1082
- if rid.startswith("http"):
1083
- return self._fetch_minid_metadata(rid, dataset.version)
1636
+ DatasetMinid: Object containing the MINID URL, checksum, and metadata.
1084
1637
 
1085
- # Case 2: RID is a dataset RID – validate existence
1086
- if not any(rid == ds["RID"] for ds in self.find_datasets()):
1087
- raise DerivaMLTableTypeError("Dataset", rid)
1638
+ Raises:
1639
+ DerivaMLException: If the version doesn't exist, or if create=False
1640
+ and no MINID exists.
1641
+ """
1088
1642
 
1089
1643
  # Find dataset version record
1090
- version_str = str(dataset.version)
1091
- history = self.dataset_history(rid)
1644
+ version_str = str(version)
1645
+ history = self.dataset_history()
1092
1646
  try:
1093
1647
  version_record = next(v for v in history if v.dataset_version == version_str)
1094
1648
  except StopIteration:
1095
- raise DerivaMLException(f"Version {version_str} does not exist for RID {rid}")
1649
+ raise DerivaMLException(f"Version {version_str} does not exist for RID {self.dataset_rid}")
1096
1650
 
1097
1651
  # Check or create MINID
1098
1652
  minid_url = version_record.minid
1099
1653
  # If we either don't have a MINID, or we have a MINID, but we don't want to use it, generate a new one.
1100
- if (not minid_url) or (not self._use_minid):
1654
+ if (not minid_url) or (not use_minid):
1101
1655
  if not create:
1102
- raise DerivaMLException(f"Minid for dataset {rid} doesn't exist")
1103
- if self._use_minid:
1104
- self._logger.info("Creating new MINID for dataset %s", rid)
1105
- minid_url = self._create_dataset_minid(dataset, snapshot_catalog)
1656
+ raise DerivaMLException(f"Minid for dataset {self.dataset_rid} doesn't exist")
1657
+ if use_minid:
1658
+ self._logger.info("Creating new MINID for dataset %s", self.dataset_rid)
1659
+ minid_url = self._create_dataset_minid(version, use_minid=use_minid)
1106
1660
 
1107
1661
  # Return based on MINID usage
1108
- if self._use_minid:
1109
- return self._fetch_minid_metadata(minid_url, dataset.version)
1662
+ if use_minid:
1663
+ return self._fetch_minid_metadata(version, minid_url)
1110
1664
  return DatasetMinid(
1111
- dataset_version=dataset.version,
1112
- RID=f"{rid}@{version_record.snapshot}",
1665
+ dataset_version=version,
1666
+ RID=f"{self.dataset_rid}@{version_record.snapshot}",
1113
1667
  location=minid_url,
1114
1668
  )
1115
1669
 
1116
- def _fetch_minid_metadata(self, url: str, version: DatasetVersion) -> DatasetMinid:
1117
- r = requests.get(url, headers={"accept": "application/json"})
1118
- r.raise_for_status()
1119
- return DatasetMinid(dataset_version=version, **r.json())
1120
-
1121
- def _download_dataset_minid(self, minid: DatasetMinid) -> Path:
1122
- """Given a RID to a dataset_table, or a MINID to an existing bag, download the bag file, extract it, and
1123
- validate that all the metadata is correct
1670
+ def _fetch_minid_metadata(self, version: DatasetVersion, url: str) -> DatasetMinid:
1671
+ """Fetch MINID metadata from the MINID service.
1124
1672
 
1125
1673
  Args:
1126
- minid: The RID of a dataset_table or a minid to an existing bag.
1127
- Returns:
1128
- the location of the unpacked and validated dataset_table bag and the RID of the bag and the bag MINID
1129
- """
1674
+ version: The dataset version associated with this MINID.
1675
+ url: The MINID landing page URL.
1130
1676
 
1131
- # Check to see if we have an existing idempotent materialization of the desired bag. If so, then reuse
1132
- # it. If not, then we need to extract the contents of the archive into our cache directory.
1133
- bag_dir = self._cache_dir / f"{minid.dataset_rid}_{minid.checksum}"
1134
- if bag_dir.exists():
1135
- self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1136
- return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1677
+ Returns:
1678
+ DatasetMinid: Parsed metadata including bag URL, checksum, and identifiers.
1137
1679
 
1138
- # Either bag hasn't been downloaded yet, or we are not using a Minid, so we don't know the checksum yet.
1139
- with TemporaryDirectory() as tmp_dir:
1140
- if self._use_minid:
1141
- # Get bag from S3
1142
- bag_path = Path(tmp_dir) / Path(urlparse(minid.bag_url).path).name
1143
- archive_path = fetch_single_file(minid.bag_url, output_path=bag_path)
1144
- else:
1145
- exporter = DerivaExport(host=self._model.catalog.deriva_server.server, output_dir=tmp_dir)
1146
- archive_path = exporter.retrieve_file(minid.bag_url)
1147
- hashes = hash_utils.compute_file_hashes(archive_path, hashes=["md5", "sha256"])
1148
- checksum = hashes["sha256"][0]
1149
- bag_dir = self._cache_dir / f"{minid.dataset_rid}_{checksum}"
1150
- if bag_dir.exists():
1151
- self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1152
- return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1153
- bag_path = bdb.extract_bag(archive_path, bag_dir.as_posix())
1154
- bdb.validate_bag_structure(bag_path)
1155
- return Path(bag_path)
1680
+ Raises:
1681
+ requests.HTTPError: If the MINID service request fails.
1682
+ """
1683
+ r = requests.get(url, headers={"accept": "application/json"})
1684
+ r.raise_for_status()
1685
+ return DatasetMinid(dataset_version=version, **r.json())
1156
1686
 
1157
1687
  def _materialize_dataset_bag(
1158
1688
  self,
1159
1689
  minid: DatasetMinid,
1160
- execution_rid: RID | None = None,
1690
+ use_minid: bool,
1161
1691
  ) -> Path:
1162
- """Materialize a dataset_table bag into a local directory
1692
+ """Materialize a dataset bag by downloading all referenced files.
1693
+
1694
+ This method downloads a BDBag and then "materializes" it by fetching
1695
+ all files referenced in the bag's fetch.txt manifest. This includes
1696
+ data files, assets, and any other content referenced by the bag.
1697
+
1698
+ Progress is reported through callbacks that update the execution status
1699
+ if this download is associated with an execution.
1163
1700
 
1164
1701
  Args:
1165
- minid: A MINID to an existing bag or a RID of the dataset_table that should be downloaded.
1702
+ minid: DatasetMinid containing the bag URL and metadata.
1703
+ use_minid: If True, download from S3 using the MINID URL.
1166
1704
 
1167
1705
  Returns:
1168
- A tuple containing the path to the bag, the RID of the bag, and the MINID to the bag.
1706
+ Path: The path to the fully materialized bag directory.
1707
+
1708
+ Note:
1709
+ Materialization status is cached via a 'validated_check.txt' marker
1710
+ file to avoid re-downloading already-materialized bags.
1169
1711
  """
1170
1712
 
1171
1713
  def update_status(status: Status, msg: str) -> None:
1172
1714
  """Update the current status for this execution in the catalog"""
1173
- if execution_rid and execution_rid != DRY_RUN_RID:
1174
- self._model.catalog.getPathBuilder().schemas[self._ml_schema].Execution.update(
1715
+ if self.execution_rid and self.execution_rid != DRY_RUN_RID:
1716
+ self._ml_instance.pathBuilder().schemas[self._ml_instance.ml_schema].Execution.update(
1175
1717
  [
1176
1718
  {
1177
- "RID": execution_rid,
1719
+ "RID": self.execution_rid,
1178
1720
  "Status": status.value,
1179
1721
  "Status_Detail": msg,
1180
1722
  }
@@ -1184,18 +1726,18 @@ class Dataset:
1184
1726
 
1185
1727
  def fetch_progress_callback(current, total):
1186
1728
  msg = f"Materializing bag: {current} of {total} file(s) downloaded."
1187
- if execution_rid:
1729
+ if self.execution_rid:
1188
1730
  update_status(Status.running, msg)
1189
1731
  return True
1190
1732
 
1191
1733
  def validation_progress_callback(current, total):
1192
1734
  msg = f"Validating bag: {current} of {total} file(s) validated."
1193
- if execution_rid:
1735
+ if self.execution_rid:
1194
1736
  update_status(Status.running, msg)
1195
1737
  return True
1196
1738
 
1197
1739
  # request metadata
1198
- bag_path = self._download_dataset_minid(minid)
1740
+ bag_path = self._download_dataset_minid(minid, use_minid)
1199
1741
  bag_dir = bag_path.parent
1200
1742
  validated_check = bag_dir / "validated_check.txt"
1201
1743
 
@@ -1209,311 +1751,3 @@ class Dataset:
1209
1751
  )
1210
1752
  validated_check.touch()
1211
1753
  return Path(bag_path)
1212
-
1213
- def _export_annotation(
1214
- self,
1215
- snapshot_catalog: DerivaML | None = None,
1216
- ) -> list[dict[str, Any]]:
1217
- """Return and output specification for the datasets in the provided model
1218
-
1219
- Returns:
1220
- An export specification suitable for Chaise.
1221
- """
1222
-
1223
- # Export specification is a specification for the datasets, plus any controlled vocabulary
1224
- return [
1225
- {
1226
- "source": {"api": False, "skip_root_path": True},
1227
- "destination": {"type": "env", "params": {"query_keys": ["snaptime"]}},
1228
- },
1229
- {
1230
- "source": {"api": "entity"},
1231
- "destination": {
1232
- "type": "env",
1233
- "params": {"query_keys": ["RID", "Description"]},
1234
- },
1235
- },
1236
- {
1237
- "source": {"api": "schema", "skip_root_path": True},
1238
- "destination": {"type": "json", "name": "schema"},
1239
- },
1240
- ] + self._dataset_specification(
1241
- self._export_annotation_dataset_element,
1242
- None,
1243
- snapshot_catalog=snapshot_catalog,
1244
- )
1245
-
1246
- def _export_specification(
1247
- self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1248
- ) -> list[dict[str, Any]]:
1249
- """
1250
- Generate a specification for export engine for specific dataset.
1251
-
1252
- Returns:
1253
- a download specification for the datasets in the provided model.
1254
-
1255
- """
1256
-
1257
- # Download spec is the spec for any controlled vocabulary and for the dataset_table.
1258
- return [
1259
- {
1260
- "processor": "json",
1261
- "processor_params": {"query_path": "/schema", "output_path": "schema"},
1262
- }
1263
- ] + self._dataset_specification(self._export_specification_dataset_element, dataset, snapshot_catalog)
1264
-
1265
- @staticmethod
1266
- def _export_specification_dataset_element(spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1267
- """Return the download specification for the data object indicated by a path through the data model.
1268
-
1269
- Args:
1270
- spath: Source path
1271
- dpath: Destination path
1272
- table: Table referenced to by the path
1273
-
1274
- Returns:
1275
- The download specification that will retrieve that data from the catalog and place it into a BDBag.
1276
- """
1277
- exports = [
1278
- {
1279
- "processor": "csv",
1280
- "processor_params": {
1281
- "query_path": f"/entity/{spath}",
1282
- "output_path": dpath,
1283
- },
1284
- }
1285
- ]
1286
-
1287
- # If this table is an asset table, then we need to output the files associated with the asset.
1288
- asset_columns = {"Filename", "URL", "Length", "MD5", "Description"}
1289
- if asset_columns.issubset({c.name for c in table.columns}):
1290
- exports.append(
1291
- {
1292
- "processor": "fetch",
1293
- "processor_params": {
1294
- "query_path": f"/attribute/{spath}/!(URL::null::)/url:=URL,length:=Length,filename:=Filename,md5:=MD5,asset_rid:=RID",
1295
- "output_path": "asset/{asset_rid}/" + table.name,
1296
- },
1297
- }
1298
- )
1299
- return exports
1300
-
1301
- def _export_annotation_dataset_element(self, spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1302
- """Given a path in the data model, output an export specification for the path taken to get to the
1303
- current table.
1304
-
1305
- Args:
1306
- spath: Source path
1307
- dpath: Destination path
1308
- table: Table referenced to by the path
1309
-
1310
- Returns:
1311
- The export specification that will retrieve that data from the catalog and place it into a BDBag.
1312
- """
1313
- # The table is the last element of the path. Generate the ERMRest query by converting the list of tables
1314
- # into a path in the form of /S:T1/S:T2/S:Table
1315
- # Generate the destination path in the file system using just the table names.
1316
-
1317
- skip_root_path = False
1318
- if spath.startswith(f"{self._ml_schema}:Dataset/"):
1319
- # Chaise will add table name and RID filter, so strip it off.
1320
- spath = "/".join(spath.split("/")[2:])
1321
- if spath == "":
1322
- # This path is to just the dataset table.
1323
- return []
1324
- else:
1325
- # A vocabulary table, so we don't want the root_path.
1326
- skip_root_path = True
1327
- exports = [
1328
- {
1329
- "source": {
1330
- "api": "entity",
1331
- "path": spath,
1332
- "skip_root_path": skip_root_path,
1333
- },
1334
- "destination": {"name": dpath, "type": "csv"},
1335
- }
1336
- ]
1337
-
1338
- # If this table is an asset table, then we need to output the files associated with the asset.
1339
- asset_columns = {"Filename", "URL", "Length", "MD5", "Description"}
1340
- if asset_columns.issubset({c.name for c in table.columns}):
1341
- exports.append(
1342
- {
1343
- "source": {
1344
- "skip_root_path": False,
1345
- "api": "attribute",
1346
- "path": f"{spath}/!(URL::null::)/url:=URL,length:=Length,filename:=Filename,md5:=MD5, asset_rid:=RID",
1347
- },
1348
- "destination": {"name": "asset/{asset_rid}/" + table.name, "type": "fetch"},
1349
- }
1350
- )
1351
- return exports
1352
-
1353
- def _generate_dataset_download_spec(
1354
- self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1355
- ) -> dict[str, Any]:
1356
- """
1357
- Generate a specification for downloading a specific dataset.
1358
-
1359
- This routine creates a download specification that can be used by the Deriva export processor to download
1360
- a specific dataset as a MINID.
1361
- Returns:
1362
- """
1363
- s3_target = "s3://eye-ai-shared"
1364
- minid_test = False
1365
-
1366
- catalog_id = self._version_snapshot(dataset)
1367
- post_processors = (
1368
- {
1369
- "post_processors": [
1370
- {
1371
- "processor": "cloud_upload",
1372
- "processor_params": {
1373
- "acl": "public-read",
1374
- "target_url": s3_target,
1375
- },
1376
- },
1377
- {
1378
- "processor": "identifier",
1379
- "processor_params": {
1380
- "test": minid_test,
1381
- "env_column_map": {
1382
- "RID": "{RID}@{snaptime}",
1383
- "Description": "{Description}",
1384
- },
1385
- },
1386
- },
1387
- ]
1388
- }
1389
- if self._use_minid
1390
- else {}
1391
- )
1392
- return post_processors | {
1393
- "env": {"RID": "{RID}"},
1394
- "bag": {
1395
- "bag_name": "Dataset_{RID}",
1396
- "bag_algorithms": ["md5"],
1397
- "bag_archiver": "zip",
1398
- "bag_metadata": {},
1399
- "bag_idempotent": True,
1400
- },
1401
- "catalog": {
1402
- "host": f"{self._model.catalog.deriva_server.scheme}://{self._model.catalog.deriva_server.server}",
1403
- "catalog_id": catalog_id,
1404
- "query_processors": [
1405
- {
1406
- "processor": "env",
1407
- "processor_params": {
1408
- "output_path": "Dataset",
1409
- "query_keys": ["snaptime"],
1410
- "query_path": "/",
1411
- },
1412
- },
1413
- {
1414
- "processor": "env",
1415
- "processor_params": {
1416
- "query_path": "/entity/M:=deriva-ml:Dataset/RID={RID}",
1417
- "output_path": "Dataset",
1418
- "query_keys": ["RID", "Description"],
1419
- },
1420
- },
1421
- ]
1422
- + self._export_specification(dataset, snapshot_catalog),
1423
- },
1424
- }
1425
-
1426
- def _generate_dataset_download_annotations(self) -> dict[str, Any]:
1427
- post_processors = (
1428
- {
1429
- "type": "BAG",
1430
- "outputs": [{"fragment_key": "dataset_export_outputs"}],
1431
- "displayname": "BDBag to Cloud",
1432
- "bag_idempotent": True,
1433
- "postprocessors": [
1434
- {
1435
- "processor": "cloud_upload",
1436
- "processor_params": {
1437
- "acl": "public-read",
1438
- "target_url": "s3://eye-ai-shared/",
1439
- },
1440
- },
1441
- {
1442
- "processor": "identifier",
1443
- "processor_params": {
1444
- "test": False,
1445
- "env_column_map": {
1446
- "RID": "{RID}@{snaptime}",
1447
- "Description": "{Description}",
1448
- },
1449
- },
1450
- },
1451
- ],
1452
- }
1453
- if self._use_minid
1454
- else {}
1455
- )
1456
- return {
1457
- deriva_tags.export_fragment_definitions: {"dataset_export_outputs": self._export_annotation()},
1458
- deriva_tags.visible_foreign_keys: self._dataset_visible_fkeys(),
1459
- deriva_tags.export_2019: {
1460
- "detailed": {
1461
- "templates": [
1462
- {
1463
- "type": "BAG",
1464
- "outputs": [{"fragment_key": "dataset_export_outputs"}],
1465
- "displayname": "BDBag Download",
1466
- "bag_idempotent": True,
1467
- }
1468
- | post_processors
1469
- ]
1470
- }
1471
- },
1472
- }
1473
-
1474
- def _dataset_visible_fkeys(self) -> dict[str, Any]:
1475
- def fkey_name(fk):
1476
- return [fk.name[0].name, fk.name[1]]
1477
-
1478
- dataset_table = self._model.schemas["deriva-ml"].tables["Dataset"]
1479
-
1480
- source_list = [
1481
- {
1482
- "source": [
1483
- {"inbound": ["deriva-ml", "Dataset_Version_Dataset_fkey"]},
1484
- "RID",
1485
- ],
1486
- "markdown_name": "Previous Versions",
1487
- "entity": True,
1488
- },
1489
- {
1490
- "source": [
1491
- {"inbound": ["deriva-ml", "Dataset_Dataset_Nested_Dataset_fkey"]},
1492
- {"outbound": ["deriva-ml", "Dataset_Dataset_Dataset_fkey"]},
1493
- "RID",
1494
- ],
1495
- "markdown_name": "Parent Datasets",
1496
- },
1497
- {
1498
- "source": [
1499
- {"inbound": ["deriva-ml", "Dataset_Dataset_Dataset_fkey"]},
1500
- {"outbound": ["deriva-ml", "Dataset_Dataset_Nested_Dataset_fkey"]},
1501
- "RID",
1502
- ],
1503
- "markdown_name": "Child Datasets",
1504
- },
1505
- ]
1506
- source_list.extend(
1507
- [
1508
- {
1509
- "source": [
1510
- {"inbound": fkey_name(fkey.self_fkey)},
1511
- {"outbound": fkey_name(other_fkey := fkey.other_fkeys.pop())},
1512
- "RID",
1513
- ],
1514
- "markdown_name": other_fkey.pk_table.name,
1515
- }
1516
- for fkey in dataset_table.find_associations(max_arity=3, pure=False)
1517
- ]
1518
- )
1519
- return {"detailed": source_list}