deriva-ml 1.14.0__py3-none-any.whl → 1.14.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. deriva_ml/__init__.py +25 -30
  2. deriva_ml/core/__init__.py +39 -0
  3. deriva_ml/core/base.py +1489 -0
  4. deriva_ml/core/constants.py +36 -0
  5. deriva_ml/core/definitions.py +74 -0
  6. deriva_ml/core/enums.py +222 -0
  7. deriva_ml/core/ermrest.py +288 -0
  8. deriva_ml/core/exceptions.py +28 -0
  9. deriva_ml/core/filespec.py +116 -0
  10. deriva_ml/dataset/__init__.py +4 -0
  11. deriva_ml/{dataset_aux_classes.py → dataset/aux_classes.py} +16 -12
  12. deriva_ml/{dataset.py → dataset/dataset.py} +405 -428
  13. deriva_ml/{dataset_bag.py → dataset/dataset_bag.py} +137 -97
  14. deriva_ml/{history.py → dataset/history.py} +51 -33
  15. deriva_ml/{upload.py → dataset/upload.py} +48 -70
  16. deriva_ml/demo_catalog.py +233 -183
  17. deriva_ml/execution/environment.py +290 -0
  18. deriva_ml/{execution.py → execution/execution.py} +365 -252
  19. deriva_ml/execution/execution_configuration.py +163 -0
  20. deriva_ml/{execution_configuration.py → execution/workflow.py} +206 -218
  21. deriva_ml/feature.py +83 -46
  22. deriva_ml/model/__init__.py +0 -0
  23. deriva_ml/{deriva_model.py → model/catalog.py} +113 -132
  24. deriva_ml/{database_model.py → model/database.py} +52 -74
  25. deriva_ml/model/sql_mapper.py +44 -0
  26. deriva_ml/run_notebook.py +19 -11
  27. deriva_ml/schema/__init__.py +3 -0
  28. deriva_ml/{schema_setup → schema}/annotations.py +31 -22
  29. deriva_ml/schema/check_schema.py +104 -0
  30. deriva_ml/{schema_setup → schema}/create_schema.py +151 -104
  31. deriva_ml/schema/deriva-ml-reference.json +8525 -0
  32. deriva_ml/schema/table_comments_utils.py +57 -0
  33. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.26.dist-info}/METADATA +5 -4
  34. deriva_ml-1.14.26.dist-info/RECORD +40 -0
  35. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.26.dist-info}/entry_points.txt +1 -0
  36. deriva_ml/deriva_definitions.py +0 -391
  37. deriva_ml/deriva_ml_base.py +0 -1046
  38. deriva_ml/execution_environment.py +0 -139
  39. deriva_ml/schema_setup/table_comments_utils.py +0 -56
  40. deriva_ml/test-files/execution-parameters.json +0 -1
  41. deriva_ml/test-files/notebook-parameters.json +0 -5
  42. deriva_ml/test_functions.py +0 -141
  43. deriva_ml/test_notebook.ipynb +0 -197
  44. deriva_ml-1.14.0.dist-info/RECORD +0 -31
  45. /deriva_ml/{schema_setup → execution}/__init__.py +0 -0
  46. /deriva_ml/{schema_setup → schema}/policy.json +0 -0
  47. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.26.dist-info}/WHEEL +0 -0
  48. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.26.dist-info}/licenses/LICENSE +0 -0
  49. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.26.dist-info}/top_level.txt +0 -0
@@ -1,74 +1,110 @@
1
- """
2
- This module defines the DataSet class with is used to manipulate datasets in DerivaML.
3
- The intended use of this class is as a base class in DerivaML, so all the methods documented here are
4
- accessible via a DerivaML class instance.
5
-
1
+ """Dataset management for DerivaML.
2
+
3
+ This module provides functionality for managing datasets in DerivaML. A dataset represents a collection
4
+ of related data that can be versioned, downloaded, and tracked. The module includes:
5
+
6
+ - Dataset class: Core class for dataset operations
7
+ - Version management: Track and update dataset versions
8
+ - History tracking: Record dataset changes over time
9
+ - Download capabilities: Export datasets as BDBags
10
+ - Relationship management: Handle dataset dependencies and hierarchies
11
+
12
+ The Dataset class serves as a base class in DerivaML, making its methods accessible through
13
+ DerivaML class instances.
14
+
15
+ Typical usage example:
16
+ >>> ml = DerivaML('deriva.example.org', 'my_catalog')
17
+ >>> dataset_rid = ml.create_dataset('experiment', 'Experimental data')
18
+ >>> ml.add_dataset_members(dataset_rid=dataset_rid, members=['1-abc123', '1-def456'])
19
+ >>> ml.increment_dataset_version(datset_rid=dataset_rid, component=VersionPart.minor,
20
+ ... description='Added new samples')
6
21
  """
7
22
 
8
23
  from __future__ import annotations
9
- from bdbag import bdbag_api as bdb
10
- from bdbag.fetch.fetcher import fetch_single_file
11
- from collections import defaultdict
12
- from graphlib import TopologicalSorter
24
+
25
+ # Standard library imports
13
26
  import json
14
27
  import logging
28
+ from collections import defaultdict
29
+ from graphlib import TopologicalSorter
15
30
  from pathlib import Path
16
- from pydantic import (
17
- validate_call,
18
- ConfigDict,
19
- )
20
- import requests
21
31
  from tempfile import TemporaryDirectory
22
- from typing import Any, Callable, Optional, Iterable, Iterator, TYPE_CHECKING
32
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator
23
33
 
24
- from .history import iso_to_snap
25
- from deriva.core.ermrest_model import Table
26
- from deriva.core.utils.core_utils import tag as deriva_tags, format_exception
27
34
  import deriva.core.utils.hash_utils as hash_utils
28
- from deriva.transfer.download.deriva_export import DerivaExport
35
+ import requests
36
+
37
+ # Third-party imports
38
+ from bdbag import bdbag_api as bdb
39
+ from bdbag.fetch.fetcher import fetch_single_file
40
+
41
+ # Deriva imports
42
+ from deriva.core.ermrest_model import Table
43
+ from deriva.core.utils.core_utils import format_exception
44
+ from deriva.core.utils.core_utils import tag as deriva_tags
29
45
  from deriva.transfer.download.deriva_download import (
30
- DerivaDownloadConfigurationError,
31
- DerivaDownloadError,
32
46
  DerivaDownloadAuthenticationError,
33
47
  DerivaDownloadAuthorizationError,
48
+ DerivaDownloadConfigurationError,
49
+ DerivaDownloadError,
34
50
  DerivaDownloadTimeoutError,
35
51
  )
52
+ from deriva.transfer.download.deriva_export import DerivaExport
53
+ from pydantic import ConfigDict, validate_call
36
54
 
37
-
55
+ # Local imports
38
56
  try:
39
57
  from icecream import ic
58
+
59
+ ic.configureOutput(includeContext=True)
40
60
  except ImportError: # Graceful fallback if IceCream isn't installed.
41
61
  ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
42
62
 
43
- from deriva_ml import DatasetBag
44
- from .deriva_definitions import (
63
+ from deriva_ml.core.constants import RID
64
+ from deriva_ml.core.definitions import (
65
+ DRY_RUN_RID,
45
66
  ML_SCHEMA,
46
- DerivaMLException,
47
67
  MLVocab,
48
68
  Status,
49
- RID,
50
- DRY_RUN_RID,
51
69
  )
52
- from .deriva_model import DerivaModel
53
- from .database_model import DatabaseModel
54
- from .dataset_aux_classes import (
55
- DatasetVersion,
56
- DatasetMinid,
70
+ from deriva_ml.core.exceptions import DerivaMLException, DerivaMLTableTypeError
71
+ from deriva_ml.dataset.aux_classes import (
57
72
  DatasetHistory,
58
- VersionPart,
73
+ DatasetMinid,
59
74
  DatasetSpec,
75
+ DatasetVersion,
76
+ VersionPart,
60
77
  )
78
+ from deriva_ml.dataset.dataset_bag import DatasetBag
79
+ from deriva_ml.model.catalog import DerivaModel
80
+ from deriva_ml.model.database import DatabaseModel
81
+
82
+ from .history import iso_to_snap
83
+
84
+ # Stop pycharm from complaining about undefined reference in docstring....
85
+ ml: DerivaML
61
86
 
62
87
  if TYPE_CHECKING:
63
- from .deriva_ml_base import DerivaML
88
+ from deriva_ml.core.base import DerivaML
64
89
 
65
90
 
66
91
  class Dataset:
67
- """
68
- Class to manipulate a dataset.
92
+ """Manages dataset operations in a Deriva catalog.
93
+
94
+ The Dataset class provides functionality for creating, modifying, and tracking datasets
95
+ in a Deriva catalog. It handles versioning, relationships between datasets, and data export.
69
96
 
70
97
  Attributes:
71
- dataset_table (Table): ERMRest table holding dataset information.
98
+ dataset_table (Table): ERMrest table storing dataset information.
99
+ _model (DerivaModel): Catalog model instance.
100
+ _ml_schema (str): Schema name for ML-specific tables.
101
+ _cache_dir (Path): Directory for caching downloaded datasets.
102
+ _working_dir (Path): Directory for working data.
103
+ _use_minid (bool): Whether to use MINID service for dataset identification.
104
+
105
+ Note:
106
+ This class is typically used as a base class, with its methods accessed through
107
+ DerivaML class instances rather than directly.
72
108
  """
