deriva-ml 1.13.3__py3-none-any.whl → 1.14.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. deriva_ml/__init__.py +25 -30
  2. deriva_ml/core/__init__.py +39 -0
  3. deriva_ml/core/base.py +1489 -0
  4. deriva_ml/core/constants.py +36 -0
  5. deriva_ml/core/definitions.py +74 -0
  6. deriva_ml/core/enums.py +222 -0
  7. deriva_ml/core/ermrest.py +288 -0
  8. deriva_ml/core/exceptions.py +28 -0
  9. deriva_ml/core/filespec.py +116 -0
  10. deriva_ml/dataset/__init__.py +4 -0
  11. deriva_ml/{dataset_aux_classes.py → dataset/aux_classes.py} +16 -12
  12. deriva_ml/{dataset.py → dataset/dataset.py} +408 -416
  13. deriva_ml/{dataset_bag.py → dataset/dataset_bag.py} +137 -97
  14. deriva_ml/{history.py → dataset/history.py} +52 -33
  15. deriva_ml/{upload.py → dataset/upload.py} +48 -70
  16. deriva_ml/demo_catalog.py +233 -183
  17. deriva_ml/execution/environment.py +290 -0
  18. deriva_ml/{execution.py → execution/execution.py} +365 -252
  19. deriva_ml/execution/execution_configuration.py +163 -0
  20. deriva_ml/{execution_configuration.py → execution/workflow.py} +206 -218
  21. deriva_ml/feature.py +83 -46
  22. deriva_ml/model/__init__.py +0 -0
  23. deriva_ml/{deriva_model.py → model/catalog.py} +113 -132
  24. deriva_ml/{database_model.py → model/database.py} +52 -74
  25. deriva_ml/model/sql_mapper.py +44 -0
  26. deriva_ml/run_notebook.py +19 -11
  27. deriva_ml/schema/__init__.py +3 -0
  28. deriva_ml/{schema_setup → schema}/annotations.py +31 -22
  29. deriva_ml/schema/check_schema.py +104 -0
  30. deriva_ml/{schema_setup → schema}/create_schema.py +151 -104
  31. deriva_ml/schema/deriva-ml-reference.json +8525 -0
  32. deriva_ml/schema/table_comments_utils.py +57 -0
  33. {deriva_ml-1.13.3.dist-info → deriva_ml-1.14.26.dist-info}/METADATA +5 -4
  34. deriva_ml-1.14.26.dist-info/RECORD +40 -0
  35. {deriva_ml-1.13.3.dist-info → deriva_ml-1.14.26.dist-info}/entry_points.txt +1 -0
  36. deriva_ml/deriva_definitions.py +0 -372
  37. deriva_ml/deriva_ml_base.py +0 -1046
  38. deriva_ml/execution_environment.py +0 -139
  39. deriva_ml/schema_setup/table_comments_utils.py +0 -56
  40. deriva_ml/test-files/execution-parameters.json +0 -1
  41. deriva_ml/test-files/notebook-parameters.json +0 -5
  42. deriva_ml/test_functions.py +0 -141
  43. deriva_ml/test_notebook.ipynb +0 -197
  44. deriva_ml-1.13.3.dist-info/RECORD +0 -31
  45. /deriva_ml/{schema_setup → execution}/__init__.py +0 -0
  46. /deriva_ml/{schema_setup → schema}/policy.json +0 -0
  47. {deriva_ml-1.13.3.dist-info → deriva_ml-1.14.26.dist-info}/WHEEL +0 -0
  48. {deriva_ml-1.13.3.dist-info → deriva_ml-1.14.26.dist-info}/licenses/LICENSE +0 -0
  49. {deriva_ml-1.13.3.dist-info → deriva_ml-1.14.26.dist-info}/top_level.txt +0 -0
@@ -1,74 +1,110 @@
1
- """
2
- This module defines the DataSet class with is used to manipulate datasets in DerivaML.
3
- The intended use of this class is as a base class in DerivaML, so all the methods documented here are
4
- accessible via a DerivaML class instance.
5
-
1
+ """Dataset management for DerivaML.
2
+
3
+ This module provides functionality for managing datasets in DerivaML. A dataset represents a collection
4
+ of related data that can be versioned, downloaded, and tracked. The module includes:
5
+
6
+ - Dataset class: Core class for dataset operations
7
+ - Version management: Track and update dataset versions
8
+ - History tracking: Record dataset changes over time
9
+ - Download capabilities: Export datasets as BDBags
10
+ - Relationship management: Handle dataset dependencies and hierarchies
11
+
12
+ The Dataset class serves as a base class in DerivaML, making its methods accessible through
13
+ DerivaML class instances.
14
+
15
+ Typical usage example:
16
+ >>> ml = DerivaML('deriva.example.org', 'my_catalog')
17
+ >>> dataset_rid = ml.create_dataset('experiment', 'Experimental data')
18
+ >>> ml.add_dataset_members(dataset_rid=dataset_rid, members=['1-abc123', '1-def456'])
19
+ >>> ml.increment_dataset_version(datset_rid=dataset_rid, component=VersionPart.minor,
20
+ ... description='Added new samples')
6
21
  """
7
22
 
8
23
  from __future__ import annotations
9
- from bdbag import bdbag_api as bdb
10
- from bdbag.fetch.fetcher import fetch_single_file
11
- from collections import defaultdict
12
- from graphlib import TopologicalSorter
24
+
25
+ # Standard library imports
13
26
  import json
14
27
  import logging
28
+ from collections import defaultdict
29
+ from graphlib import TopologicalSorter
15
30
  from pathlib import Path
16
- from pydantic import (
17
- validate_call,
18
- ConfigDict,
19
- )
20
- import requests
21
31
  from tempfile import TemporaryDirectory
22
- from typing import Any, Callable, Optional, Iterable, Iterator, TYPE_CHECKING
32
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator
23
33
 
34
+ import deriva.core.utils.hash_utils as hash_utils
35
+ import requests
24
36
 
37
+ # Third-party imports
38
+ from bdbag import bdbag_api as bdb
39
+ from bdbag.fetch.fetcher import fetch_single_file
40
+
41
+ # Deriva imports
25
42
  from deriva.core.ermrest_model import Table
26
- from deriva.core.utils.core_utils import tag as deriva_tags, format_exception
27
- import deriva.core.utils.hash_utils as hash_utils
28
- from deriva.transfer.download.deriva_export import DerivaExport
43
+ from deriva.core.utils.core_utils import format_exception
44
+ from deriva.core.utils.core_utils import tag as deriva_tags
29
45
  from deriva.transfer.download.deriva_download import (
30
- DerivaDownloadConfigurationError,
31
- DerivaDownloadError,
32
46
  DerivaDownloadAuthenticationError,
33
47
  DerivaDownloadAuthorizationError,
48
+ DerivaDownloadConfigurationError,
49
+ DerivaDownloadError,
34
50
  DerivaDownloadTimeoutError,
35
51
  )
52
+ from deriva.transfer.download.deriva_export import DerivaExport
53
+ from pydantic import ConfigDict, validate_call
36
54
 
37
-
55
+ # Local imports
38
56
  try:
39
57
  from icecream import ic
58
+
59
+ ic.configureOutput(includeContext=True)
40
60
  except ImportError: # Graceful fallback if IceCream isn't installed.
41
61
  ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
42
62
 
43
- from deriva_ml import DatasetBag
44
- from .deriva_definitions import (
63
+ from deriva_ml.core.constants import RID
64
+ from deriva_ml.core.definitions import (
65
+ DRY_RUN_RID,
45
66
  ML_SCHEMA,
46
- DerivaMLException,
47
67
  MLVocab,
48
68
  Status,
49
- RID,
50
- DRY_RUN_RID,
51
69
  )
52
- from .deriva_model import DerivaModel
53
- from .database_model import DatabaseModel
54
- from .dataset_aux_classes import (
55
- DatasetVersion,
56
- DatasetMinid,
70
+ from deriva_ml.core.exceptions import DerivaMLException, DerivaMLTableTypeError
71
+ from deriva_ml.dataset.aux_classes import (
57
72
  DatasetHistory,
58
- VersionPart,
73
+ DatasetMinid,
59
74
  DatasetSpec,
75
+ DatasetVersion,
76
+ VersionPart,
60
77
  )
78
+ from deriva_ml.dataset.dataset_bag import DatasetBag
79
+ from deriva_ml.model.catalog import DerivaModel
80
+ from deriva_ml.model.database import DatabaseModel
81
+
82
+ from .history import iso_to_snap
83
+
84
+ # Stop pycharm from complaining about undefined reference in docstring....
85
+ ml: DerivaML
61
86
 
62
87
  if TYPE_CHECKING:
63
- from .deriva_ml_base import DerivaML
88
+ from deriva_ml.core.base import DerivaML
64
89
 
65
90
 
66
91
  class Dataset:
67
- """
68
- Class to manipulate a dataset.
92
+ """Manages dataset operations in a Deriva catalog.
93
+
94
+ The Dataset class provides functionality for creating, modifying, and tracking datasets
95
+ in a Deriva catalog. It handles versioning, relationships between datasets, and data export.
69
96
 
70
97
  Attributes:
71
- dataset_table (Table): ERMRest table holding dataset information.
98
+ dataset_table (Table): ERMrest table storing dataset information.
99
+ _model (DerivaModel): Catalog model instance.
100
+ _ml_schema (str): Schema name for ML-specific tables.
101
+ _cache_dir (Path): Directory for caching downloaded datasets.
102
+ _working_dir (Path): Directory for working data.
103
+ _use_minid (bool): Whether to use MINID service for dataset identification.
104
+
105
+ Note:
106
+ This class is typically used as a base class, with its methods accessed through
107
+ DerivaML class instances rather than directly.
72
108
  """