73
109
 
74
110
  _Logger = logging.getLogger("deriva_ml")
@@ -80,20 +116,31 @@ class Dataset:
80
116
  working_dir: Path,
81
117
  use_minid: bool = True,
82
118
  ):
119
+ """Initializes a Dataset instance.
120
+
121
+ Args:
122
+ model: DerivaModel instance representing the catalog.
123
+ cache_dir: Directory path for caching downloaded datasets.
124
+ working_dir: Directory path for working data.
125
+ use_minid: Whether to use MINID service for dataset identification.
126
+ """
83
127
  self._model = model
84
128
  self._ml_schema = ML_SCHEMA
85
- self.dataset_table = self._model.schemas[self._ml_schema].tables["Dataset"]
86
129
  self._cache_dir = cache_dir
87
130
  self._working_dir = working_dir
88
131
  self._logger = logging.getLogger("deriva_ml")
89
132
  self._use_minid = use_minid
90
133
 
134
+ @property
135
+ def _dataset_table(self):
136
+ return self._model.schemas[self._ml_schema].tables["Dataset"]
137
+
91
138
  def _is_dataset_rid(self, dataset_rid: RID, deleted: bool = False) -> bool:
92
139
  try:
93
140
  rid_info = self._model.catalog.resolve_rid(dataset_rid, self._model.model)
94
141
  except KeyError as _e:
95
142
  raise DerivaMLException(f"Invalid RID {dataset_rid}")
96
- if rid_info.table != self.dataset_table:
143
+ if rid_info.table != self._dataset_table:
97
144
  return False
98
145
  elif deleted:
99
146
  # Got a dataset rid. Now check to see if its deleted or not.
@@ -104,12 +151,12 @@ class Dataset:
104
151
  def _insert_dataset_versions(
105
152
  self,
106
153
  dataset_list: list[DatasetSpec],
107
- description: Optional[str] = "",
108
- execution_rid: Optional[RID] = None,
154
+ description: str | None = "",
155
+ execution_rid: RID | None = None,
109
156
  ) -> None:
110
157
  schema_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema]
111
158
  # determine snapshot after changes were made
112
- snap = self._model.catalog.get("/").json()["snaptime"]
159
+
113
160
  # Construct version records for insert
114
161
  version_records = schema_path.tables["Dataset_Version"].insert(
115
162
  [
@@ -118,16 +165,18 @@ class Dataset:
118
165
  "Version": str(dataset.version),
119
166
  "Description": description,
120
167
  "Execution": execution_rid,
121
- "Snapshot": snap,
122
168
  }
123
169
  for dataset in dataset_list
124
170
  ]
125
171
  )
172
+ version_records = list(version_records)
173
+ snap = self._model.catalog.get("/").json()["snaptime"]
174
+ schema_path.tables["Dataset_Version"].update(
175
+ [{"RID": v["RID"], "Dataset": v["Dataset"], "Snapshot": snap} for v in version_records]
176
+ )
126
177
 
127
178
  # And update the dataset records.
128
- schema_path.tables["Dataset"].update(
129
- [{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records]
130
- )
179
+ schema_path.tables["Dataset"].update([{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records])
131
180
 
132
181
  def _bootstrap_versions(self):
133
182
  datasets = [ds["RID"] for ds in self.find_datasets()]
@@ -143,9 +192,7 @@ class Dataset:
143
192
  version_path = schema_path.tables["Dataset_Version"]
144
193
  dataset_path = schema_path.tables["Dataset"]
145
194
  history = list(version_path.insert(ds_version))
146
- dataset_versions = [
147
- {"RID": h["Dataset"], "Version": h["Version"]} for h in history
148
- ]
195
+ dataset_versions = [{"RID": h["Dataset"], "Version": h["Version"]} for h in history]
149
196
  dataset_path.update(dataset_versions)
150
197
 
151
198
  def _synchronize_dataset_versions(self):
@@ -161,45 +208,46 @@ class Dataset:
161
208
  versions[v["Dataset"]] = v
162
209
  dataset_path = schema_path.tables["Dataset"]
163
210
 
164
- dataset_path.update(
165
- [
166
- {"RID": dataset, "Version": version["RID"]}
167
- for dataset, version in versions.items()
168
- ]
169
- )
211
+ dataset_path.update([{"RID": dataset, "Version": version["RID"]} for dataset, version in versions.items()])
170
212
 
171
213
  def _set_version_snapshot(self):
172
- dataset_version_path = (
173
- self._model.catalog.getPathBuilder()
174
- .schemas[self._ml_schema]
175
- .tables["Dataset_Version"]
176
- )
214
+ dataset_version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
177
215
  versions = dataset_version_path.entities().fetch()
178
216
  dataset_version_path.update(
179
- [
180
- {"RID": h["RID"], "Snapshot": iso_to_snap(h["RCT"])}
181
- for h in versions
182
- if not h["Snapshot"]
183
- ]
217
+ [{"RID": h["RID"], "Snapshot": iso_to_snap(h["RCT"])} for h in versions if not h["Snapshot"]]
184
218
  )
185
219
 
186
220
  def dataset_history(self, dataset_rid: RID) -> list[DatasetHistory]:
187
- """Return a list of DatasetHistory objects representing the dataset
221
+ """Retrieves the version history of a dataset.
222
+
223
+ Returns a chronological list of dataset versions, including their version numbers,
224
+ creation times, and associated metadata.
188
225
 
189
226
  Args:
190
- dataset_rid: A RID to the dataset for which history is to be fetched.
227
+ dataset_rid: Resource Identifier of the dataset.
191
228
 
192
229
  Returns:
193
- A list of DatasetHistory objects which indicate the version-number, creation time, and bag instantiation of the dataset.
230
+ list[DatasetHistory]: List of history entries, each containing:
231
+ - dataset_version: Version number (major.minor.patch)
232
+ - minid: Minimal Viable Identifier
233
+ - snapshot: Catalog snapshot time
234
+ - dataset_rid: Dataset Resource Identifier
235
+ - version_rid: Version Resource Identifier
236
+ - description: Version description
237
+ - execution_rid: Associated execution RID
238
+
239
+ Raises:
240
+ DerivaMLException: If dataset_rid is not a valid dataset RID.
241
+
242
+ Example:
243
+ >>> history = ml.dataset_history("1-abc123")
244
+ >>> for entry in history:
245
+ ... print(f"Version {entry.dataset_version}: {entry.description}")
194
246
  """
195
247
 
196
248
  if not self._is_dataset_rid(dataset_rid):
197
249
  raise DerivaMLException(f"RID is not for a data set: {dataset_rid}")
198
- version_path = (
199
- self._model.catalog.getPathBuilder()
200
- .schemas[self._ml_schema]
201
- .tables["Dataset_Version"]
202
- )
250
+ version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
203
251
  return [
204
252
  DatasetHistory(
205
253
  dataset_version=DatasetVersion.parse(v["Version"]),
@@ -210,9 +258,7 @@ class Dataset:
210
258
  description=v["Description"],
211
259
  execution_rid=v["Execution"],
212
260
  )
213
- for v in version_path.filter(version_path.Dataset == dataset_rid)
214
- .entities()
215
- .fetch()
261
+ for v in version_path.filter(version_path.Dataset == dataset_rid).entities().fetch()
216
262
  ]
217
263
 
218
264
  @validate_call
@@ -234,14 +280,16 @@ class Dataset:
234
280
  if not history:
235
281
  return DatasetVersion(0, 1, 0)
236
282
  else:
237
- return max([h.dataset_version for h in self.dataset_history(dataset_rid)])
283
+ # Ensure we return a DatasetVersion, not a string
284
+ versions = [h.dataset_version for h in history]
285
+ return max(versions) if versions else DatasetVersion(0, 1, 0)
238
286
 
239
287
  def _build_dataset_graph(self, dataset_rid: RID) -> Iterable[RID]:
240
- ts = TopologicalSorter()
288
+ ts: TopologicalSorter = TopologicalSorter()
241
289
  self._build_dataset_graph_1(dataset_rid, ts, set())
242
290
  return ts.static_order()
243
291
 
244
- def _build_dataset_graph_1(self, dataset_rid: RID, ts, visited) -> None:
292
+ def _build_dataset_graph_1(self, dataset_rid: RID, ts: TopologicalSorter, visited) -> None:
245
293
  """Use topological sort to return bottom up list of nested datasets"""
246
294
  ts.add(dataset_rid)
247
295
  if dataset_rid not in visited:
@@ -249,7 +297,8 @@ class Dataset:
249
297
  children = self.list_dataset_children(dataset_rid=dataset_rid)
250
298
  parents = self.list_dataset_parents(dataset_rid=dataset_rid)
251
299
  for parent in parents:
252
- self._build_dataset_graph_1(parent, ts, visited)
300
+ # Convert string to RID type
301
+ self._build_dataset_graph_1(RID(parent), ts, visited)
253
302
  for child in children:
254
303
  self._build_dataset_graph_1(child, ts, visited)
255
304
 
@@ -258,22 +307,34 @@ class Dataset:
258
307
  self,
259
308
  dataset_rid: RID,
260
309
  component: VersionPart,
261
- description: Optional[str] = "",
262
- execution_rid: Optional[RID] = None,
310
+ description: str | None = "",
311
+ execution_rid: RID | None = None,
263
312
  ) -> DatasetVersion:
264
- """Increment the version of the specified dataset_table.
313
+ """Increments a dataset's version number.
314
+
315
+ Creates a new version of the dataset by incrementing the specified version component
316
+ (major, minor, or patch). The new version is recorded with an optional description
317
+ and execution reference.
265
318
 
266
319
  Args:
267
- dataset_rid: RID of the dataset whose version is to be incremented.
268
- component: Which version of the dataset_table to increment. Major, Minor, or Patch
269
- description: Description of the version update of the dataset_table.
270
- execution_rid: Which execution is performing increment.
320
+ dataset_rid: Resource Identifier of the dataset to version.
321
+ component: Which version component to increment ('major', 'minor', or 'patch').
322
+ description: Optional description of the changes in this version.
323
+ execution_rid: Optional execution RID to associate with this version.
271
324
 
272
325
  Returns:
273
- new semantic version of the dataset_table as a 3-tuple
326
+ DatasetVersion: The new version number.
274
327
 
275
328
  Raises:
276
- DerivaMLException: if provided, RID is not to a dataset_table.
329
+ DerivaMLException: If dataset_rid is invalid or version increment fails.
330
+
331
+ Example:
332
+ >>> new_version = ml.increment_dataset_version(
333
+ ... dataset_rid="1-abc123",
334
+ ... component="minor",
335
+ ... description="Added new samples"
336
+ ... )
337
+ >>> print(f"New version: {new_version}") # e.g., "1.2.0"
277
338
  """
278
339
 
279
340
  # Find all the datasets that are reachable from this dataset and determine their new version numbers.
@@ -285,46 +346,51 @@ class Dataset:
285
346
  )
286
347
  for ds_rid in related_datasets
287
348
  ]
288
- self._insert_dataset_versions(
289
- version_update_list, description=description, execution_rid=execution_rid
290
- )
291
- return [d.version for d in version_update_list if d.rid == dataset_rid][0]
349
+ self._insert_dataset_versions(version_update_list, description=description, execution_rid=execution_rid)
350
+ return next((d.version for d in version_update_list if d.rid == dataset_rid))
292
351
 
293
352
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
294
353
  def create_dataset(
295
354
  self,
296
- dataset_types: str | list[str],
297
- description: str,
298
- execution_rid: Optional[RID] = None,
299
- version: Optional[DatasetVersion] = None,
355
+ dataset_types: str | list[str] | None = None,
356
+ description: str = "",
357
+ execution_rid: RID | None = None,
358
+ version: DatasetVersion | None = None,
300
359
  ) -> RID:
301
- """Create a new dataset_table from the specified list of RIDs.
360
+ """Creates a new dataset in the catalog.
361
+
362
+ Creates a dataset with specified types and description. The dataset can be associated
363
+ with an execution and initialized with a specific version.
302
364
 
303
365
  Args:
304
- dataset_types: One or more dataset_table types. Must be a term from the DatasetType controlled vocabulary.
305
- description: Description of the dataset_table.
306
- execution_rid: Execution under which the dataset_table will be created.
307
- version: Version of the dataset_table.
366
+ dataset_types: One or more dataset type terms from Dataset_Type vocabulary.
367
+ description: Description of the dataset's purpose and contents.
368
+ execution_rid: Optional execution RID to associate with dataset creation.
369
+ version: Optional initial version number. Defaults to 0.1.0.
308
370
 
309
371
  Returns:
310
- New dataset_table RID.
372
+ RID: Resource Identifier of the newly created dataset.
311
373
 
374
+ Raises:
375
+ DerivaMLException: If dataset_types are invalid or creation fails.
376
+
377
+ Example:
378
+ >>> rid = ml.create_dataset(
379
+ ... dataset_types=["experiment", "raw_data"],
380
+ ... description="RNA sequencing experiment data",
381
+ ... version=DatasetVersion(1, 0, 0)
382
+ ... )
312
383
  """
313
384
 
314
385
  version = version or DatasetVersion(0, 1, 0)
386
+ dataset_types = dataset_types or []
315
387
 
316
- type_path = (
317
- self._model.catalog.getPathBuilder()
318
- .schemas[self._ml_schema]
319
- .tables[MLVocab.dataset_type.value]
320
- )
388
+ type_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables[MLVocab.dataset_type.value]
321
389
  defined_types = list(type_path.entities().fetch())
322
390
 
323
391
  def check_dataset_type(dtype: str) -> bool:
324
392
  for term in defined_types:
325
- if dtype == term["Name"] or (
326
- term["Synonyms"] and ds_type in term["Synonyms"]
327
- ):
393
+ if dtype == term["Name"] or (term["Synonyms"] and ds_type in term["Synonyms"]):
328
394
  return True
329
395
  return False
330
396
 
@@ -334,9 +400,7 @@ class Dataset:
334
400
  for ds_type in ds_types:
335
401
  if not check_dataset_type(ds_type):
336
402
  raise DerivaMLException("Dataset type must be a vocabulary term.")
337
- dataset_table_path = pb.schemas[self.dataset_table.schema.name].tables[
338
- self.dataset_table.name
339
- ]
403
+ dataset_table_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
340
404
  dataset_rid = dataset_table_path.insert(
341
405
  [
342
406
  {
@@ -347,21 +411,13 @@ class Dataset:
347
411
  )[0]["RID"]
348
412
 
349
413
  # Get the name of the association table between dataset_table and dataset_type.
350
- atable = next(
351
- self._model.schemas[self._ml_schema]
352
- .tables[MLVocab.dataset_type]
353
- .find_associations()
354
- ).name
414
+ associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
415
+ atable = associations[0].name if associations else None
355
416
  pb.schemas[self._ml_schema].tables[atable].insert(
356
- [
357
- {MLVocab.dataset_type: ds_type, "Dataset": dataset_rid}
358
- for ds_type in ds_types
359
- ]
417
+ [{MLVocab.dataset_type: ds_type, "Dataset": dataset_rid} for ds_type in ds_types]
360
418
  )
361
419
  if execution_rid is not None:
362
- pb.schemas[self._ml_schema].Dataset_Execution.insert(
363
- [{"Dataset": dataset_rid, "Execution": execution_rid}]
364
- )
420
+ pb.schemas[self._ml_schema].Dataset_Execution.insert([{"Dataset": dataset_rid, "Execution": execution_rid}])
365
421
  self._insert_dataset_versions(
366
422
  [DatasetSpec(rid=dataset_rid, version=version)],
367
423
  execution_rid=execution_rid,
@@ -383,18 +439,12 @@ class Dataset:
383
439
  raise DerivaMLException("Dataset_rid is not a dataset.")
384
440
 
385
441
  if parents := self.list_dataset_parents(dataset_rid):
386
- raise DerivaMLException(
387
- f'Dataset_rid "{dataset_rid}" is in a nested dataset: {parents}.'
388
- )
442
+ raise DerivaMLException(f'Dataset_rid "{dataset_rid}" is in a nested dataset: {parents}.')
389
443
 
390
444
  pb = self._model.catalog.getPathBuilder()
391
- dataset_path = pb.schemas[self.dataset_table.schema.name].tables[
392
- self.dataset_table.name
393
- ]
445
+ dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
394
446
 
395
- rid_list = [dataset_rid] + (
396
- self.list_dataset_children(dataset_rid) if recurse else []
397
- )
447
+ rid_list = [dataset_rid] + (self.list_dataset_children(dataset_rid=dataset_rid) if recurse else [])
398
448
  dataset_path.update([{"RID": r, "Deleted": True} for r in rid_list])
399
449
 
400
450
  def find_datasets(self, deleted: bool = False) -> Iterable[dict[str, Any]]:
@@ -408,14 +458,9 @@ class Dataset:
408
458
  """