73
109
 
74
110
  _Logger = logging.getLogger("deriva_ml")
@@ -80,20 +116,31 @@ class Dataset:
80
116
  working_dir: Path,
81
117
  use_minid: bool = True,
82
118
  ):
119
+ """Initializes a Dataset instance.
120
+
121
+ Args:
122
+ model: DerivaModel instance representing the catalog.
123
+ cache_dir: Directory path for caching downloaded datasets.
124
+ working_dir: Directory path for working data.
125
+ use_minid: Whether to use MINID service for dataset identification.
126
+ """
83
127
  self._model = model
84
128
  self._ml_schema = ML_SCHEMA
85
- self.dataset_table = self._model.schemas[self._ml_schema].tables["Dataset"]
86
129
  self._cache_dir = cache_dir
87
130
  self._working_dir = working_dir
88
131
  self._logger = logging.getLogger("deriva_ml")
89
132
  self._use_minid = use_minid
90
133
 
134
+ @property
135
+ def _dataset_table(self):
136
+ return self._model.schemas[self._ml_schema].tables["Dataset"]
137
+
91
138
  def _is_dataset_rid(self, dataset_rid: RID, deleted: bool = False) -> bool:
92
139
  try:
93
140
  rid_info = self._model.catalog.resolve_rid(dataset_rid, self._model.model)
94
141
  except KeyError as _e:
95
142
  raise DerivaMLException(f"Invalid RID {dataset_rid}")
96
- if rid_info.table != self.dataset_table:
143
+ if rid_info.table != self._dataset_table:
97
144
  return False
98
145
  elif deleted:
99
146
  # Got a dataset rid. Now check to see if its deleted or not.
@@ -104,12 +151,12 @@ class Dataset:
104
151
  def _insert_dataset_versions(
105
152
  self,
106
153
  dataset_list: list[DatasetSpec],
107
- description: Optional[str] = "",
108
- execution_rid: Optional[RID] = None,
154
+ description: str | None = "",
155
+ execution_rid: RID | None = None,
109
156
  ) -> None:
110
157
  schema_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema]
111
158
  # determine snapshot after changes were made
112
- snap = self._model.catalog.get("/").json()["snaptime"]
159
+
113
160
  # Construct version records for insert
114
161
  version_records = schema_path.tables["Dataset_Version"].insert(
115
162
  [
@@ -118,16 +165,18 @@ class Dataset:
118
165
  "Version": str(dataset.version),
119
166
  "Description": description,
120
167
  "Execution": execution_rid,
121
- "Snapshot": snap,
122
168
  }
123
169
  for dataset in dataset_list
124
170
  ]
125
171
  )
172
+ version_records = list(version_records)
173
+ snap = self._model.catalog.get("/").json()["snaptime"]
174
+ schema_path.tables["Dataset_Version"].update(
175
+ [{"RID": v["RID"], "Dataset": v["Dataset"], "Snapshot": snap} for v in version_records]
176
+ )
126
177
 
127
178
  # And update the dataset records.