409
459
  # Get datapath to all the tables we will need: Dataset, DatasetType and the association table.
410
460
  pb = self._model.catalog.getPathBuilder()
411
- dataset_path = pb.schemas[self.dataset_table.schema.name].tables[
412
- self.dataset_table.name
413
- ]
414
- atable = next(
415
- self._model.schemas[self._ml_schema]
416
- .tables[MLVocab.dataset_type]
417
- .find_associations()
418
- ).name
461
+ dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
462
+ associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
463
+ atable = associations[0].name if associations else None
419
464
  ml_path = pb.schemas[self._ml_schema]
420
465
  atable_path = ml_path.tables[atable]
421
466
 
@@ -423,21 +468,16 @@ class Dataset:
423
468
  filtered_path = dataset_path
424
469
  else:
425
470
  filtered_path = dataset_path.filter(
426
- (dataset_path.Deleted == False) | (dataset_path.Deleted == None) # noqa: E712
471
+ (dataset_path.Deleted == False) | (dataset_path.Deleted == None) # noqa: E711, E712
427
472
  )
428
473
 
429
474
  # Get a list of all the dataset_type values associated with this dataset_table.
430
475
  datasets = []
431
476
  for dataset in filtered_path.entities().fetch():
432
477
  ds_types = (
433
- atable_path.filter(atable_path.Dataset == dataset["RID"])
434
- .attributes(atable_path.Dataset_Type)
435
- .fetch()
436
- )
437
- datasets.append(
438
- dataset
439
- | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in ds_types]}
478
+ atable_path.filter(atable_path.Dataset == dataset["RID"]).attributes(atable_path.Dataset_Type).fetch()
440
479
  )
480
+ datasets.append(dataset | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in ds_types]})
441
481
  return datasets
442
482
 
443
483
  def list_dataset_element_types(self) -> Iterable[Table]:
@@ -448,16 +488,9 @@ class Dataset:
448
488
  """
449
489
 
450
490
  def domain_table(table: Table) -> bool:
451
- return (
452
- table.schema.name == self._model.domain_schema
453
- or table.name == self.dataset_table.name
454
- )
491
+ return table.schema.name == self._model.domain_schema or table.name == self._dataset_table.name
455
492
 
456
- return [
457
- t
458
- for a in self.dataset_table.find_associations()
459
- if domain_table(t := a.other_fkeys.pop().pk_table)
460
- ]
493
+ return [t for a in self._dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
461
494
 
462
495
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
463
496
  def add_dataset_element_type(self, element: str | Table) -> Table:
@@ -472,31 +505,45 @@ class Dataset:
472
505
  """
473
506
  # Add table to map
474
507
  element_table = self._model.name_to_table(element)
475
- table = self._model.schemas[self._model.domain_schema].create_table(
476
- Table.define_association([self.dataset_table, element_table])
477
- )
508
+ atable_def = Table.define_association([self._dataset_table, element_table])
509
+ try:
510
+ table = self._model.schemas[self._model.domain_schema].create_table(atable_def)
511
+ except ValueError as e:
512
+ if "already exists" in str(e):
513
+ table = self._model.name_to_table(atable_def["table_name"])
514
+ else:
515
+ raise e
478
516
 
479
517
  # self.model = self.catalog.getCatalogModel()
480
- self.dataset_table.annotations.update(
481
- self._generate_dataset_download_annotations()
482
- )
518
+ self._dataset_table.annotations.update(self._generate_dataset_download_annotations())
483
519
  self._model.model.apply()
484
520
  return table
485
521
 
486
522
  # @validate_call
487
523
  def list_dataset_members(
488
- self, dataset_rid: RID, recurse: bool = False, limit: Optional[int] = None
524
+ self, dataset_rid: RID, recurse: bool = False, limit: int | None = None
489
525
  ) -> dict[str, list[dict[str, Any]]]:
490
- """Return a list of entities associated with a specific dataset_table.
526
+ """Lists members of a dataset.
527
+
528
+ Returns a dictionary mapping member types to lists of member records. Can optionally
529
+ recurse through nested datasets and limit the number of results.
491
530
 
492
531
  Args:
493
- dataset_rid: param recurse: If this is a nested dataset_table, list the members of the contained datasets
494
- recurse: (Default value = False)
495
- limit: If provided, the maximum number of members to return for each element type.
532
+ dataset_rid: Resource Identifier of the dataset.
533
+ recurse: Whether to include members of nested datasets. Defaults to False.
534
+ limit: Maximum number of members to return per type. None for no limit.
496
535
 
497
536
  Returns:
498
- Dictionary of entities associated with a specific dataset_table. Key is the table from which the elements
499
- were taken.
537
+ dict[str, list[dict[str, Any]]]: Dictionary mapping member types to lists of members.
538
+ Each member is a dictionary containing the record's attributes.
539
+
540
+ Raises:
541
+ DerivaMLException: If dataset_rid is invalid.
542
+
543
+ Example:
544
+ >>> members = ml.list_dataset_members("1-abc123", recurse=True)
545
+ >>> for type_name, records in members.items():
546
+ ... print(f"{type_name}: {len(records)} records")
500
547
  """
501
548
 
502
549
  if not self._is_dataset_rid(dataset_rid):
@@ -506,21 +553,18 @@ class Dataset:
506
553
  # the appropriate association table.
507
554
  members = defaultdict(list)
508
555
  pb = self._model.catalog.getPathBuilder()
509
- for assoc_table in self.dataset_table.find_associations():
556
+ for assoc_table in self._dataset_table.find_associations():
510
557
  other_fkey = assoc_table.other_fkeys.pop()
511
558
  target_table = other_fkey.pk_table
512
559
  member_table = assoc_table.table
513
560
 
514
561
  # Look at domain tables and nested datasets.
515
- if (
516
- target_table.schema.name != self._model.domain_schema
517
- and target_table != self.dataset_table
562
+ if target_table.schema.name != self._model.domain_schema and not (
563
+ target_table == self._dataset_table or target_table.name == "File"
518
564
  ):
519
565
  continue
520
566
  member_column = (
521
- "Nested_Dataset"
522
- if target_table == self.dataset_table
523
- else other_fkey.foreign_key_columns[0].name
567
+ "Nested_Dataset" if target_table == self._dataset_table else other_fkey.foreign_key_columns[0].name
524
568
  )
525
569
 
526
570
  target_path = pb.schemas[target_table.schema.name].tables[target_table.name]
@@ -530,15 +574,13 @@ class Dataset:
530
574
  target_path,
531
575
  on=(member_path.columns[member_column] == target_path.columns["RID"]),
532
576
  )
533
- target_entities = list(
534
- path.entities().fetch(limit=limit) if limit else path.entities().fetch()
535
- )
577
+ target_entities = list(path.entities().fetch(limit=limit) if limit else path.entities().fetch())
536
578
  members[target_table.name].extend(target_entities)
537
- if recurse and target_table == self.dataset_table:
579
+ if recurse and target_table == self._dataset_table:
538
580
  # Get the members for all the nested datasets and add to the member list.
539
581
  nested_datasets = [d["RID"] for d in target_entities]
540
582
  for ds in nested_datasets:
541
- for k, v in self.list_dataset_members(ds, recurse=False).items():
583
+ for k, v in self.list_dataset_members(ds, recurse=recurse).items():
542
584
  members[k].extend(v)
543
585
  return dict(members)
544
586
 
@@ -546,24 +588,38 @@ class Dataset:
546
588
  def add_dataset_members(
547
589
  self,
548
590
  dataset_rid: RID,
549
- members: list[RID],
591
+ members: list[RID] | dict[str, list[RID]],
550
592
  validate: bool = True,
551
- description: Optional[str] = "",
552
- execution_rid: Optional[RID] = None,
593
+ description: str | None = "",
594
+ execution_rid: RID | None = None,
553
595
  ) -> None:
554
- """Add additional elements to an existing dataset_table.
596
+ """Adds members to a dataset.
555
597
 
556
- Add new elements to an existing dataset. In addition to adding new members, the minor version number of the
557
- dataset is incremented and the description, if provide is applied to that new version.
598
+ Associates one or more records with a dataset. Can optionally validate member types
599
+ and create a new dataset version to track the changes.
558
600
 
559
601
  Args:
560
- dataset_rid: RID of dataset_table to extend or None if a new dataset_table is to be created.
561
- members: List of member RIDs to add to the dataset_table.
562
- validate: Check rid_list to make sure elements are not already in the dataset_table.
563
- description: Markdown description of the updated dataset.
564
- execution_rid: Optional RID of execution associated with this dataset.
602
+ dataset_rid: Resource Identifier of the dataset.
603
+ members: List of RIDs to add as dataset members. Can be orginized into a dictionary that indicates the
604
+ table that the member rids belong to.
605
+ validate: Whether to validate member types. Defaults to True.
606
+ description: Optional description of the member additions.
607
+ execution_rid: Optional execution RID to associate with changes.
608
+
609
+ Raises:
610
+ DerivaMLException: If:
611
+ - dataset_rid is invalid
612
+ - members are invalid or of wrong type
613
+ - adding members would create a cycle
614
+ - validation fails
615
+
616
+ Example:
617
+ >>> ml.add_dataset_members(
618
+ ... dataset_rid="1-abc123",
619
+ ... members=["1-def456", "1-ghi789"],
620
+ ... description="Added sample data"
621
+ ... )
565
622
  """
566
- members = set(members)
567
623
  description = description or "Updated dataset via add_dataset_members"
568
624
 
569
625
  def check_dataset_cycle(member_rid, path=None):
@@ -580,43 +636,37 @@ class Dataset:
580
636
  return member_rid in path
581
637
 
582
638
  if validate:
583
- existing_rids = set(
584
- m["RID"]
585
- for ms in self.list_dataset_members(dataset_rid).values()
586
- for m in ms
587
- )
639
+ existing_rids = set(m["RID"] for ms in self.list_dataset_members(dataset_rid).values() for m in ms)
588
640
  if overlap := set(existing_rids).intersection(members):
589
- raise DerivaMLException(
590
- f"Attempting to add existing member to dataset_table {dataset_rid}: {overlap}"
591
- )
641
+ raise DerivaMLException(f"Attempting to add existing member to dataset_table {dataset_rid}: {overlap}")
592
642
 
593
643
  # Now go through every rid to be added to the data set and sort them based on what association table entries
594
644
  # need to be made.
595
645
  dataset_elements = {}
596
646
  association_map = {
597
- a.other_fkeys.pop().pk_table.name: a.table.name
598
- for a in self.dataset_table.find_associations()
647
+ a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
599
648
  }
649
+
600
650
  # Get a list of all the object types that can be linked to a dataset_table.
601
- for m in members:
602
- try:
603
- rid_info = self._model.catalog.resolve_rid(m)
604
- except KeyError:
605
- raise DerivaMLException(f"Invalid RID: {m}")
606
- if rid_info.table.name not in association_map:
607
- raise DerivaMLException(
608
- f"RID table: {rid_info.table.name} not part of dataset_table"
609
- )
610
- if rid_info.table == self.dataset_table and check_dataset_cycle(
611
- rid_info.rid
612
- ):
613
- raise DerivaMLException("Creating cycle of datasets is not allowed")
614
- dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
651
+ if type(members) is list:
652
+ members = set(members)
653
+ for m in members:
654
+ try:
655
+ rid_info = self._model.catalog.resolve_rid(m)
656
+ except KeyError:
657
+ raise DerivaMLException(f"Invalid RID: {m}")
658
+ if rid_info.table.name not in association_map:
659
+ raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
660
+ if rid_info.table == self._dataset_table and check_dataset_cycle(rid_info.rid):
661
+ raise DerivaMLException("Creating cycle of datasets is not allowed")
662
+ dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
663
+ else:
664
+ dataset_elements = {t: set(ms) for t, ms in members.items()}
615
665
  # Now make the entries into the association tables.
616
666
  pb = self._model.catalog.getPathBuilder()
617
667
  for table, elements in dataset_elements.items():
618
668
  schema_path = pb.schemas[
619
- self._ml_schema if table == "Dataset" else self._model.domain_schema
669
+ self._ml_schema if (table == "Dataset" or table == "File") else self._model.domain_schema
620
670
  ]
621
671
  fk_column = "Nested_Dataset" if table == "Dataset" else table
622
672
  if len(elements):
@@ -637,7 +687,7 @@ class Dataset:
637
687
  dataset_rid: RID,
638
688
  members: list[RID],
639
689
  description: str = "",
640
- execution_rid: Optional[RID] = None,
690
+ execution_rid: RID | None = None,
641
691
  ) -> None:
642
692
  """Remove elements to an existing dataset_table.
643
693
 
@@ -658,8 +708,7 @@ class Dataset:
658
708
  # need to be made.
659
709
  dataset_elements = {}
660
710
  association_map = {
661
- a.other_fkeys.pop().pk_table.name: a.table.name
662
- for a in self.dataset_table.find_associations()
711
+ a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
663
712
  }
664
713
  # Get a list of all the object types that can be linked to a dataset_table.
665
714
  for m in members:
@@ -668,16 +717,12 @@ class Dataset:
668
717
  except KeyError:
669
718
  raise DerivaMLException(f"Invalid RID: {m}")
670
719
  if rid_info.table.name not in association_map:
671
- raise DerivaMLException(
672
- f"RID table: {rid_info.table.name} not part of dataset_table"
673
- )
720
+ raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
674
721
  dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
675
722
  # Now make the entries into the association tables.
676
723
  pb = self._model.catalog.getPathBuilder()
677
724
  for table, elements in dataset_elements.items():
678
- schema_path = pb.schemas[
679
- self._ml_schema if table == "Dataset" else self._model.domain_schema
680
- ]
725
+ schema_path = pb.schemas[self._ml_schema if table == "Dataset" else self._model.domain_schema]
681
726
  fk_column = "Nested_Dataset" if table == "Dataset" else table
682
727
 
683
728
  if len(elements):
@@ -685,8 +730,7 @@ class Dataset:
685
730
  # Find out the name of the column in the association table.
686
731
  for e in elements:
687
732
  entity = atable_path.filter(
688
- (atable_path.Dataset == dataset_rid)
689
- & (atable_path.columns[fk_column] == e),
733
+ (atable_path.Dataset == dataset_rid) & (atable_path.columns[fk_column] == e),
690
734
  )
691
735
  entity.delete()
692
736
  self.increment_dataset_version(
@@ -708,21 +752,14 @@ class Dataset:
708
752
  RID of the parent dataset_table.
709
753
  """
710
754
  if not self._is_dataset_rid(dataset_rid):
711
- raise DerivaMLException(
712
- f"RID: {dataset_rid} does not belong to dataset_table {self.dataset_table.name}"
713
- )
755
+ raise DerivaMLException(f"RID: {dataset_rid} does not belong to dataset_table {self._dataset_table.name}")
714
756
  # Get association table for nested datasets
715
757
  pb = self._model.catalog.getPathBuilder()
716
758
  atable_path = pb.schemas[self._ml_schema].Dataset_Dataset
717
- return [
718
- p["Dataset"]
719
- for p in atable_path.filter(atable_path.Nested_Dataset == dataset_rid)
720
- .entities()
721
- .fetch()
722
- ]
759
+ return [p["Dataset"] for p in atable_path.filter(atable_path.Nested_Dataset == dataset_rid).entities().fetch()]
723
760
 
724
761
  @validate_call
725
- def list_dataset_children(self, dataset_rid: RID, recurse=False) -> list[RID]:
762
+ def list_dataset_children(self, dataset_rid: RID, recurse: bool = False) -> list[RID]:
726
763
  """Given a dataset_table RID, return a list of RIDs for any nested datasets.
727
764
 
728
765
  Args:
@@ -733,19 +770,11 @@ class Dataset:
733
770
  list of nested dataset RIDs.
734
771
 
735
772
  """
736
- dataset_dataset_path = (
737
- self._model.catalog.getPathBuilder()
738
- .schemas[self._ml_schema]
739
- .tables["Dataset_Dataset"]
740
- )
773
+ dataset_dataset_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
741
774
  nested_datasets = list(dataset_dataset_path.entities().fetch())