128
- schema_path.tables["Dataset"].update(
129
- [{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records]
130
- )
179
+ schema_path.tables["Dataset"].update([{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records])
131
180
 
132
181
  def _bootstrap_versions(self):
133
182
  datasets = [ds["RID"] for ds in self.find_datasets()]
@@ -143,9 +192,7 @@ class Dataset:
143
192
  version_path = schema_path.tables["Dataset_Version"]
144
193
  dataset_path = schema_path.tables["Dataset"]
145
194
  history = list(version_path.insert(ds_version))
146
- dataset_versions = [
147
- {"RID": h["Dataset"], "Version": h["Version"]} for h in history
148
- ]
195
+ dataset_versions = [{"RID": h["Dataset"], "Version": h["Version"]} for h in history]
149
196
  dataset_path.update(dataset_versions)
150
197
 
151
198
  def _synchronize_dataset_versions(self):
@@ -161,30 +208,46 @@ class Dataset:
161
208
  versions[v["Dataset"]] = v
162
209
  dataset_path = schema_path.tables["Dataset"]
163
210
 
164
- dataset_path.update(
165
- [
166
- {"RID": dataset, "Version": version["RID"]}
167
- for dataset, version in versions.items()
168
- ]
211
+ dataset_path.update([{"RID": dataset, "Version": version["RID"]} for dataset, version in versions.items()])
212
+
213
+ def _set_version_snapshot(self):
214
+ dataset_version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
215
+ versions = dataset_version_path.entities().fetch()
216
+ dataset_version_path.update(
217
+ [{"RID": h["RID"], "Snapshot": iso_to_snap(h["RCT"])} for h in versions if not h["Snapshot"]]
169
218
  )
170
219
 
171
220
  def dataset_history(self, dataset_rid: RID) -> list[DatasetHistory]:
172
- """Return a list of DatasetHistory objects representing the dataset
221
+ """Retrieves the version history of a dataset.
222
+
223
+ Returns a chronological list of dataset versions, including their version numbers,
224
+ creation times, and associated metadata.
173
225
 
174
226
  Args:
175
- dataset_rid: A RID to the dataset for which history is to be fetched.
227
+ dataset_rid: Resource Identifier of the dataset.
176
228
 
177
229
  Returns:
178
- A list of DatasetHistory objects which indicate the version-number, creation time, and bag instantiation of the dataset.
230
+ list[DatasetHistory]: List of history entries, each containing:
231
+ - dataset_version: Version number (major.minor.patch)
232
+ - minid: Minimal Viable Identifier
233
+ - snapshot: Catalog snapshot time
234
+ - dataset_rid: Dataset Resource Identifier
235
+ - version_rid: Version Resource Identifier
236
+ - description: Version description
237
+ - execution_rid: Associated execution RID
238
+
239
+ Raises:
240
+ DerivaMLException: If dataset_rid is not a valid dataset RID.
241
+
242
+ Example:
243
+ >>> history = ml.dataset_history("1-abc123")
244
+ >>> for entry in history:
245
+ ... print(f"Version {entry.dataset_version}: {entry.description}")
179
246
  """
180
247
 
181
248
  if not self._is_dataset_rid(dataset_rid):
182
249
  raise DerivaMLException(f"RID is not for a data set: {dataset_rid}")
183
- version_path = (
184
- self._model.catalog.getPathBuilder()
185
- .schemas[self._ml_schema]
186
- .tables["Dataset_Version"]
187
- )
250
+ version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
188
251
  return [
189
252
  DatasetHistory(
190
253
  dataset_version=DatasetVersion.parse(v["Version"]),
@@ -195,9 +258,7 @@ class Dataset:
195
258
  description=v["Description"],
196
259
  execution_rid=v["Execution"],
197
260
  )
198
- for v in version_path.filter(version_path.Dataset == dataset_rid)
199
- .entities()
200
- .fetch()
261
+ for v in version_path.filter(version_path.Dataset == dataset_rid).entities().fetch()
201
262
  ]
202
263
 
203
264
  @validate_call
@@ -219,14 +280,16 @@ class Dataset:
219
280
  if not history:
220
281
  return DatasetVersion(0, 1, 0)
221
282
  else:
222
- return max([h.dataset_version for h in self.dataset_history(dataset_rid)])
283
+ # Ensure we return a DatasetVersion, not a string
284
+ versions = [h.dataset_version for h in history]
285
+ return max(versions) if versions else DatasetVersion(0, 1, 0)
223
286
 
224
287
  def _build_dataset_graph(self, dataset_rid: RID) -> Iterable[RID]:
225
- ts = TopologicalSorter()
288
+ ts: TopologicalSorter = TopologicalSorter()
226
289
  self._build_dataset_graph_1(dataset_rid, ts, set())
227
290
  return ts.static_order()
228
291
 
229
- def _build_dataset_graph_1(self, dataset_rid: RID, ts, visited) -> None:
292
+ def _build_dataset_graph_1(self, dataset_rid: RID, ts: TopologicalSorter, visited) -> None:
230
293
  """Use topological sort to return bottom up list of nested datasets"""
231
294
  ts.add(dataset_rid)
232
295
  if dataset_rid not in visited:
@@ -234,7 +297,8 @@ class Dataset:
234
297
  children = self.list_dataset_children(dataset_rid=dataset_rid)
235
298
  parents = self.list_dataset_parents(dataset_rid=dataset_rid)
236
299
  for parent in parents:
237
- self._build_dataset_graph_1(parent, ts, visited)
300
+ # Convert string to RID type
301
+ self._build_dataset_graph_1(RID(parent), ts, visited)
238
302
  for child in children:
239
303
  self._build_dataset_graph_1(child, ts, visited)
240
304
 
@@ -243,22 +307,34 @@ class Dataset:
243
307
  self,
244
308
  dataset_rid: RID,
245
309
  component: VersionPart,
246
- description: Optional[str] = "",
247
- execution_rid: Optional[RID] = None,
310
+ description: str | None = "",
311
+ execution_rid: RID | None = None,
248
312
  ) -> DatasetVersion:
249
- """Increment the version of the specified dataset_table.
313
+ """Increments a dataset's version number.
314
+
315
+ Creates a new version of the dataset by incrementing the specified version component
316
+ (major, minor, or patch). The new version is recorded with an optional description
317
+ and execution reference.
250
318
 
251
319
  Args:
252
- dataset_rid: RID of the dataset whose version is to be incremented.
253
- component: Which version of the dataset_table to increment. Major, Minor, or Patch
254
- description: Description of the version update of the dataset_table.
255
- execution_rid: Which execution is performing increment.
320
+ dataset_rid: Resource Identifier of the dataset to version.
321
+ component: Which version component to increment ('major', 'minor', or 'patch').
322
+ description: Optional description of the changes in this version.
323
+ execution_rid: Optional execution RID to associate with this version.
256
324
 
257
325
  Returns:
258
- new semantic version of the dataset_table as a 3-tuple
326
+ DatasetVersion: The new version number.
259
327
 
260
328
  Raises:
261
- DerivaMLException: if provided, RID is not to a dataset_table.
329
+ DerivaMLException: If dataset_rid is invalid or version increment fails.
330
+
331
+ Example:
332
+ >>> new_version = ml.increment_dataset_version(
333
+ ... dataset_rid="1-abc123",
334
+ ... component="minor",
335
+ ... description="Added new samples"
336
+ ... )
337
+ >>> print(f"New version: {new_version}") # e.g., "1.2.0"
262
338
  """
263
339
 
264
340
  # Find all the datasets that are reachable from this dataset and determine their new version numbers.
@@ -270,46 +346,51 @@ class Dataset:
270
346
  )
271
347
  for ds_rid in related_datasets
272
348
  ]
273
- self._insert_dataset_versions(
274
- version_update_list, description=description, execution_rid=execution_rid
275
- )
276
- return [d.version for d in version_update_list if d.rid == dataset_rid][0]
349
+ self._insert_dataset_versions(version_update_list, description=description, execution_rid=execution_rid)
350
+ return next((d.version for d in version_update_list if d.rid == dataset_rid))
277
351
 
278
352
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
279
353
  def create_dataset(
280
354
  self,
281
- dataset_types: str | list[str],
282
- description: str,
283
- execution_rid: Optional[RID] = None,
284
- version: Optional[DatasetVersion] = None,
355
+ dataset_types: str | list[str] | None = None,
356
+ description: str = "",
357
+ execution_rid: RID | None = None,
358
+ version: DatasetVersion | None = None,
285
359
  ) -> RID:
286
- """Create a new dataset_table from the specified list of RIDs.
360
+ """Creates a new dataset in the catalog.
361
+
362
+ Creates a dataset with specified types and description. The dataset can be associated
363
+ with an execution and initialized with a specific version.
287
364
 
288
365
  Args:
289
- dataset_types: One or more dataset_table types. Must be a term from the DatasetType controlled vocabulary.
290
- description: Description of the dataset_table.
291
- execution_rid: Execution under which the dataset_table will be created.
292
- version: Version of the dataset_table.
366
+ dataset_types: One or more dataset type terms from Dataset_Type vocabulary.
367
+ description: Description of the dataset's purpose and contents.
368
+ execution_rid: Optional execution RID to associate with dataset creation.
369
+ version: Optional initial version number. Defaults to 0.1.0.
293
370
 
294
371
  Returns:
295
- New dataset_table RID.
372
+ RID: Resource Identifier of the newly created dataset.
296
373
 
374
+ Raises:
375
+ DerivaMLException: If dataset_types are invalid or creation fails.
376
+
377
+ Example:
378
+ >>> rid = ml.create_dataset(
379
+ ... dataset_types=["experiment", "raw_data"],
380
+ ... description="RNA sequencing experiment data",
381
+ ... version=DatasetVersion(1, 0, 0)
382
+ ... )
297
383
  """
298
384
 
299
385
  version = version or DatasetVersion(0, 1, 0)
386
+ dataset_types = dataset_types or []
300
387
 
301
- type_path = (
302
- self._model.catalog.getPathBuilder()
303
- .schemas[self._ml_schema]
304
- .tables[MLVocab.dataset_type.value]
305
- )
388
+ type_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables[MLVocab.dataset_type.value]
306
389
  defined_types = list(type_path.entities().fetch())
307
390
 
308
391
  def check_dataset_type(dtype: str) -> bool:
309
392
  for term in defined_types:
310
- if dtype == term["Name"] or (
311
- term["Synonyms"] and ds_type in term["Synonyms"]
312
- ):
393
+ if dtype == term["Name"] or (term["Synonyms"] and ds_type in term["Synonyms"]):
313
394
  return True
314
395
  return False
315
396
 
@@ -319,9 +400,7 @@ class Dataset:
319
400
  for ds_type in ds_types:
320
401
  if not check_dataset_type(ds_type):
321
402
  raise DerivaMLException("Dataset type must be a vocabulary term.")
322
- dataset_table_path = pb.schemas[self.dataset_table.schema.name].tables[
323
- self.dataset_table.name
324
- ]
403
+ dataset_table_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
325
404
  dataset_rid = dataset_table_path.insert(
326
405
  [
327
406
  {
@@ -332,21 +411,13 @@ class Dataset:
332
411
  )[0]["RID"]
333
412
 
334
413
  # Get the name of the association table between dataset_table and dataset_type.
335
- atable = next(
336
- self._model.schemas[self._ml_schema]
337
- .tables[MLVocab.dataset_type]
338
- .find_associations()
339
- ).name
414
+ associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
415
+ atable = associations[0].name if associations else None
340
416
  pb.schemas[self._ml_schema].tables[atable].insert(
341
- [
342
- {MLVocab.dataset_type: ds_type, "Dataset": dataset_rid}
343
- for ds_type in ds_types
344
- ]
417
+ [{MLVocab.dataset_type: ds_type, "Dataset": dataset_rid} for ds_type in ds_types]
345
418
  )
346
419
  if execution_rid is not None:
347
- pb.schemas[self._ml_schema].Dataset_Execution.insert(
348
- [{"Dataset": dataset_rid, "Execution": execution_rid}]
349
- )
420
+ pb.schemas[self._ml_schema].Dataset_Execution.insert([{"Dataset": dataset_rid, "Execution": execution_rid}])
350
421
  self._insert_dataset_versions(
351
422
  [DatasetSpec(rid=dataset_rid, version=version)],
352
423
  execution_rid=execution_rid,
@@ -368,18 +439,12 @@ class Dataset:
368
439
  raise DerivaMLException("Dataset_rid is not a dataset.")
369
440
 
370
441
  if parents := self.list_dataset_parents(dataset_rid):
371
- raise DerivaMLException(
372
- f'Dataset_rid "{dataset_rid}" is in a nested dataset: {parents}.'
373
- )
442
+ raise DerivaMLException(f'Dataset_rid "{dataset_rid}" is in a nested dataset: {parents}.')
374
443
 
375
444
  pb = self._model.catalog.getPathBuilder()
376
- dataset_path = pb.schemas[self.dataset_table.schema.name].tables[
377
- self.dataset_table.name
378
- ]
445
+ dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
379
446
 
380
- rid_list = [dataset_rid] + (
381
- self.list_dataset_children(dataset_rid) if recurse else []
382
- )
447
+ rid_list = [dataset_rid] + (self.list_dataset_children(dataset_rid=dataset_rid) if recurse else [])
383
448
  dataset_path.update([{"RID": r, "Deleted": True} for r in rid_list])
384
449
 
385
450
  def find_datasets(self, deleted: bool = False) -> Iterable[dict[str, Any]]:
@@ -393,14 +458,9 @@ class Dataset:
393
458
  """
394
459
  # Get datapath to all the tables we will need: Dataset, DatasetType and the association table.
395
460
  pb = self._model.catalog.getPathBuilder()
396
- dataset_path = pb.schemas[self.dataset_table.schema.name].tables[
397
- self.dataset_table.name
398
- ]
399
- atable = next(
400
- self._model.schemas[self._ml_schema]
401
- .tables[MLVocab.dataset_type]
402
- .find_associations()
403
- ).name
461
+ dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
462
+ associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
463
+ atable = associations[0].name if associations else None
404
464
  ml_path = pb.schemas[self._ml_schema]
405
465
  atable_path = ml_path.tables[atable]
406
466
 
@@ -408,21 +468,16 @@ class Dataset:
408
468
  filtered_path = dataset_path
409
469
  else:
410
470
  filtered_path = dataset_path.filter(
411
- (dataset_path.Deleted == False) | (dataset_path.Deleted == None) # noqa: E712
471
+ (dataset_path.Deleted == False) | (dataset_path.Deleted == None) # noqa: E711, E712
412
472
  )
413
473
 
414
474
  # Get a list of all the dataset_type values associated with this dataset_table.
415
475
  datasets = []
416
476
  for dataset in filtered_path.entities().fetch():
417
477
  ds_types = (
418
- atable_path.filter(atable_path.Dataset == dataset["RID"])
419
- .attributes(atable_path.Dataset_Type)
420
- .fetch()
421
- )
422
- datasets.append(
423
- dataset
424
- | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in ds_types]}
478
+ atable_path.filter(atable_path.Dataset == dataset["RID"]).attributes(atable_path.Dataset_Type).fetch()
425
479
  )
480
+ datasets.append(dataset | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in ds_types]})
426
481
  return datasets
427
482
 
428
483
  def list_dataset_element_types(self) -> Iterable[Table]:
@@ -433,16 +488,9 @@ class Dataset:
433
488
  """
434
489
 
435
490
  def domain_table(table: Table) -> bool:
436
- return (
437
- table.schema.name == self._model.domain_schema
438
- or table.name == self.dataset_table.name
439
- )
491
+ return table.schema.name == self._model.domain_schema or table.name == self._dataset_table.name
440
492
 
441
- return [
442
- t
443
- for a in self.dataset_table.find_associations()
444
- if domain_table(t := a.other_fkeys.pop().pk_table)
445
- ]
493
+ return [t for a in self._dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
446
494
 
447
495
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
448
496
  def add_dataset_element_type(self, element: str | Table) -> Table:
@@ -457,31 +505,45 @@ class Dataset:
457
505
  """
458
506
  # Add table to map
459
507
  element_table = self._model.name_to_table(element)
460
- table = self._model.schemas[self._model.domain_schema].create_table(
461
- Table.define_association([self.dataset_table, element_table])
462
- )
508
+ atable_def = Table.define_association([self._dataset_table, element_table])
509
+ try:
510
+ table = self._model.schemas[self._model.domain_schema].create_table(atable_def)
511
+ except ValueError as e:
512
+ if "already exists" in str(e):
513
+ table = self._model.name_to_table(atable_def["table_name"])
514
+ else:
515
+ raise e
463
516
 
464
517
  # self.model = self.catalog.getCatalogModel()
465
- self.dataset_table.annotations.update(
466
- self._generate_dataset_download_annotations()
467
- )
518
+ self._dataset_table.annotations.update(self._generate_dataset_download_annotations())
468
519
  self._model.model.apply()
469
520
  return table
470
521
 
471
522
  # @validate_call
472
523
  def list_dataset_members(
473
- self, dataset_rid: RID, recurse: bool = False, limit: Optional[int] = None
524
+ self, dataset_rid: RID, recurse: bool = False, limit: int | None = None
474
525
  ) -> dict[str, list[dict[str, Any]]]:
475
- """Return a list of entities associated with a specific dataset_table.
526
+ """Lists members of a dataset.
527
+
528
+ Returns a dictionary mapping member types to lists of member records. Can optionally
529
+ recurse through nested datasets and limit the number of results.
476
530
 
477
531
  Args:
478
- dataset_rid: param recurse: If this is a nested dataset_table, list the members of the contained datasets
479
- recurse: (Default value = False)
480
- limit: If provided, the maximum number of members to return for each element type.
532
+ dataset_rid: Resource Identifier of the dataset.
533
+ recurse: Whether to include members of nested datasets. Defaults to False.
534
+ limit: Maximum number of members to return per type. None for no limit.
481
535
 
482
536
  Returns:
483
- Dictionary of entities associated with a specific dataset_table. Key is the table from which the elements
484
- were taken.
537
+ dict[str, list[dict[str, Any]]]: Dictionary mapping member types to lists of members.
538
+ Each member is a dictionary containing the record's attributes.
539
+
540
+ Raises:
541
+ DerivaMLException: If dataset_rid is invalid.
542
+
543
+ Example:
544
+ >>> members = ml.list_dataset_members("1-abc123", recurse=True)
545
+ >>> for type_name, records in members.items():
546
+ ... print(f"{type_name}: {len(records)} records")
485
547
  """
486
548
 
487
549
  if not self._is_dataset_rid(dataset_rid):
@@ -491,21 +553,18 @@ class Dataset:
491
553
  # the appropriate association table.
492
554
  members = defaultdict(list)
493
555
  pb = self._model.catalog.getPathBuilder()
494
- for assoc_table in self.dataset_table.find_associations():
556
+ for assoc_table in self._dataset_table.find_associations():
495
557
  other_fkey = assoc_table.other_fkeys.pop()
496
558
  target_table = other_fkey.pk_table
497
559
  member_table = assoc_table.table
498
560
 
499
561
  # Look at domain tables and nested datasets.
500
- if (
501
- target_table.schema.name != self._model.domain_schema
502
- and target_table != self.dataset_table
562
+ if target_table.schema.name != self._model.domain_schema and not (
563
+ target_table == self._dataset_table or target_table.name == "File"
503
564
  ):
504
565
  continue
505
566
  member_column = (
506
- "Nested_Dataset"
507
- if target_table == self.dataset_table
508
- else other_fkey.foreign_key_columns[0].name
567
+ "Nested_Dataset" if target_table == self._dataset_table else other_fkey.foreign_key_columns[0].name
509
568
  )
510
569
 
511
570
  target_path = pb.schemas[target_table.schema.name].tables[target_table.name]
@@ -515,15 +574,13 @@ class Dataset:
515
574
  target_path,
516
575
  on=(member_path.columns[member_column] == target_path.columns["RID"]),
517
576
  )
518
- target_entities = list(
519
- path.entities().fetch(limit=limit) if limit else path.entities().fetch()
520
- )
577
+ target_entities = list(path.entities().fetch(limit=limit) if limit else path.entities().fetch())
521
578
  members[target_table.name].extend(target_entities)
522
- if recurse and target_table == self.dataset_table:
579
+ if recurse and target_table == self._dataset_table:
523
580
  # Get the members for all the nested datasets and add to the member list.
524
581
  nested_datasets = [d["RID"] for d in target_entities]
525
582
  for ds in nested_datasets:
526
- for k, v in self.list_dataset_members(ds, recurse=False).items():
583
+ for k, v in self.list_dataset_members(ds, recurse=recurse).items():
527
584
  members[k].extend(v)
528
585
  return dict(members)
529
586
 
@@ -531,24 +588,38 @@ class Dataset:
531
588
  def add_dataset_members(
532
589
  self,
533
590
  dataset_rid: RID,
534
- members: list[RID],
591
+ members: list[RID] | dict[str, list[RID]],
535
592
  validate: bool = True,
536
- description: Optional[str] = "",
537
- execution_rid: Optional[RID] = None,
593
+ description: str | None = "",
594
+ execution_rid: RID | None = None,
538
595
  ) -> None:
539
- """Add additional elements to an existing dataset_table.
596
+ """Adds members to a dataset.
540
597
 
541
- Add new elements to an existing dataset. In addition to adding new members, the minor version number of the
542
- dataset is incremented and the description, if provide is applied to that new version.
598
+ Associates one or more records with a dataset. Can optionally validate member types
599
+ and create a new dataset version to track the changes.
543
600
 
544
601
  Args:
545
- dataset_rid: RID of dataset_table to extend or None if a new dataset_table is to be created.
546
- members: List of member RIDs to add to the dataset_table.
547
- validate: Check rid_list to make sure elements are not already in the dataset_table.
548
- description: Markdown description of the updated dataset.
549
- execution_rid: Optional RID of execution associated with this dataset.
602
+ dataset_rid: Resource Identifier of the dataset.
603
+ members: List of RIDs to add as dataset members. Can be orginized into a dictionary that indicates the
604
+ table that the member rids belong to.
605
+ validate: Whether to validate member types. Defaults to True.
606
+ description: Optional description of the member additions.
607
+ execution_rid: Optional execution RID to associate with changes.
608
+
609
+ Raises:
610
+ DerivaMLException: If:
611
+ - dataset_rid is invalid
612
+ - members are invalid or of wrong type
613
+ - adding members would create a cycle
614
+ - validation fails
615
+
616
+ Example:
617
+ >>> ml.add_dataset_members(
618
+ ... dataset_rid="1-abc123",
619
+ ... members=["1-def456", "1-ghi789"],
620
+ ... description="Added sample data"
621
+ ... )
550
622
  """