742
775
 
743
776
  def find_children(rid: RID):
744
- children = [
745
- child["Nested_Dataset"]
746
- for child in nested_datasets
747
- if child["Dataset"] == rid
748
- ]
777
+ children = [child["Nested_Dataset"] for child in nested_datasets if child["Dataset"] == rid]
749
778
  if recurse:
750
779
  for child in children.copy():
751
780
  children.extend(find_children(child))
@@ -753,9 +782,7 @@ class Dataset:
753
782
 
754
783
  return find_children(dataset_rid)
755
784
 
756
- def _export_vocabulary(
757
- self, writer: Callable[[str, str, Table], list[dict[str, Any]]]
758
- ) -> list[dict[str, Any]]:
785
+ def _export_vocabulary(self, writer: Callable[[str, str, Table], list[dict[str, Any]]]) -> list[dict[str, Any]]:
759
786
  """
760
787
 
761
788
  Args:
@@ -770,16 +797,12 @@ class Dataset:
770
797
  for table in s.tables.values()
771
798
  if self._model.is_vocabulary(table)
772
799
  ]
773
- return [
774
- o
775
- for table in vocabs
776
- for o in writer(f"{table.schema.name}:{table.name}", table.name, table)
777
- ]
800
+ return [o for table in vocabs for o in writer(f"{table.schema.name}:{table.name}", table.name, table)]
778
801
 
779
802
  def _table_paths(
780
803
  self,
781
- dataset: Optional[DatasetSpec] = None,
782
- snapshot_catalog: Optional[DerivaML] = None,
804
+ dataset: DatasetSpec | None = None,
805
+ snapshot_catalog: DerivaML | None = None,
783
806
  ) -> Iterator[tuple[str, str, Table]]:
784
807
  paths = self._collect_paths(dataset and dataset.rid, snapshot_catalog)
785
808
 
@@ -805,25 +828,20 @@ class Dataset:
805
828
 
806
829
  def _collect_paths(
807
830
  self,
808
- dataset_rid: Optional[RID] = None,
809
- snapshot: Optional[Dataset] = None,
810
- dataset_nesting_depth: Optional[int] = None,
831
+ dataset_rid: RID | None = None,
832
+ snapshot: Dataset | None = None,
833
+ dataset_nesting_depth: int | None = None,
811
834
  ) -> set[tuple[Table, ...]]:
812
835
  snapshot_catalog = snapshot if snapshot else self
813
836
 
814
- dataset_table = snapshot_catalog._model.schemas[self._ml_schema].tables[
815
- "Dataset"
816
- ]
817
- dataset_dataset = snapshot_catalog._model.schemas[self._ml_schema].tables[
818
- "Dataset_Dataset"
819
- ]
837
+ dataset_table = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset"]
838
+ dataset_dataset = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset_Dataset"]
820
839
 
821
840
  # Figure out what types of elements the dataset contains.
822
841
  dataset_associations = [
823
842
  a
824
- for a in self.dataset_table.find_associations()
825
- if a.table.schema.name != self._ml_schema
826
- or a.table.name == "Dataset_Dataset"
843
+ for a in self._dataset_table.find_associations()
844
+ if a.table.schema.name != self._ml_schema or a.table.name == "Dataset_Dataset"
827
845
  ]
828
846
  if dataset_rid:
829
847
  # Get a list of the members of the dataset so we can figure out which tables to query.
@@ -835,9 +853,7 @@ class Dataset:
835
853
  if m
836
854
  ]
837
855
  included_associations = [
838
- a.table
839
- for a in dataset_table.find_associations()
840
- if a.other_fkeys.pop().pk_table in dataset_elements
856
+ a.table for a in dataset_table.find_associations() if a.other_fkeys.pop().pk_table in dataset_elements
841
857
  ]
842
858
  else:
843
859
  included_associations = dataset_associations
@@ -848,9 +864,7 @@ class Dataset:
848
864
  for p in snapshot_catalog._model._schema_to_paths()
849
865
  if (len(p) == 1)
850
866
  or (p[1] not in dataset_associations) # Tables in the domain schema
851
- or (
852
- p[1] in included_associations
853
- ) # Tables that include members of the dataset
867
+ or (p[1] in included_associations) # Tables that include members of the dataset
854
868
  }
855
869
  # Now get paths for nested datasets
856
870
  nested_paths = set()
@@ -860,56 +874,42 @@ class Dataset:
860
874
  else:
861
875
  # Initialize nesting depth if not already provided.
862
876
  dataset_nesting_depth = (
863
- self._dataset_nesting_depth()
864
- if dataset_nesting_depth is None
865
- else dataset_nesting_depth
877
+ self._dataset_nesting_depth() if dataset_nesting_depth is None else dataset_nesting_depth
866
878
  )
867
879
  if dataset_nesting_depth:
868
- nested_paths = self._collect_paths(
869
- dataset_nesting_depth=dataset_nesting_depth - 1
870
- )
880
+ nested_paths = self._collect_paths(dataset_nesting_depth=dataset_nesting_depth - 1)
871
881
  if nested_paths:
872
882
  paths |= {
873
883
  tuple([dataset_table]),
874
884
  (dataset_table, dataset_dataset),
875
885
  }
876
- paths |= {(self.dataset_table, dataset_dataset) + p for p in nested_paths}
886
+ paths |= {(self._dataset_table, dataset_dataset) + p for p in nested_paths}
877
887
  return paths
878
888
 
879
- def _dataset_nesting_depth(self, dataset_rid: Optional[RID] = None) -> int:
889
+ def _dataset_nesting_depth(self, dataset_rid: RID | None = None) -> int:
880
890
  """Determine the maximum dataset nesting depth in the current catalog.
881
891
 
882
892
  Returns:
883
893
 
884
894
  """
885
895
 
886
- def children_depth(
887
- dataset_rid: RID, nested_datasets: dict[str, list[str]]
888
- ) -> int:
896
+ def children_depth(dataset_rid: RID, nested_datasets: dict[str, list[str]]) -> int:
889
897
  """Return the number of nested datasets for the dataset_rid if provided, otherwise in the current catalog"""
890
898
  try:
891
899
  children = nested_datasets[dataset_rid]
892
- return (
893
- max(map(lambda x: children_depth(x, nested_datasets), children)) + 1
894
- if children
895
- else 1
896
- )
900
+ return max(map(lambda x: children_depth(x, nested_datasets), children)) + 1 if children else 1
897
901
  except KeyError:
898
902
  return 0
899
903
 
900
904
  # Build up the dataset_table nesting graph...
901
- pb = (
902
- self._model.catalog.getPathBuilder()
903
- .schemas[self._ml_schema]
904
- .tables["Dataset_Dataset"]
905
- )
905
+ pb = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
906
906
  dataset_children = (
907
907
  [
908
908
  {
909
909
  "Dataset": dataset_rid,
910
910
  "Nested_Dataset": c,
911
911
  } # Make uniform with return from datapath
912
- for c in self.list_dataset_children(dataset_rid)
912
+ for c in self.list_dataset_children(dataset_rid=dataset_rid)
913
913
  ]
914
914
  if dataset_rid
915
915
  else pb.entities().fetch()
@@ -917,30 +917,29 @@ class Dataset:
917
917
  nested_dataset = defaultdict(list)
918
918
  for ds in dataset_children:
919
919
  nested_dataset[ds["Dataset"]].append(ds["Nested_Dataset"])
920
- return (
921
- max(map(lambda d: children_depth(d, dict(nested_dataset)), nested_dataset))
922
- if nested_dataset
923
- else 0
924
- )
920
+ return max(map(lambda d: children_depth(d, dict(nested_dataset)), nested_dataset)) if nested_dataset else 0
925
921
 
926
922
  def _dataset_specification(
927
923
  self,
928
924
  writer: Callable[[str, str, Table], list[dict[str, Any]]],
929
- dataset: Optional[DatasetSpec] = None,
930
- snapshot_catalog: Optional[DerivaML] = None,
925
+ dataset: DatasetSpec | None = None,
926
+ snapshot_catalog: DerivaML | None = None,
931
927
  ) -> list[dict[str, Any]]:
932
- """Output a download/export specification for a dataset_table. Each element of the dataset_table will be placed in its own dir
933
- The top level data directory of the resulting BDBag will have one subdirectory for element type. The subdirectory
934
- will contain the CSV indicating which elements of that type are present in the dataset_table, and then there will be a
935
- subdirectory for each object that is reachable from the dataset_table members.
936
-
937
- To simplify reconstructing the relationship between tables, the CVS for each
938
- The top level data directory will also contain a subdirectory for any controlled vocabularies used in the dataset_table.
939
- All assets will be placed into a directory named asset in a subdirectory with the asset table name.
940
-
941
- For example, consider a dataset_table that consists of two element types, T1 and T2. T1 has foreign key relationships to
942
- objects in tables T3 and T4. There are also two controlled vocabularies, CV1 and CV2. T2 is an asset table
943
- which has two assets in it. The layout of the resulting bdbag would be:
928
+ """Output a download/export specification for a dataset_table. Each element of the dataset_table
929
+ will be placed in its own directory.
930
+ The top level data directory of the resulting BDBag will have one subdirectory for element type.
931
+ The subdirectory will contain the CSV indicating which elements of that type are present in the
932
+ dataset_table, and then there will be a subdirectory for each object that is reachable from the
933
+ dataset_table members.
934
+
935
+ To simplify reconstructing the relationship between tables, the CVS for each element is included.
936
+ The top level data directory will also contain a subdirectory for any controlled vocabularies used in
937
+ the dataset_table. All assets will be placed into a directory named asset in a subdirectory with the
938
+ asset table name.
939
+
940
+ For example, consider a dataset_table that consists of two element types, T1 and T2. T1 has foreign
941
+ key relationships to objects in tables T3 and T4. There are also two controlled vocabularies, CV1 and
942
+ CV2. T2 is an asset table which has two assets in it. The layout of the resulting bdbag would be:
944
943
  data
945
944
  CV1/
946
945
  cv1.csv
@@ -967,17 +966,15 @@ class Dataset:
967
966
  A dataset_table specification.
968
967
  """