551
- members = set(members)
552
623
  description = description or "Updated dataset via add_dataset_members"
553
624
 
554
625
  def check_dataset_cycle(member_rid, path=None):
@@ -565,43 +636,37 @@ class Dataset:
565
636
  return member_rid in path
566
637
 
567
638
  if validate:
568
- existing_rids = set(
569
- m["RID"]
570
- for ms in self.list_dataset_members(dataset_rid).values()
571
- for m in ms
572
- )
639
+ existing_rids = set(m["RID"] for ms in self.list_dataset_members(dataset_rid).values() for m in ms)
573
640
  if overlap := set(existing_rids).intersection(members):
574
- raise DerivaMLException(
575
- f"Attempting to add existing member to dataset_table {dataset_rid}: {overlap}"
576
- )
641
+ raise DerivaMLException(f"Attempting to add existing member to dataset_table {dataset_rid}: {overlap}")
577
642
 
578
643
  # Now go through every rid to be added to the data set and sort them based on what association table entries
579
644
  # need to be made.
580
645
  dataset_elements = {}
581
646
  association_map = {
582
- a.other_fkeys.pop().pk_table.name: a.table.name
583
- for a in self.dataset_table.find_associations()
647
+ a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
584
648
  }
649
+
585
650
  # Get a list of all the object types that can be linked to a dataset_table.
586
- for m in members:
587
- try:
588
- rid_info = self._model.catalog.resolve_rid(m)
589
- except KeyError:
590
- raise DerivaMLException(f"Invalid RID: {m}")
591
- if rid_info.table.name not in association_map:
592
- raise DerivaMLException(
593
- f"RID table: {rid_info.table.name} not part of dataset_table"
594
- )
595
- if rid_info.table == self.dataset_table and check_dataset_cycle(
596
- rid_info.rid
597
- ):
598
- raise DerivaMLException("Creating cycle of datasets is not allowed")
599
- dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
651
+ if type(members) is list:
652
+ members = set(members)
653
+ for m in members:
654
+ try:
655
+ rid_info = self._model.catalog.resolve_rid(m)
656
+ except KeyError:
657
+ raise DerivaMLException(f"Invalid RID: {m}")
658
+ if rid_info.table.name not in association_map:
659
+ raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
660
+ if rid_info.table == self._dataset_table and check_dataset_cycle(rid_info.rid):
661
+ raise DerivaMLException("Creating cycle of datasets is not allowed")
662
+ dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
663
+ else:
664
+ dataset_elements = {t: set(ms) for t, ms in members.items()}
600
665
  # Now make the entries into the association tables.
601
666
  pb = self._model.catalog.getPathBuilder()
602
667
  for table, elements in dataset_elements.items():
603
668
  schema_path = pb.schemas[
604
- self._ml_schema if table == "Dataset" else self._model.domain_schema
669
+ self._ml_schema if (table == "Dataset" or table == "File") else self._model.domain_schema
605
670
  ]
606
671
  fk_column = "Nested_Dataset" if table == "Dataset" else table
607
672
  if len(elements):
@@ -622,7 +687,7 @@ class Dataset:
622
687
  dataset_rid: RID,
623
688
  members: list[RID],
624
689
  description: str = "",
625
- execution_rid: Optional[RID] = None,
690
+ execution_rid: RID | None = None,
626
691
  ) -> None:
627
692
  """Remove elements to an existing dataset_table.
628
693
 
@@ -643,8 +708,7 @@ class Dataset:
643
708
  # need to be made.
644
709
  dataset_elements = {}
645
710
  association_map = {
646
- a.other_fkeys.pop().pk_table.name: a.table.name
647
- for a in self.dataset_table.find_associations()
711
+ a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
648
712
  }
649
713
  # Get a list of all the object types that can be linked to a dataset_table.
650
714
  for m in members:
@@ -653,16 +717,12 @@ class Dataset:
653
717
  except KeyError:
654
718
  raise DerivaMLException(f"Invalid RID: {m}")
655
719
  if rid_info.table.name not in association_map:
656
- raise DerivaMLException(
657
- f"RID table: {rid_info.table.name} not part of dataset_table"
658
- )
720
+ raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
659
721
  dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
660
722
  # Now make the entries into the association tables.
661
723
  pb = self._model.catalog.getPathBuilder()
662
724
  for table, elements in dataset_elements.items():
663
- schema_path = pb.schemas[
664
- self._ml_schema if table == "Dataset" else self._model.domain_schema
665
- ]
725
+ schema_path = pb.schemas[self._ml_schema if table == "Dataset" else self._model.domain_schema]
666
726
  fk_column = "Nested_Dataset" if table == "Dataset" else table
667
727
 
668
728
  if len(elements):
@@ -670,8 +730,7 @@ class Dataset:
670
730
  # Find out the name of the column in the association table.
671
731
  for e in elements:
672
732
  entity = atable_path.filter(
673
- (atable_path.Dataset == dataset_rid)
674
- & (atable_path.columns[fk_column] == e),
733
+ (atable_path.Dataset == dataset_rid) & (atable_path.columns[fk_column] == e),
675
734
  )
676
735
  entity.delete()
677
736
  self.increment_dataset_version(
@@ -693,21 +752,14 @@ class Dataset:
693
752
  RID of the parent dataset_table.
694
753
  """
695
754
  if not self._is_dataset_rid(dataset_rid):
696
- raise DerivaMLException(
697
- f"RID: {dataset_rid} does not belong to dataset_table {self.dataset_table.name}"
698
- )
755
+ raise DerivaMLException(f"RID: {dataset_rid} does not belong to dataset_table {self._dataset_table.name}")
699
756
  # Get association table for nested datasets
700
757
  pb = self._model.catalog.getPathBuilder()
701
758
  atable_path = pb.schemas[self._ml_schema].Dataset_Dataset
702
- return [
703
- p["Dataset"]
704
- for p in atable_path.filter(atable_path.Nested_Dataset == dataset_rid)
705
- .entities()
706
- .fetch()
707
- ]
759
+ return [p["Dataset"] for p in atable_path.filter(atable_path.Nested_Dataset == dataset_rid).entities().fetch()]
708
760
 
709
761
  @validate_call
710
- def list_dataset_children(self, dataset_rid: RID, recurse=False) -> list[RID]:
762
+ def list_dataset_children(self, dataset_rid: RID, recurse: bool = False) -> list[RID]:
711
763
  """Given a dataset_table RID, return a list of RIDs for any nested datasets.
712
764
 
713
765
  Args:
@@ -718,19 +770,11 @@ class Dataset:
718
770
  list of nested dataset RIDs.
719
771
 
720
772
  """
721
- dataset_dataset_path = (
722
- self._model.catalog.getPathBuilder()
723
- .schemas[self._ml_schema]
724
- .tables["Dataset_Dataset"]
725
- )
773
+ dataset_dataset_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
726
774
  nested_datasets = list(dataset_dataset_path.entities().fetch())
727
775
 
728
776
  def find_children(rid: RID):
729
- children = [
730
- child["Nested_Dataset"]
731
- for child in nested_datasets
732
- if child["Dataset"] == rid
733
- ]
777
+ children = [child["Nested_Dataset"] for child in nested_datasets if child["Dataset"] == rid]
734
778
  if recurse:
735
779
  for child in children.copy():
736
780
  children.extend(find_children(child))
@@ -738,9 +782,7 @@ class Dataset:
738
782
 
739
783
  return find_children(dataset_rid)
740
784
 
741
- def _export_vocabulary(
742
- self, writer: Callable[[str, str, Table], list[dict[str, Any]]]
743
- ) -> list[dict[str, Any]]:
785
+ def _export_vocabulary(self, writer: Callable[[str, str, Table], list[dict[str, Any]]]) -> list[dict[str, Any]]:
744
786
  """
745
787
 
746
788
  Args:
@@ -755,16 +797,12 @@ class Dataset:
755
797
  for table in s.tables.values()
756
798
  if self._model.is_vocabulary(table)
757
799
  ]
758
- return [
759
- o
760
- for table in vocabs
761
- for o in writer(f"{table.schema.name}:{table.name}", table.name, table)
762
- ]
800
+ return [o for table in vocabs for o in writer(f"{table.schema.name}:{table.name}", table.name, table)]
763
801
 
764
802
  def _table_paths(
765
803
  self,
766
- dataset: Optional[DatasetSpec] = None,
767
- snapshot_catalog: Optional[DerivaML] = None,
804
+ dataset: DatasetSpec | None = None,
805
+ snapshot_catalog: DerivaML | None = None,
768
806
  ) -> Iterator[tuple[str, str, Table]]:
769
807
  paths = self._collect_paths(dataset and dataset.rid, snapshot_catalog)
770
808
 
@@ -790,25 +828,20 @@ class Dataset:
790
828
 
791
829
  def _collect_paths(
792
830
  self,
793
- dataset_rid: Optional[RID] = None,
794
- snapshot: Optional[Dataset] = None,
795
- dataset_nesting_depth: Optional[int] = None,
831
+ dataset_rid: RID | None = None,
832
+ snapshot: Dataset | None = None,
833
+ dataset_nesting_depth: int | None = None,
796
834
  ) -> set[tuple[Table, ...]]:
797
835
  snapshot_catalog = snapshot if snapshot else self
798
836
 
799
- dataset_table = snapshot_catalog._model.schemas[self._ml_schema].tables[
800
- "Dataset"
801
- ]
802
- dataset_dataset = snapshot_catalog._model.schemas[self._ml_schema].tables[
803
- "Dataset_Dataset"
804
- ]
837
+ dataset_table = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset"]
838
+ dataset_dataset = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset_Dataset"]
805
839
 
806
840
  # Figure out what types of elements the dataset contains.
807
841
  dataset_associations = [
808
842
  a
809
- for a in self.dataset_table.find_associations()
810
- if a.table.schema.name != self._ml_schema
811
- or a.table.name == "Dataset_Dataset"
843
+ for a in self._dataset_table.find_associations()
844
+ if a.table.schema.name != self._ml_schema or a.table.name == "Dataset_Dataset"
812
845
  ]
813
846
  if dataset_rid:
814
847
  # Get a list of the members of the dataset so we can figure out which tables to query.
@@ -820,9 +853,7 @@ class Dataset:
820
853
  if m
821
854
  ]
822
855
  included_associations = [
823
- a.table
824
- for a in dataset_table.find_associations()
825
- if a.other_fkeys.pop().pk_table in dataset_elements
856
+ a.table for a in dataset_table.find_associations() if a.other_fkeys.pop().pk_table in dataset_elements
826
857
  ]
827
858
  else:
828
859
  included_associations = dataset_associations
@@ -833,9 +864,7 @@ class Dataset:
833
864
  for p in snapshot_catalog._model._schema_to_paths()
834
865
  if (len(p) == 1)
835
866
  or (p[1] not in dataset_associations) # Tables in the domain schema
836
- or (
837
- p[1] in included_associations
838
- ) # Tables that include members of the dataset
867
+ or (p[1] in included_associations) # Tables that include members of the dataset
839
868
  }
840
869
  # Now get paths for nested datasets
841
870
  nested_paths = set()
@@ -845,56 +874,42 @@ class Dataset:
845
874
  else:
846
875
  # Initialize nesting depth if not already provided.
847
876
  dataset_nesting_depth = (
848
- self._dataset_nesting_depth()
849
- if dataset_nesting_depth is None
850
- else dataset_nesting_depth
877
+ self._dataset_nesting_depth() if dataset_nesting_depth is None else dataset_nesting_depth
851
878
  )
852
879
  if dataset_nesting_depth:
853
- nested_paths = self._collect_paths(
854
- dataset_nesting_depth=dataset_nesting_depth - 1
855
- )
880
+ nested_paths = self._collect_paths(dataset_nesting_depth=dataset_nesting_depth - 1)
856
881
  if nested_paths:
857
882
  paths |= {
858
883
  tuple([dataset_table]),
859
884
  (dataset_table, dataset_dataset),
860
885
  }
861
- paths |= {(self.dataset_table, dataset_dataset) + p for p in nested_paths}
886
+ paths |= {(self._dataset_table, dataset_dataset) + p for p in nested_paths}
862
887
  return paths
863
888
 
864
- def _dataset_nesting_depth(self, dataset_rid: Optional[RID] = None) -> int:
889
+ def _dataset_nesting_depth(self, dataset_rid: RID | None = None) -> int:
865
890
  """Determine the maximum dataset nesting depth in the current catalog.
866
891
 
867
892
  Returns:
868
893
 
869
894
  """
870
895
 
871
- def children_depth(
872
- dataset_rid: RID, nested_datasets: dict[str, list[str]]
873
- ) -> int:
896
+ def children_depth(dataset_rid: RID, nested_datasets: dict[str, list[str]]) -> int:
874
897
  """Return the number of nested datasets for the dataset_rid if provided, otherwise in the current catalog"""
875
898
  try:
876
899
  children = nested_datasets[dataset_rid]
877
- return (
878
- max(map(lambda x: children_depth(x, nested_datasets), children)) + 1
879
- if children
880
- else 1
881
- )
900
+ return max(map(lambda x: children_depth(x, nested_datasets), children)) + 1 if children else 1
882
901
  except KeyError:
883
902
  return 0
884
903
 
885
904
  # Build up the dataset_table nesting graph...
886
- pb = (
887
- self._model.catalog.getPathBuilder()
888
- .schemas[self._ml_schema]
889
- .tables["Dataset_Dataset"]
890
- )
905
+ pb = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
891
906
  dataset_children = (
892
907
  [
893
908
  {
894
909
  "Dataset": dataset_rid,
895
910
  "Nested_Dataset": c,
896
911
  } # Make uniform with return from datapath
897
- for c in self.list_dataset_children(dataset_rid)
912
+ for c in self.list_dataset_children(dataset_rid=dataset_rid)
898
913
  ]
899
914
  if dataset_rid
900
915
  else pb.entities().fetch()
@@ -902,30 +917,29 @@ class Dataset:
902
917
  nested_dataset = defaultdict(list)
903
918
  for ds in dataset_children:
904
919
  nested_dataset[ds["Dataset"]].append(ds["Nested_Dataset"])
905
- return (
906
- max(map(lambda d: children_depth(d, dict(nested_dataset)), nested_dataset))
907
- if nested_dataset
908
- else 0
909
- )
920
+ return max(map(lambda d: children_depth(d, dict(nested_dataset)), nested_dataset)) if nested_dataset else 0
910
921
 
911
922
  def _dataset_specification(
912
923
  self,
913
924
  writer: Callable[[str, str, Table], list[dict[str, Any]]],
914
- dataset: Optional[DatasetSpec] = None,
915
- snapshot_catalog: Optional[DerivaML] = None,
925
+ dataset: DatasetSpec | None = None,
926
+ snapshot_catalog: DerivaML | None = None,
916
927
  ) -> list[dict[str, Any]]:
917
- """Output a download/export specification for a dataset_table. Each element of the dataset_table will be placed in its own dir
918
- The top level data directory of the resulting BDBag will have one subdirectory for element type. The subdirectory
919
- will contain the CSV indicating which elements of that type are present in the dataset_table, and then there will be a
920
- subdirectory for each object that is reachable from the dataset_table members.
921
-
922
- To simplify reconstructing the relationship between tables, the CVS for each
923
- The top level data directory will also contain a subdirectory for any controlled vocabularies used in the dataset_table.
924
- All assets will be placed into a directory named asset in a subdirectory with the asset table name.
925
-
926
- For example, consider a dataset_table that consists of two element types, T1 and T2. T1 has foreign key relationships to
927
- objects in tables T3 and T4. There are also two controlled vocabularies, CV1 and CV2. T2 is an asset table
928
- which has two assets in it. The layout of the resulting bdbag would be:
928
+ """Output a download/export specification for a dataset_table. Each element of the dataset_table
929
+ will be placed in its own directory.
930
+ The top level data directory of the resulting BDBag will have one subdirectory for element type.
931
+ The subdirectory will contain the CSV indicating which elements of that type are present in the
932
+ dataset_table, and then there will be a subdirectory for each object that is reachable from the
933
+ dataset_table members.
934
+
935
+ To simplify reconstructing the relationship between tables, the CVS for each element is included.
936
+ The top level data directory will also contain a subdirectory for any controlled vocabularies used in
937
+ the dataset_table. All assets will be placed into a directory named asset in a subdirectory with the
938
+ asset table name.
939
+
940
+ For example, consider a dataset_table that consists of two element types, T1 and T2. T1 has foreign
941
+ key relationships to objects in tables T3 and T4. There are also two controlled vocabularies, CV1 and
942
+ CV2. T2 is an asset table which has two assets in it. The layout of the resulting bdbag would be:
929
943
  data
930
944
  CV1/
931
945
  cv1.csv
@@ -952,17 +966,15 @@ class Dataset:
952
966
  A dataset_table specification.
953
967
  """