969
968
  element_spec = self._export_vocabulary(writer)
970
- for path in self._table_paths(
971
- dataset=dataset, snapshot_catalog=snapshot_catalog
972
- ):
969
+ for path in self._table_paths(dataset=dataset, snapshot_catalog=snapshot_catalog):
973
970
  element_spec.extend(writer(*path))
974
971
  return element_spec
975
972
 
976
973
  def _download_dataset_bag(
977
974
  self,
978
975
  dataset: DatasetSpec,
979
- execution_rid: Optional[RID] = None,
980
- snapshot_catalog: Optional[DerivaML] = None,
976
+ execution_rid: RID | None = None,
977
+ snapshot_catalog: DerivaML | None = None,
981
978
  ) -> DatasetBag:
982
979
  """Download a dataset onto the local file system. Create a MINID for the dataset if one doesn't already exist.
983
980
 
@@ -1007,27 +1004,29 @@ class Dataset:
1007
1004
 
1008
1005
  def _version_snapshot(self, dataset: DatasetSpec) -> str:
1009
1006
  """Return a catalog with snapshot for the specified dataset version"""
1010
- version_record = [
1011
- h
1012
- for h in self.dataset_history(dataset_rid=dataset.rid)
1013
- if h.dataset_version == dataset.version
1014
- ][0]
1007
+ try:
1008
+ version_record = next(
1009
+ h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1010
+ )
1011
+ except StopIteration:
1012
+ raise DerivaMLException(f"Dataset version {dataset.version} not found for dataset {dataset.rid}")
1015
1013
  return f"{self._model.catalog.catalog_id}@{version_record.snapshot}"
1016
1014
 
1017
- def _create_dataset_minid(
1018
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1019
- ) -> str:
1015
+ def _create_dataset_minid(self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None) -> str:
1020
1016
  with TemporaryDirectory() as tmp_dir:
1021
1017
  # Generate a download specification file for the current catalog schema. By default, this spec
1022
1018
  # will generate a minid and place the bag into S3 storage.
1023
- spec_file = f"{tmp_dir}/download_spec.json"
1024
- with open(spec_file, "w", encoding="utf-8") as ds:
1025
- json.dump(
1026
- self._generate_dataset_download_spec(dataset, snapshot_catalog), ds
1027
- )
1019
+ spec_file = Path(tmp_dir) / "download_spec.json"
1020
+ with spec_file.open("w", encoding="utf-8") as ds:
1021
+ json.dump(self._generate_dataset_download_spec(dataset, snapshot_catalog), ds)
1028
1022
  try:
1029
1023
  self._logger.info(
1030
- f"Downloading dataset {'minid' if self._use_minid else 'bag'} for catalog: {dataset.rid}@{str(dataset.version)}"
1024
+ "Downloading dataset %s for catalog: %s@%s"
1025
+ % (
1026
+ 'minid' if self._use_minid else 'bag',
1027
+ dataset.rid,
1028
+ str(dataset.version),
1029
+ )
1031
1030
  )
1032
1031
  # Generate the bag and put into S3 storage.
1033
1032
  exporter = DerivaExport(
@@ -1050,15 +1049,9 @@ class Dataset:
1050
1049
  raise DerivaMLException(format_exception(e))
1051
1050
  # Update version table with MINID.
1052
1051
  if self._use_minid:
1053
- version_path = (
1054
- self._model.catalog.getPathBuilder()
1055
- .schemas[self._ml_schema]
1056
- .tables["Dataset_Version"]
1057
- )
1052
+ version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
1058
1053
  version_rid = [
1059
- h
1060
- for h in self.dataset_history(dataset_rid=dataset.rid)
1061
- if h.dataset_version == dataset.version
1054
+ h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1062
1055
  ][0].version_rid
1063
1056
  version_path.update([{"RID": version_rid, "Minid": minid_page_url}])
1064
1057
  return minid_page_url
@@ -1066,10 +1059,10 @@ class Dataset:
1066
1059
  def _get_dataset_minid(
1067
1060
  self,
1068
1061
  dataset: DatasetSpec,
1069
- snapshot_catalog: Optional[DerivaML] = None,
1062
+ snapshot_catalog: DerivaML | None = None,
1070
1063
  create: bool = True,
1071
- ) -> DatasetMinid:
1072
- """Return a MINID to the specified dataset. If no version is specified, use the latest.
1064
+ ) -> DatasetMinid | None:
1065
+ """Return a MINID for the specified dataset. If no version is specified, use the latest.
1073
1066
 
1074
1067
  Args:
1075
1068
  dataset: Specification of the dataset.
@@ -1079,50 +1072,53 @@ class Dataset:
1079
1072
  Returns:
1080
1073
  New or existing MINID for the dataset.
1081
1074
  """
1082
- if dataset.rid.startswith("minid"):
1083
- minid_url = f"https://identifiers.org/{dataset.rid}"
1084
- elif dataset.rid.startswith("http"):
1085
- minid_url = dataset.rid
1086
- else:
1087
- if not any([dataset.rid == ds["RID"] for ds in self.find_datasets()]):
1088
- raise DerivaMLException(f"RID {dataset.rid} is not a dataset_table")
1089
-
1090
- # Get the history record for the version we are looking for.
1091
- dataset_version_record = [
1092
- v
1093
- for v in self.dataset_history(dataset.rid)
1094
- if v.dataset_version == str(dataset.version)
1095
- ][0]
1096
- if not dataset_version_record:
1097
- raise DerivaMLException(
1098
- f"Version {str(dataset.version)} does not exist for RID {dataset.rid}"
1099
- )
1100
- minid_url = dataset_version_record.minid
1101
- if not minid_url:
1102
- if not create:
1103
- raise DerivaMLException(
1104
- f"Minid for dataset {dataset.rid} doesn't exist"
1105
- )
1106
- if self._use_minid:
1107
- self._logger.info("Creating new MINID for dataset %s", dataset.rid)
1108
- minid_url = self._create_dataset_minid(dataset, snapshot_catalog)
1109
- # If provided a MINID, use the MINID metadata to get the checksum and download the bag.
1075
+ rid = dataset.rid
1076
+
1077
+ # Case 1: RID is already a MINID or direct URL
1078
+ if rid.startswith("minid"):
1079
+ return self._fetch_minid_metadata(f"https://identifiers.org/{rid}", dataset.version)
1080
+ if rid.startswith("http"):
1081
+ return self._fetch_minid_metadata(rid, dataset.version)
1082
+
1083
+ # Case 2: RID is a dataset RID validate existence
1084
+ if not any(rid == ds["RID"] for ds in self.find_datasets()):
1085
+ raise DerivaMLTableTypeError("Dataset", rid)
1086
+
1087
+ # Find dataset version record
1088
+ version_str = str(dataset.version)
1089
+ history = self.dataset_history(rid)
1090
+ try:
1091
+ version_record = next(v for v in history if v.dataset_version == version_str)
1092
+ except StopIteration:
1093
+ raise DerivaMLException(f"Version {version_str} does not exist for RID {rid}")
1094
+
1095
+ # Check or create MINID
1096
+ minid_url = version_record.minid
1097
+ if not minid_url:
1098
+ if not create:
1099
+ raise DerivaMLException(f"Minid for dataset {rid} doesn't exist")
1110
1100
  if self._use_minid:
1111
- r = requests.get(minid_url, headers={"accept": "application/json"})
1112
- dataset_minid = DatasetMinid(
1113
- dataset_version=dataset.version, **r.json()
1114
- )
1115
- else:
1116
- dataset_minid = DatasetMinid(
1117
- dataset_version=dataset.version,
1118
- RID=f"{dataset.rid}@{dataset_version_record.snapshot}",
1119
- location=minid_url,
1120
- )
1121
- return dataset_minid
1101
+ self._logger.info("Creating new MINID for dataset %s", rid)
1102
+ minid_url = self._create_dataset_minid(dataset, snapshot_catalog)
1103
+
1104
+ # Return based on MINID usage
1105
+ if self._use_minid:
1106
+ return self._fetch_minid_metadata(minid_url, dataset.version)
1107
+
1108
+ return DatasetMinid(
1109
+ dataset_version=dataset.version,
1110
+ RID=f"{rid}@{version_record.snapshot}",
1111
+ location=minid_url,
1112
+ )
1113
+
1114
+ def _fetch_minid_metadata(self, url: str, version: DatasetVersion) -> DatasetMinid:
1115
+ r = requests.get(url, headers={"accept": "application/json"})
1116
+ r.raise_for_status()
1117
+ return DatasetMinid(dataset_version=version, **r.json())
1122
1118
 
1123
1119
  def _download_dataset_minid(self, minid: DatasetMinid) -> Path:
1124
- """Given a RID to a dataset_table, or a MINID to an existing bag, download the bag file, extract it, and validate
1125
- that all the metadata is correct
1120
+ """Given a RID to a dataset_table, or a MINID to an existing bag, download the bag file, extract it, and
1121
+ validate that all the metadata is correct
1126
1122
 
1127
1123
  Args:
1128
1124
  minid: The RID of a dataset_table or a minid to an existing bag.
@@ -1134,9 +1130,7 @@ class Dataset:
1134
1130
  # it. If not, then we need to extract the contents of the archive into our cache directory.
1135
1131
  bag_dir = self._cache_dir / f"{minid.dataset_rid}_{minid.checksum}"
1136
1132
  if bag_dir.exists():
1137
- self._logger.info(
1138
- f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}"
1139
- )
1133
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1140
1134
  return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1141
1135
 
1142
1136
  # Either bag hasn't been downloaded yet, or we are not using a Minid, so we don't know the checksum yet.
@@ -1145,19 +1139,13 @@ class Dataset:
1145
1139
  # Get bag from S3
1146
1140
  archive_path = fetch_single_file(minid.bag_url)
1147
1141
  else:
1148
- exporter = DerivaExport(
1149
- host=self._model.catalog.deriva_server.server, output_dir=tmp_dir
1150
- )
1142
+ exporter = DerivaExport(host=self._model.catalog.deriva_server.server, output_dir=tmp_dir)
1151
1143
  archive_path = exporter.retrieve_file(minid.bag_url)
1152
- hashes = hash_utils.compute_file_hashes(
1153
- archive_path, hashes=["md5", "sha256"]
1154
- )
1144
+ hashes = hash_utils.compute_file_hashes(archive_path, hashes=["md5", "sha256"])
1155
1145
  checksum = hashes["sha256"][0]
1156
1146
  bag_dir = self._cache_dir / f"{minid.dataset_rid}_{checksum}"
1157
1147
  if bag_dir.exists():
1158
- self._logger.info(
1159
- f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}"
1160
- )
1148
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1161
1149
  return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1162
1150
  bag_path = bdb.extract_bag(archive_path, bag_dir.as_posix())
1163
1151
  bdb.validate_bag_structure(bag_path)
@@ -1166,7 +1154,7 @@ class Dataset:
1166
1154
  def _materialize_dataset_bag(
1167
1155
  self,
1168
1156
  minid: DatasetMinid,
1169
- execution_rid: Optional[RID] = None,
1157
+ execution_rid: RID | None = None,
1170
1158
  ) -> Path:
1171
1159
  """Materialize a dataset_table bag into a local directory
1172
1160
 
@@ -1180,9 +1168,7 @@ class Dataset:
1180
1168
  def update_status(status: Status, msg: str) -> None:
1181
1169
  """Update the current status for this execution in the catalog"""
1182
1170
  if execution_rid and execution_rid != DRY_RUN_RID:
1183
- self._model.catalog.getPathBuilder().schemas[
1184
- self._ml_schema
1185
- ].Execution.update(
1171
+ self._model.catalog.getPathBuilder().schemas[self._ml_schema].Execution.update(
1186
1172
  [
1187
1173
  {
1188
1174
  "RID": execution_rid,
@@ -1212,9 +1198,7 @@ class Dataset:
1212
1198
 
1213
1199
  # If this bag has already been validated, our work is done. Otherwise, materialize the bag.
1214
1200
  if not validated_check.exists():
1215
- self._logger.info(
1216
- f"Materializing bag {minid.dataset_rid} Version:{minid.dataset_version}"
1217
- )
1201
+ self._logger.info(f"Materializing bag {minid.dataset_rid} Version:{minid.dataset_version}")
1218
1202
  bdb.materialize(
1219
1203
  bag_path.as_posix(),
1220
1204
  fetch_callback=fetch_progress_callback,
@@ -1225,7 +1209,7 @@ class Dataset:
1225
1209
 
1226
1210
  def _export_annotation(
1227
1211
  self,
1228
- snapshot_catalog: Optional[DerivaML] = None,
1212
+ snapshot_catalog: DerivaML | None = None,
1229
1213
  ) -> list[dict[str, Any]]:
1230
1214
  """Return and output specification for the datasets in the provided model
1231
1215
 
@@ -1257,7 +1241,7 @@ class Dataset:
1257
1241
  )
1258
1242
 
1259
1243
  def _export_specification(
1260
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1244
+ self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1261
1245
  ) -> list[dict[str, Any]]:
1262
1246
  """
1263
1247
  Generate a specification for export engine for specific dataset.
@@ -1273,14 +1257,10 @@ class Dataset:
1273
1257
  "processor": "json",
1274
1258
  "processor_params": {"query_path": "/schema", "output_path": "schema"},
1275
1259
  }
1276
- ] + self._dataset_specification(
1277
- self._export_specification_dataset_element, dataset, snapshot_catalog
1278
- )
1260
+ ] + self._dataset_specification(self._export_specification_dataset_element, dataset, snapshot_catalog)
1279
1261
 
1280
1262
  @staticmethod
1281
- def _export_specification_dataset_element(
1282
- spath: str, dpath: str, table: Table
1283
- ) -> list[dict[str, Any]]:
1263
+ def _export_specification_dataset_element(spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1284
1264
  """Return the download specification for the data object indicated by a path through the data model.
1285
1265
 
1286
1266
  Args:
@@ -1315,10 +1295,9 @@ class Dataset:
1315
1295
  )
1316
1296
  return exports
1317
1297
 
1318
- def _export_annotation_dataset_element(
1319
- self, spath: str, dpath: str, table: Table
1320
- ) -> list[dict[str, Any]]:
1321
- """Given a path in the data model, output an export specification for the path taken to get to the current table.
1298
+ def _export_annotation_dataset_element(self, spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1299
+ """Given a path in the data model, output an export specification for the path taken to get to the
1300
+ current table.
1322
1301
 
1323
1302
  Args:
1324
1303
  spath: Source path
@@ -1369,7 +1348,7 @@ class Dataset:
1369
1348
  return exports
1370
1349
 
1371
1350
  def _generate_dataset_download_spec(
1372
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1351
+ self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1373
1352
  ) -> dict[str, Any]:
1374
1353
  """
1375
1354
  Generate a specification for downloading a specific dataset.
@@ -1472,9 +1451,7 @@ class Dataset:
1472
1451
  else {}
1473
1452
  )
1474
1453
  return {
1475
- deriva_tags.export_fragment_definitions: {
1476
- "dataset_export_outputs": self._export_annotation()
1477
- },
1454
+ deriva_tags.export_fragment_definitions: {"dataset_export_outputs": self._export_annotation()},
1478
1455
  deriva_tags.visible_foreign_keys: self._dataset_visible_fkeys(),
1479
1456
  deriva_tags.export_2019: {
1480
1457
  "detailed": {