954
968
  element_spec = self._export_vocabulary(writer)
955
- for path in self._table_paths(
956
- dataset=dataset, snapshot_catalog=snapshot_catalog
957
- ):
969
+ for path in self._table_paths(dataset=dataset, snapshot_catalog=snapshot_catalog):
958
970
  element_spec.extend(writer(*path))
959
971
  return element_spec
960
972
 
961
973
  def _download_dataset_bag(
962
974
  self,
963
975
  dataset: DatasetSpec,
964
- execution_rid: Optional[RID] = None,
965
- snapshot_catalog: Optional[DerivaML] = None,
976
+ execution_rid: RID | None = None,
977
+ snapshot_catalog: DerivaML | None = None,
966
978
  ) -> DatasetBag:
967
979
  """Download a dataset onto the local file system. Create a MINID for the dataset if one doesn't already exist.
968
980
 
@@ -992,27 +1004,29 @@ class Dataset:
992
1004
 
993
1005
  def _version_snapshot(self, dataset: DatasetSpec) -> str:
994
1006
  """Return a catalog with snapshot for the specified dataset version"""
995
- version_record = [
996
- h
997
- for h in self.dataset_history(dataset_rid=dataset.rid)
998
- if h.dataset_version == dataset.version
999
- ][0]
1007
+ try:
1008
+ version_record = next(
1009
+ h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1010
+ )
1011
+ except StopIteration:
1012
+ raise DerivaMLException(f"Dataset version {dataset.version} not found for dataset {dataset.rid}")
1000
1013
  return f"{self._model.catalog.catalog_id}@{version_record.snapshot}"
1001
1014
 
1002
- def _create_dataset_minid(
1003
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1004
- ) -> str:
1015
+ def _create_dataset_minid(self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None) -> str:
1005
1016
  with TemporaryDirectory() as tmp_dir:
1006
1017
  # Generate a download specification file for the current catalog schema. By default, this spec
1007
1018
  # will generate a minid and place the bag into S3 storage.
1008
- spec_file = f"{tmp_dir}/download_spec.json"
1009
- with open(spec_file, "w", encoding="utf-8") as ds:
1010
- json.dump(
1011
- self._generate_dataset_download_spec(dataset, snapshot_catalog), ds
1012
- )
1019
+ spec_file = Path(tmp_dir) / "download_spec.json"
1020
+ with spec_file.open("w", encoding="utf-8") as ds:
1021
+ json.dump(self._generate_dataset_download_spec(dataset, snapshot_catalog), ds)
1013
1022
  try:
1014
1023
  self._logger.info(
1015
- f"Downloading dataset {'minid' if self._use_minid else 'bag'} for catalog: {dataset.rid}@{str(dataset.version)}"
1024
+ "Downloading dataset %s for catalog: %s@%s"
1025
+ % (
1026
+ 'minid' if self._use_minid else 'bag',
1027
+ dataset.rid,
1028
+ str(dataset.version),
1029
+ )
1016
1030
  )
1017
1031
  # Generate the bag and put into S3 storage.
1018
1032
  exporter = DerivaExport(
@@ -1035,15 +1049,9 @@ class Dataset:
1035
1049
  raise DerivaMLException(format_exception(e))
1036
1050
  # Update version table with MINID.
1037
1051
  if self._use_minid:
1038
- version_path = (
1039
- self._model.catalog.getPathBuilder()
1040
- .schemas[self._ml_schema]
1041
- .tables["Dataset_Version"]
1042
- )
1052
+ version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
1043
1053
  version_rid = [
1044
- h
1045
- for h in self.dataset_history(dataset_rid=dataset.rid)
1046
- if h.dataset_version == dataset.version
1054
+ h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1047
1055
  ][0].version_rid
1048
1056
  version_path.update([{"RID": version_rid, "Minid": minid_page_url}])
1049
1057
  return minid_page_url
@@ -1051,10 +1059,10 @@ class Dataset:
1051
1059
  def _get_dataset_minid(
1052
1060
  self,
1053
1061
  dataset: DatasetSpec,
1054
- snapshot_catalog: Optional[DerivaML] = None,
1062
+ snapshot_catalog: DerivaML | None = None,
1055
1063
  create: bool = True,
1056
- ) -> DatasetMinid:
1057
- """Return a MINID to the specified dataset. If no version is specified, use the latest.
1064
+ ) -> DatasetMinid | None:
1065
+ """Return a MINID for the specified dataset. If no version is specified, use the latest.
1058
1066
 
1059
1067
  Args:
1060
1068
  dataset: Specification of the dataset.
@@ -1064,50 +1072,53 @@ class Dataset:
1064
1072
  Returns:
1065
1073
  New or existing MINID for the dataset.
1066
1074
  """
1067
- if dataset.rid.startswith("minid"):
1068
- minid_url = f"https://identifiers.org/{dataset.rid}"
1069
- elif dataset.rid.startswith("http"):
1070
- minid_url = dataset.rid
1071
- else:
1072
- if not any([dataset.rid == ds["RID"] for ds in self.find_datasets()]):
1073
- raise DerivaMLException(f"RID {dataset.rid} is not a dataset_table")
1074
-
1075
- # Get the history record for the version we are looking for.
1076
- dataset_version_record = [
1077
- v
1078
- for v in self.dataset_history(dataset.rid)
1079
- if v.dataset_version == str(dataset.version)
1080
- ][0]
1081
- if not dataset_version_record:
1082
- raise DerivaMLException(
1083
- f"Version {str(dataset.version)} does not exist for RID {dataset.rid}"
1084
- )
1085
- minid_url = dataset_version_record.minid
1086
- if not minid_url:
1087
- if not create:
1088
- raise DerivaMLException(
1089
- f"Minid for dataset {dataset.rid} doesn't exist"
1090
- )
1091
- if self._use_minid:
1092
- self._logger.info("Creating new MINID for dataset %s", dataset.rid)
1093
- minid_url = self._create_dataset_minid(dataset, snapshot_catalog)
1094
- # If provided a MINID, use the MINID metadata to get the checksum and download the bag.
1075
+ rid = dataset.rid
1076
+
1077
+ # Case 1: RID is already a MINID or direct URL
1078
+ if rid.startswith("minid"):
1079
+ return self._fetch_minid_metadata(f"https://identifiers.org/{rid}", dataset.version)
1080
+ if rid.startswith("http"):
1081
+ return self._fetch_minid_metadata(rid, dataset.version)
1082
+
1083
+ # Case 2: RID is a dataset RID validate existence
1084
+ if not any(rid == ds["RID"] for ds in self.find_datasets()):
1085
+ raise DerivaMLTableTypeError("Dataset", rid)
1086
+
1087
+ # Find dataset version record
1088
+ version_str = str(dataset.version)
1089
+ history = self.dataset_history(rid)
1090
+ try:
1091
+ version_record = next(v for v in history if v.dataset_version == version_str)
1092
+ except StopIteration:
1093
+ raise DerivaMLException(f"Version {version_str} does not exist for RID {rid}")
1094
+
1095
+ # Check or create MINID
1096
+ minid_url = version_record.minid
1097
+ if not minid_url:
1098
+ if not create:
1099
+ raise DerivaMLException(f"Minid for dataset {rid} doesn't exist")
1095
1100
  if self._use_minid:
1096
- r = requests.get(minid_url, headers={"accept": "application/json"})
1097
- dataset_minid = DatasetMinid(
1098
- dataset_version=dataset.version, **r.json()
1099
- )
1100
- else:
1101
- dataset_minid = DatasetMinid(
1102
- dataset_version=dataset.version,
1103
- RID=f"{dataset.rid}@{dataset_version_record.snapshot}",
1104
- location=minid_url,
1105
- )
1106
- return dataset_minid
1101
+ self._logger.info("Creating new MINID for dataset %s", rid)
1102
+ minid_url = self._create_dataset_minid(dataset, snapshot_catalog)
1103
+
1104
+ # Return based on MINID usage
1105
+ if self._use_minid:
1106
+ return self._fetch_minid_metadata(minid_url, dataset.version)
1107
+
1108
+ return DatasetMinid(
1109
+ dataset_version=dataset.version,
1110
+ RID=f"{rid}@{version_record.snapshot}",
1111
+ location=minid_url,
1112
+ )
1113
+
1114
+ def _fetch_minid_metadata(self, url: str, version: DatasetVersion) -> DatasetMinid:
1115
+ r = requests.get(url, headers={"accept": "application/json"})
1116
+ r.raise_for_status()
1117
+ return DatasetMinid(dataset_version=version, **r.json())
1107
1118
 
1108
1119
  def _download_dataset_minid(self, minid: DatasetMinid) -> Path:
1109
- """Given a RID to a dataset_table, or a MINID to an existing bag, download the bag file, extract it, and validate
1110
- that all the metadata is correct
1120
+ """Given a RID to a dataset_table, or a MINID to an existing bag, download the bag file, extract it, and
1121
+ validate that all the metadata is correct
1111
1122
 
1112
1123
  Args:
1113
1124
  minid: The RID of a dataset_table or a minid to an existing bag.
@@ -1119,9 +1130,7 @@ class Dataset:
1119
1130
  # it. If not, then we need to extract the contents of the archive into our cache directory.
1120
1131
  bag_dir = self._cache_dir / f"{minid.dataset_rid}_{minid.checksum}"
1121
1132
  if bag_dir.exists():
1122
- self._logger.info(
1123
- f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}"
1124
- )
1133
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1125
1134
  return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1126
1135
 
1127
1136
  # Either bag hasn't been downloaded yet, or we are not using a Minid, so we don't know the checksum yet.
@@ -1130,19 +1139,13 @@ class Dataset:
1130
1139
  # Get bag from S3
1131
1140
  archive_path = fetch_single_file(minid.bag_url)
1132
1141
  else:
1133
- exporter = DerivaExport(
1134
- host=self._model.catalog.deriva_server.server, output_dir=tmp_dir
1135
- )
1142
+ exporter = DerivaExport(host=self._model.catalog.deriva_server.server, output_dir=tmp_dir)
1136
1143
  archive_path = exporter.retrieve_file(minid.bag_url)
1137
- hashes = hash_utils.compute_file_hashes(
1138
- archive_path, hashes=["md5", "sha256"]
1139
- )
1144
+ hashes = hash_utils.compute_file_hashes(archive_path, hashes=["md5", "sha256"])
1140
1145
  checksum = hashes["sha256"][0]
1141
1146
  bag_dir = self._cache_dir / f"{minid.dataset_rid}_{checksum}"
1142
1147
  if bag_dir.exists():
1143
- self._logger.info(
1144
- f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}"
1145
- )
1148
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1146
1149
  return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1147
1150
  bag_path = bdb.extract_bag(archive_path, bag_dir.as_posix())
1148
1151
  bdb.validate_bag_structure(bag_path)
@@ -1151,7 +1154,7 @@ class Dataset:
1151
1154
  def _materialize_dataset_bag(
1152
1155
  self,
1153
1156
  minid: DatasetMinid,
1154
- execution_rid: Optional[RID] = None,
1157
+ execution_rid: RID | None = None,
1155
1158
  ) -> Path:
1156
1159
  """Materialize a dataset_table bag into a local directory
1157
1160
 
@@ -1165,9 +1168,7 @@ class Dataset:
1165
1168
  def update_status(status: Status, msg: str) -> None:
1166
1169
  """Update the current status for this execution in the catalog"""
1167
1170
  if execution_rid and execution_rid != DRY_RUN_RID:
1168
- self._model.catalog.getPathBuilder().schemas[
1169
- self._ml_schema
1170
- ].Execution.update(
1171
+ self._model.catalog.getPathBuilder().schemas[self._ml_schema].Execution.update(
1171
1172
  [
1172
1173
  {
1173
1174
  "RID": execution_rid,
@@ -1197,9 +1198,7 @@ class Dataset:
1197
1198
 
1198
1199
  # If this bag has already been validated, our work is done. Otherwise, materialize the bag.
1199
1200
  if not validated_check.exists():
1200
- self._logger.info(
1201
- f"Materializing bag {minid.dataset_rid} Version:{minid.dataset_version}"
1202
- )
1201
+ self._logger.info(f"Materializing bag {minid.dataset_rid} Version:{minid.dataset_version}")
1203
1202
  bdb.materialize(
1204
1203
  bag_path.as_posix(),
1205
1204
  fetch_callback=fetch_progress_callback,
@@ -1210,7 +1209,7 @@ class Dataset:
1210
1209
 
1211
1210
  def _export_annotation(
1212
1211
  self,
1213
- snapshot_catalog: Optional[DerivaML] = None,
1212
+ snapshot_catalog: DerivaML | None = None,
1214
1213
  ) -> list[dict[str, Any]]:
1215
1214
  """Return and output specification for the datasets in the provided model
1216
1215
 
@@ -1242,7 +1241,7 @@ class Dataset:
1242
1241
  )
1243
1242
 
1244
1243
  def _export_specification(
1245
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1244
+ self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1246
1245
  ) -> list[dict[str, Any]]:
1247
1246
  """
1248
1247
  Generate a specification for export engine for specific dataset.
@@ -1258,14 +1257,10 @@ class Dataset:
1258
1257
  "processor": "json",
1259
1258
  "processor_params": {"query_path": "/schema", "output_path": "schema"},
1260
1259
  }
1261
- ] + self._dataset_specification(
1262
- self._export_specification_dataset_element, dataset, snapshot_catalog
1263
- )
1260
+ ] + self._dataset_specification(self._export_specification_dataset_element, dataset, snapshot_catalog)
1264
1261
 
1265
1262
  @staticmethod
1266
- def _export_specification_dataset_element(
1267
- spath: str, dpath: str, table: Table
1268
- ) -> list[dict[str, Any]]:
1263
+ def _export_specification_dataset_element(spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1269
1264
  """Return the download specification for the data object indicated by a path through the data model.
1270
1265
 
1271
1266
  Args:
@@ -1300,10 +1295,9 @@ class Dataset:
1300
1295
  )
1301
1296
  return exports
1302
1297
 
1303
- def _export_annotation_dataset_element(
1304
- self, spath: str, dpath: str, table: Table
1305
- ) -> list[dict[str, Any]]:
1306
- """Given a path in the data model, output an export specification for the path taken to get to the current table.
1298
+ def _export_annotation_dataset_element(self, spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1299
+ """Given a path in the data model, output an export specification for the path taken to get to the
1300
+ current table.
1307
1301
 
1308
1302
  Args:
1309
1303
  spath: Source path
@@ -1354,7 +1348,7 @@ class Dataset:
1354
1348
  return exports
1355
1349
 
1356
1350
  def _generate_dataset_download_spec(
1357
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1351
+ self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1358
1352
  ) -> dict[str, Any]:
1359
1353
  """
1360
1354
  Generate a specification for downloading a specific dataset.
@@ -1457,9 +1451,7 @@ class Dataset:
1457
1451
  else {}
1458
1452
  )
1459
1453
  return {
1460
- deriva_tags.export_fragment_definitions: {
1461
- "dataset_export_outputs": self._export_annotation()
1462
- },
1454
+ deriva_tags.export_fragment_definitions: {"dataset_export_outputs": self._export_annotation()},
1463
1455
  deriva_tags.visible_foreign_keys: self._dataset_visible_fkeys(),
1464
1456
  deriva_tags.export_2019: {
1465
1457
  "detailed": {