deriva-ml 1.14.0__py3-none-any.whl → 1.14.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. deriva_ml/__init__.py +25 -30
  2. deriva_ml/core/__init__.py +39 -0
  3. deriva_ml/core/base.py +1489 -0
  4. deriva_ml/core/constants.py +36 -0
  5. deriva_ml/core/definitions.py +74 -0
  6. deriva_ml/core/enums.py +222 -0
  7. deriva_ml/core/ermrest.py +288 -0
  8. deriva_ml/core/exceptions.py +28 -0
  9. deriva_ml/core/filespec.py +116 -0
  10. deriva_ml/dataset/__init__.py +4 -0
  11. deriva_ml/{dataset_aux_classes.py → dataset/aux_classes.py} +16 -12
  12. deriva_ml/{dataset.py → dataset/dataset.py} +406 -428
  13. deriva_ml/{dataset_bag.py → dataset/dataset_bag.py} +137 -97
  14. deriva_ml/{history.py → dataset/history.py} +51 -33
  15. deriva_ml/{upload.py → dataset/upload.py} +48 -70
  16. deriva_ml/demo_catalog.py +233 -183
  17. deriva_ml/execution/environment.py +290 -0
  18. deriva_ml/{execution.py → execution/execution.py} +365 -252
  19. deriva_ml/execution/execution_configuration.py +163 -0
  20. deriva_ml/{execution_configuration.py → execution/workflow.py} +212 -224
  21. deriva_ml/feature.py +83 -46
  22. deriva_ml/model/__init__.py +0 -0
  23. deriva_ml/{deriva_model.py → model/catalog.py} +113 -132
  24. deriva_ml/{database_model.py → model/database.py} +52 -74
  25. deriva_ml/model/sql_mapper.py +44 -0
  26. deriva_ml/run_notebook.py +19 -11
  27. deriva_ml/schema/__init__.py +3 -0
  28. deriva_ml/{schema_setup → schema}/annotations.py +31 -22
  29. deriva_ml/schema/check_schema.py +104 -0
  30. deriva_ml/{schema_setup → schema}/create_schema.py +151 -104
  31. deriva_ml/schema/deriva-ml-reference.json +8525 -0
  32. deriva_ml/schema/table_comments_utils.py +57 -0
  33. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.27.dist-info}/METADATA +5 -4
  34. deriva_ml-1.14.27.dist-info/RECORD +40 -0
  35. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.27.dist-info}/entry_points.txt +1 -0
  36. deriva_ml/deriva_definitions.py +0 -391
  37. deriva_ml/deriva_ml_base.py +0 -1046
  38. deriva_ml/execution_environment.py +0 -139
  39. deriva_ml/schema_setup/table_comments_utils.py +0 -56
  40. deriva_ml/test-files/execution-parameters.json +0 -1
  41. deriva_ml/test-files/notebook-parameters.json +0 -5
  42. deriva_ml/test_functions.py +0 -141
  43. deriva_ml/test_notebook.ipynb +0 -197
  44. deriva_ml-1.14.0.dist-info/RECORD +0 -31
  45. /deriva_ml/{schema_setup → execution}/__init__.py +0 -0
  46. /deriva_ml/{schema_setup → schema}/policy.json +0 -0
  47. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.27.dist-info}/WHEEL +0 -0
  48. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.27.dist-info}/licenses/LICENSE +0 -0
  49. {deriva_ml-1.14.0.dist-info → deriva_ml-1.14.27.dist-info}/top_level.txt +0 -0
@@ -1,74 +1,110 @@
1
- """
2
- This module defines the DataSet class with is used to manipulate datasets in DerivaML.
3
- The intended use of this class is as a base class in DerivaML, so all the methods documented here are
4
- accessible via a DerivaML class instance.
5
-
1
+ """Dataset management for DerivaML.
2
+
3
+ This module provides functionality for managing datasets in DerivaML. A dataset represents a collection
4
+ of related data that can be versioned, downloaded, and tracked. The module includes:
5
+
6
+ - Dataset class: Core class for dataset operations
7
+ - Version management: Track and update dataset versions
8
+ - History tracking: Record dataset changes over time
9
+ - Download capabilities: Export datasets as BDBags
10
+ - Relationship management: Handle dataset dependencies and hierarchies
11
+
12
+ The Dataset class serves as a base class in DerivaML, making its methods accessible through
13
+ DerivaML class instances.
14
+
15
+ Typical usage example:
16
+ >>> ml = DerivaML('deriva.example.org', 'my_catalog')
17
+ >>> dataset_rid = ml.create_dataset('experiment', 'Experimental data')
18
+ >>> ml.add_dataset_members(dataset_rid=dataset_rid, members=['1-abc123', '1-def456'])
19
+ >>> ml.increment_dataset_version(datset_rid=dataset_rid, component=VersionPart.minor,
20
+ ... description='Added new samples')
6
21
  """
7
22
 
8
23
  from __future__ import annotations
9
- from bdbag import bdbag_api as bdb
10
- from bdbag.fetch.fetcher import fetch_single_file
11
- from collections import defaultdict
12
- from graphlib import TopologicalSorter
24
+
25
+ # Standard library imports
13
26
  import json
14
27
  import logging
28
+ from collections import defaultdict
29
+ from graphlib import TopologicalSorter
15
30
  from pathlib import Path
16
- from pydantic import (
17
- validate_call,
18
- ConfigDict,
19
- )
20
- import requests
21
31
  from tempfile import TemporaryDirectory
22
- from typing import Any, Callable, Optional, Iterable, Iterator, TYPE_CHECKING
32
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator
23
33
 
24
- from .history import iso_to_snap
25
- from deriva.core.ermrest_model import Table
26
- from deriva.core.utils.core_utils import tag as deriva_tags, format_exception
27
34
  import deriva.core.utils.hash_utils as hash_utils
28
- from deriva.transfer.download.deriva_export import DerivaExport
35
+ import requests
36
+
37
+ # Third-party imports
38
+ from bdbag import bdbag_api as bdb
39
+ from bdbag.fetch.fetcher import fetch_single_file
40
+
41
+ # Deriva imports
42
+ from deriva.core.ermrest_model import Table
43
+ from deriva.core.utils.core_utils import format_exception
44
+ from deriva.core.utils.core_utils import tag as deriva_tags
29
45
  from deriva.transfer.download.deriva_download import (
30
- DerivaDownloadConfigurationError,
31
- DerivaDownloadError,
32
46
  DerivaDownloadAuthenticationError,
33
47
  DerivaDownloadAuthorizationError,
48
+ DerivaDownloadConfigurationError,
49
+ DerivaDownloadError,
34
50
  DerivaDownloadTimeoutError,
35
51
  )
52
+ from deriva.transfer.download.deriva_export import DerivaExport
53
+ from pydantic import ConfigDict, validate_call
36
54
 
37
-
55
+ # Local imports
38
56
  try:
39
57
  from icecream import ic
58
+
59
+ ic.configureOutput(includeContext=True)
40
60
  except ImportError: # Graceful fallback if IceCream isn't installed.
41
61
  ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
42
62
 
43
- from deriva_ml import DatasetBag
44
- from .deriva_definitions import (
63
+ from deriva_ml.core.constants import RID
64
+ from deriva_ml.core.definitions import (
65
+ DRY_RUN_RID,
45
66
  ML_SCHEMA,
46
- DerivaMLException,
47
67
  MLVocab,
48
68
  Status,
49
- RID,
50
- DRY_RUN_RID,
51
69
  )
52
- from .deriva_model import DerivaModel
53
- from .database_model import DatabaseModel
54
- from .dataset_aux_classes import (
55
- DatasetVersion,
56
- DatasetMinid,
70
+ from deriva_ml.core.exceptions import DerivaMLException, DerivaMLTableTypeError
71
+ from deriva_ml.dataset.aux_classes import (
57
72
  DatasetHistory,
58
- VersionPart,
73
+ DatasetMinid,
59
74
  DatasetSpec,
75
+ DatasetVersion,
76
+ VersionPart,
60
77
  )
78
+ from deriva_ml.dataset.dataset_bag import DatasetBag
79
+ from deriva_ml.model.catalog import DerivaModel
80
+ from deriva_ml.model.database import DatabaseModel
81
+
82
+ from .history import iso_to_snap
83
+
84
+ # Stop pycharm from complaining about undefined reference in docstring....
85
+ ml: DerivaML
61
86
 
62
87
  if TYPE_CHECKING:
63
- from .deriva_ml_base import DerivaML
88
+ from deriva_ml.core.base import DerivaML
64
89
 
65
90
 
66
91
  class Dataset:
67
- """
68
- Class to manipulate a dataset.
92
+ """Manages dataset operations in a Deriva catalog.
93
+
94
+ The Dataset class provides functionality for creating, modifying, and tracking datasets
95
+ in a Deriva catalog. It handles versioning, relationships between datasets, and data export.
69
96
 
70
97
  Attributes:
71
- dataset_table (Table): ERMRest table holding dataset information.
98
+ dataset_table (Table): ERMrest table storing dataset information.
99
+ _model (DerivaModel): Catalog model instance.
100
+ _ml_schema (str): Schema name for ML-specific tables.
101
+ _cache_dir (Path): Directory for caching downloaded datasets.
102
+ _working_dir (Path): Directory for working data.
103
+ _use_minid (bool): Whether to use MINID service for dataset identification.
104
+
105
+ Note:
106
+ This class is typically used as a base class, with its methods accessed through
107
+ DerivaML class instances rather than directly.
72
108
  """
73
109
 
74
110
  _Logger = logging.getLogger("deriva_ml")
@@ -80,20 +116,31 @@ class Dataset:
80
116
  working_dir: Path,
81
117
  use_minid: bool = True,
82
118
  ):
119
+ """Initializes a Dataset instance.
120
+
121
+ Args:
122
+ model: DerivaModel instance representing the catalog.
123
+ cache_dir: Directory path for caching downloaded datasets.
124
+ working_dir: Directory path for working data.
125
+ use_minid: Whether to use MINID service for dataset identification.
126
+ """
83
127
  self._model = model
84
128
  self._ml_schema = ML_SCHEMA
85
- self.dataset_table = self._model.schemas[self._ml_schema].tables["Dataset"]
86
129
  self._cache_dir = cache_dir
87
130
  self._working_dir = working_dir
88
131
  self._logger = logging.getLogger("deriva_ml")
89
132
  self._use_minid = use_minid
90
133
 
134
+ @property
135
+ def _dataset_table(self):
136
+ return self._model.schemas[self._ml_schema].tables["Dataset"]
137
+
91
138
  def _is_dataset_rid(self, dataset_rid: RID, deleted: bool = False) -> bool:
92
139
  try:
93
140
  rid_info = self._model.catalog.resolve_rid(dataset_rid, self._model.model)
94
141
  except KeyError as _e:
95
142
  raise DerivaMLException(f"Invalid RID {dataset_rid}")
96
- if rid_info.table != self.dataset_table:
143
+ if rid_info.table != self._dataset_table:
97
144
  return False
98
145
  elif deleted:
99
146
  # Got a dataset rid. Now check to see if its deleted or not.
@@ -104,12 +151,12 @@ class Dataset:
104
151
  def _insert_dataset_versions(
105
152
  self,
106
153
  dataset_list: list[DatasetSpec],
107
- description: Optional[str] = "",
108
- execution_rid: Optional[RID] = None,
154
+ description: str | None = "",
155
+ execution_rid: RID | None = None,
109
156
  ) -> None:
110
157
  schema_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema]
111
158
  # determine snapshot after changes were made
112
- snap = self._model.catalog.get("/").json()["snaptime"]
159
+
113
160
  # Construct version records for insert
114
161
  version_records = schema_path.tables["Dataset_Version"].insert(
115
162
  [
@@ -118,16 +165,18 @@ class Dataset:
118
165
  "Version": str(dataset.version),
119
166
  "Description": description,
120
167
  "Execution": execution_rid,
121
- "Snapshot": snap,
122
168
  }
123
169
  for dataset in dataset_list
124
170
  ]
125
171
  )
172
+ version_records = list(version_records)
173
+ snap = self._model.catalog.get("/").json()["snaptime"]
174
+ schema_path.tables["Dataset_Version"].update(
175
+ [{"RID": v["RID"], "Dataset": v["Dataset"], "Snapshot": snap} for v in version_records]
176
+ )
126
177
 
127
178
  # And update the dataset records.
128
- schema_path.tables["Dataset"].update(
129
- [{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records]
130
- )
179
+ schema_path.tables["Dataset"].update([{"Version": v["RID"], "RID": v["Dataset"]} for v in version_records])
131
180
 
132
181
  def _bootstrap_versions(self):
133
182
  datasets = [ds["RID"] for ds in self.find_datasets()]
@@ -143,9 +192,7 @@ class Dataset:
143
192
  version_path = schema_path.tables["Dataset_Version"]
144
193
  dataset_path = schema_path.tables["Dataset"]
145
194
  history = list(version_path.insert(ds_version))
146
- dataset_versions = [
147
- {"RID": h["Dataset"], "Version": h["Version"]} for h in history
148
- ]
195
+ dataset_versions = [{"RID": h["Dataset"], "Version": h["Version"]} for h in history]
149
196
  dataset_path.update(dataset_versions)
150
197
 
151
198
  def _synchronize_dataset_versions(self):
@@ -161,45 +208,47 @@ class Dataset:
161
208
  versions[v["Dataset"]] = v
162
209
  dataset_path = schema_path.tables["Dataset"]
163
210
 
164
- dataset_path.update(
165
- [
166
- {"RID": dataset, "Version": version["RID"]}
167
- for dataset, version in versions.items()
168
- ]
169
- )
211
+ dataset_path.update([{"RID": dataset, "Version": version["RID"]} for dataset, version in versions.items()])
170
212
 
171
213
  def _set_version_snapshot(self):
172
- dataset_version_path = (
173
- self._model.catalog.getPathBuilder()
174
- .schemas[self._ml_schema]
175
- .tables["Dataset_Version"]
176
- )
214
+ """Update the Snapshot column of the Dataset_Version table to the correct time."""
215
+ dataset_version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
177
216
  versions = dataset_version_path.entities().fetch()
178
217
  dataset_version_path.update(
179
- [
180
- {"RID": h["RID"], "Snapshot": iso_to_snap(h["RCT"])}
181
- for h in versions
182
- if not h["Snapshot"]
183
- ]
218
+ [{"RID": h["RID"], "Snapshot": iso_to_snap(h["RCT"])} for h in versions if not h["Snapshot"]]
184
219
  )
185
220
 
186
221
  def dataset_history(self, dataset_rid: RID) -> list[DatasetHistory]:
187
- """Return a list of DatasetHistory objects representing the dataset
222
+ """Retrieves the version history of a dataset.
223
+
224
+ Returns a chronological list of dataset versions, including their version numbers,
225
+ creation times, and associated metadata.
188
226
 
189
227
  Args:
190
- dataset_rid: A RID to the dataset for which history is to be fetched.
228
+ dataset_rid: Resource Identifier of the dataset.
191
229
 
192
230
  Returns:
193
- A list of DatasetHistory objects which indicate the version-number, creation time, and bag instantiation of the dataset.
231
+ list[DatasetHistory]: List of history entries, each containing:
232
+ - dataset_version: Version number (major.minor.patch)
233
+ - minid: Minimal Viable Identifier
234
+ - snapshot: Catalog snapshot time
235
+ - dataset_rid: Dataset Resource Identifier
236
+ - version_rid: Version Resource Identifier
237
+ - description: Version description
238
+ - execution_rid: Associated execution RID
239
+
240
+ Raises:
241
+ DerivaMLException: If dataset_rid is not a valid dataset RID.
242
+
243
+ Example:
244
+ >>> history = ml.dataset_history("1-abc123")
245
+ >>> for entry in history:
246
+ ... print(f"Version {entry.dataset_version}: {entry.description}")
194
247
  """
195
248
 
196
249
  if not self._is_dataset_rid(dataset_rid):
197
250
  raise DerivaMLException(f"RID is not for a data set: {dataset_rid}")
198
- version_path = (
199
- self._model.catalog.getPathBuilder()
200
- .schemas[self._ml_schema]
201
- .tables["Dataset_Version"]
202
- )
251
+ version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
203
252
  return [
204
253
  DatasetHistory(
205
254
  dataset_version=DatasetVersion.parse(v["Version"]),
@@ -210,9 +259,7 @@ class Dataset:
210
259
  description=v["Description"],
211
260
  execution_rid=v["Execution"],
212
261
  )
213
- for v in version_path.filter(version_path.Dataset == dataset_rid)
214
- .entities()
215
- .fetch()
262
+ for v in version_path.filter(version_path.Dataset == dataset_rid).entities().fetch()
216
263
  ]
217
264
 
218
265
  @validate_call
@@ -234,14 +281,16 @@ class Dataset:
234
281
  if not history:
235
282
  return DatasetVersion(0, 1, 0)
236
283
  else:
237
- return max([h.dataset_version for h in self.dataset_history(dataset_rid)])
284
+ # Ensure we return a DatasetVersion, not a string
285
+ versions = [h.dataset_version for h in history]
286
+ return max(versions) if versions else DatasetVersion(0, 1, 0)
238
287
 
239
288
  def _build_dataset_graph(self, dataset_rid: RID) -> Iterable[RID]:
240
- ts = TopologicalSorter()
289
+ ts: TopologicalSorter = TopologicalSorter()
241
290
  self._build_dataset_graph_1(dataset_rid, ts, set())
242
291
  return ts.static_order()
243
292
 
244
- def _build_dataset_graph_1(self, dataset_rid: RID, ts, visited) -> None:
293
+ def _build_dataset_graph_1(self, dataset_rid: RID, ts: TopologicalSorter, visited) -> None:
245
294
  """Use topological sort to return bottom up list of nested datasets"""
246
295
  ts.add(dataset_rid)
247
296
  if dataset_rid not in visited:
@@ -249,7 +298,8 @@ class Dataset:
249
298
  children = self.list_dataset_children(dataset_rid=dataset_rid)
250
299
  parents = self.list_dataset_parents(dataset_rid=dataset_rid)
251
300
  for parent in parents:
252
- self._build_dataset_graph_1(parent, ts, visited)
301
+ # Convert string to RID type
302
+ self._build_dataset_graph_1(RID(parent), ts, visited)
253
303
  for child in children:
254
304
  self._build_dataset_graph_1(child, ts, visited)
255
305
 
@@ -258,22 +308,34 @@ class Dataset:
258
308
  self,
259
309
  dataset_rid: RID,
260
310
  component: VersionPart,
261
- description: Optional[str] = "",
262
- execution_rid: Optional[RID] = None,
311
+ description: str | None = "",
312
+ execution_rid: RID | None = None,
263
313
  ) -> DatasetVersion:
264
- """Increment the version of the specified dataset_table.
314
+ """Increments a dataset's version number.
315
+
316
+ Creates a new version of the dataset by incrementing the specified version component
317
+ (major, minor, or patch). The new version is recorded with an optional description
318
+ and execution reference.
265
319
 
266
320
  Args:
267
- dataset_rid: RID of the dataset whose version is to be incremented.
268
- component: Which version of the dataset_table to increment. Major, Minor, or Patch
269
- description: Description of the version update of the dataset_table.
270
- execution_rid: Which execution is performing increment.
321
+ dataset_rid: Resource Identifier of the dataset to version.
322
+ component: Which version component to increment ('major', 'minor', or 'patch').
323
+ description: Optional description of the changes in this version.
324
+ execution_rid: Optional execution RID to associate with this version.
271
325
 
272
326
  Returns:
273
- new semantic version of the dataset_table as a 3-tuple
327
+ DatasetVersion: The new version number.
274
328
 
275
329
  Raises:
276
- DerivaMLException: if provided, RID is not to a dataset_table.
330
+ DerivaMLException: If dataset_rid is invalid or version increment fails.
331
+
332
+ Example:
333
+ >>> new_version = ml.increment_dataset_version(
334
+ ... dataset_rid="1-abc123",
335
+ ... component="minor",
336
+ ... description="Added new samples"
337
+ ... )
338
+ >>> print(f"New version: {new_version}") # e.g., "1.2.0"
277
339
  """
278
340
 
279
341
  # Find all the datasets that are reachable from this dataset and determine their new version numbers.
@@ -285,46 +347,51 @@ class Dataset:
285
347
  )
286
348
  for ds_rid in related_datasets
287
349
  ]
288
- self._insert_dataset_versions(
289
- version_update_list, description=description, execution_rid=execution_rid
290
- )
291
- return [d.version for d in version_update_list if d.rid == dataset_rid][0]
350
+ self._insert_dataset_versions(version_update_list, description=description, execution_rid=execution_rid)
351
+ return next((d.version for d in version_update_list if d.rid == dataset_rid))
292
352
 
293
353
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
294
354
  def create_dataset(
295
355
  self,
296
- dataset_types: str | list[str],
297
- description: str,
298
- execution_rid: Optional[RID] = None,
299
- version: Optional[DatasetVersion] = None,
356
+ dataset_types: str | list[str] | None = None,
357
+ description: str = "",
358
+ execution_rid: RID | None = None,
359
+ version: DatasetVersion | None = None,
300
360
  ) -> RID:
301
- """Create a new dataset_table from the specified list of RIDs.
361
+ """Creates a new dataset in the catalog.
362
+
363
+ Creates a dataset with specified types and description. The dataset can be associated
364
+ with an execution and initialized with a specific version.
302
365
 
303
366
  Args:
304
- dataset_types: One or more dataset_table types. Must be a term from the DatasetType controlled vocabulary.
305
- description: Description of the dataset_table.
306
- execution_rid: Execution under which the dataset_table will be created.
307
- version: Version of the dataset_table.
367
+ dataset_types: One or more dataset type terms from Dataset_Type vocabulary.
368
+ description: Description of the dataset's purpose and contents.
369
+ execution_rid: Optional execution RID to associate with dataset creation.
370
+ version: Optional initial version number. Defaults to 0.1.0.
308
371
 
309
372
  Returns:
310
- New dataset_table RID.
373
+ RID: Resource Identifier of the newly created dataset.
311
374
 
375
+ Raises:
376
+ DerivaMLException: If dataset_types are invalid or creation fails.
377
+
378
+ Example:
379
+ >>> rid = ml.create_dataset(
380
+ ... dataset_types=["experiment", "raw_data"],
381
+ ... description="RNA sequencing experiment data",
382
+ ... version=DatasetVersion(1, 0, 0)
383
+ ... )
312
384
  """
313
385
 
314
386
  version = version or DatasetVersion(0, 1, 0)
387
+ dataset_types = dataset_types or []
315
388
 
316
- type_path = (
317
- self._model.catalog.getPathBuilder()
318
- .schemas[self._ml_schema]
319
- .tables[MLVocab.dataset_type.value]
320
- )
389
+ type_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables[MLVocab.dataset_type.value]
321
390
  defined_types = list(type_path.entities().fetch())
322
391
 
323
392
  def check_dataset_type(dtype: str) -> bool:
324
393
  for term in defined_types:
325
- if dtype == term["Name"] or (
326
- term["Synonyms"] and ds_type in term["Synonyms"]
327
- ):
394
+ if dtype == term["Name"] or (term["Synonyms"] and ds_type in term["Synonyms"]):
328
395
  return True
329
396
  return False
330
397
 
@@ -334,9 +401,7 @@ class Dataset:
334
401
  for ds_type in ds_types:
335
402
  if not check_dataset_type(ds_type):
336
403
  raise DerivaMLException("Dataset type must be a vocabulary term.")
337
- dataset_table_path = pb.schemas[self.dataset_table.schema.name].tables[
338
- self.dataset_table.name
339
- ]
404
+ dataset_table_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
340
405
  dataset_rid = dataset_table_path.insert(
341
406
  [
342
407
  {
@@ -347,21 +412,13 @@ class Dataset:
347
412
  )[0]["RID"]
348
413
 
349
414
  # Get the name of the association table between dataset_table and dataset_type.
350
- atable = next(
351
- self._model.schemas[self._ml_schema]
352
- .tables[MLVocab.dataset_type]
353
- .find_associations()
354
- ).name
415
+ associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
416
+ atable = associations[0].name if associations else None
355
417
  pb.schemas[self._ml_schema].tables[atable].insert(
356
- [
357
- {MLVocab.dataset_type: ds_type, "Dataset": dataset_rid}
358
- for ds_type in ds_types
359
- ]
418
+ [{MLVocab.dataset_type: ds_type, "Dataset": dataset_rid} for ds_type in ds_types]
360
419
  )
361
420
  if execution_rid is not None:
362
- pb.schemas[self._ml_schema].Dataset_Execution.insert(
363
- [{"Dataset": dataset_rid, "Execution": execution_rid}]
364
- )
421
+ pb.schemas[self._ml_schema].Dataset_Execution.insert([{"Dataset": dataset_rid, "Execution": execution_rid}])
365
422
  self._insert_dataset_versions(
366
423
  [DatasetSpec(rid=dataset_rid, version=version)],
367
424
  execution_rid=execution_rid,
@@ -383,18 +440,12 @@ class Dataset:
383
440
  raise DerivaMLException("Dataset_rid is not a dataset.")
384
441
 
385
442
  if parents := self.list_dataset_parents(dataset_rid):
386
- raise DerivaMLException(
387
- f'Dataset_rid "{dataset_rid}" is in a nested dataset: {parents}.'
388
- )
443
+ raise DerivaMLException(f'Dataset_rid "{dataset_rid}" is in a nested dataset: {parents}.')
389
444
 
390
445
  pb = self._model.catalog.getPathBuilder()
391
- dataset_path = pb.schemas[self.dataset_table.schema.name].tables[
392
- self.dataset_table.name
393
- ]
446
+ dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
394
447
 
395
- rid_list = [dataset_rid] + (
396
- self.list_dataset_children(dataset_rid) if recurse else []
397
- )
448
+ rid_list = [dataset_rid] + (self.list_dataset_children(dataset_rid=dataset_rid) if recurse else [])
398
449
  dataset_path.update([{"RID": r, "Deleted": True} for r in rid_list])
399
450
 
400
451
  def find_datasets(self, deleted: bool = False) -> Iterable[dict[str, Any]]:
@@ -408,14 +459,9 @@ class Dataset:
408
459
  """
409
460
  # Get datapath to all the tables we will need: Dataset, DatasetType and the association table.
410
461
  pb = self._model.catalog.getPathBuilder()
411
- dataset_path = pb.schemas[self.dataset_table.schema.name].tables[
412
- self.dataset_table.name
413
- ]
414
- atable = next(
415
- self._model.schemas[self._ml_schema]
416
- .tables[MLVocab.dataset_type]
417
- .find_associations()
418
- ).name
462
+ dataset_path = pb.schemas[self._dataset_table.schema.name].tables[self._dataset_table.name]
463
+ associations = list(self._model.schemas[self._ml_schema].tables[MLVocab.dataset_type].find_associations())
464
+ atable = associations[0].name if associations else None
419
465
  ml_path = pb.schemas[self._ml_schema]
420
466
  atable_path = ml_path.tables[atable]
421
467
 
@@ -423,21 +469,16 @@ class Dataset:
423
469
  filtered_path = dataset_path
424
470
  else:
425
471
  filtered_path = dataset_path.filter(
426
- (dataset_path.Deleted == False) | (dataset_path.Deleted == None) # noqa: E712
472
+ (dataset_path.Deleted == False) | (dataset_path.Deleted == None) # noqa: E711, E712
427
473
  )
428
474
 
429
475
  # Get a list of all the dataset_type values associated with this dataset_table.
430
476
  datasets = []
431
477
  for dataset in filtered_path.entities().fetch():
432
478
  ds_types = (
433
- atable_path.filter(atable_path.Dataset == dataset["RID"])
434
- .attributes(atable_path.Dataset_Type)
435
- .fetch()
436
- )
437
- datasets.append(
438
- dataset
439
- | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in ds_types]}
479
+ atable_path.filter(atable_path.Dataset == dataset["RID"]).attributes(atable_path.Dataset_Type).fetch()
440
480
  )
481
+ datasets.append(dataset | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in ds_types]})
441
482
  return datasets
442
483
 
443
484
  def list_dataset_element_types(self) -> Iterable[Table]:
@@ -448,16 +489,9 @@ class Dataset:
448
489
  """
449
490
 
450
491
  def domain_table(table: Table) -> bool:
451
- return (
452
- table.schema.name == self._model.domain_schema
453
- or table.name == self.dataset_table.name
454
- )
492
+ return table.schema.name == self._model.domain_schema or table.name == self._dataset_table.name
455
493
 
456
- return [
457
- t
458
- for a in self.dataset_table.find_associations()
459
- if domain_table(t := a.other_fkeys.pop().pk_table)
460
- ]
494
+ return [t for a in self._dataset_table.find_associations() if domain_table(t := a.other_fkeys.pop().pk_table)]
461
495
 
462
496
  @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
463
497
  def add_dataset_element_type(self, element: str | Table) -> Table:
@@ -472,31 +506,45 @@ class Dataset:
472
506
  """
473
507
  # Add table to map
474
508
  element_table = self._model.name_to_table(element)
475
- table = self._model.schemas[self._model.domain_schema].create_table(
476
- Table.define_association([self.dataset_table, element_table])
477
- )
509
+ atable_def = Table.define_association([self._dataset_table, element_table])
510
+ try:
511
+ table = self._model.schemas[self._model.domain_schema].create_table(atable_def)
512
+ except ValueError as e:
513
+ if "already exists" in str(e):
514
+ table = self._model.name_to_table(atable_def["table_name"])
515
+ else:
516
+ raise e
478
517
 
479
518
  # self.model = self.catalog.getCatalogModel()
480
- self.dataset_table.annotations.update(
481
- self._generate_dataset_download_annotations()
482
- )
519
+ self._dataset_table.annotations.update(self._generate_dataset_download_annotations())
483
520
  self._model.model.apply()
484
521
  return table
485
522
 
486
523
  # @validate_call
487
524
  def list_dataset_members(
488
- self, dataset_rid: RID, recurse: bool = False, limit: Optional[int] = None
525
+ self, dataset_rid: RID, recurse: bool = False, limit: int | None = None
489
526
  ) -> dict[str, list[dict[str, Any]]]:
490
- """Return a list of entities associated with a specific dataset_table.
527
+ """Lists members of a dataset.
528
+
529
+ Returns a dictionary mapping member types to lists of member records. Can optionally
530
+ recurse through nested datasets and limit the number of results.
491
531
 
492
532
  Args:
493
- dataset_rid: param recurse: If this is a nested dataset_table, list the members of the contained datasets
494
- recurse: (Default value = False)
495
- limit: If provided, the maximum number of members to return for each element type.
533
+ dataset_rid: Resource Identifier of the dataset.
534
+ recurse: Whether to include members of nested datasets. Defaults to False.
535
+ limit: Maximum number of members to return per type. None for no limit.
496
536
 
497
537
  Returns:
498
- Dictionary of entities associated with a specific dataset_table. Key is the table from which the elements
499
- were taken.
538
+ dict[str, list[dict[str, Any]]]: Dictionary mapping member types to lists of members.
539
+ Each member is a dictionary containing the record's attributes.
540
+
541
+ Raises:
542
+ DerivaMLException: If dataset_rid is invalid.
543
+
544
+ Example:
545
+ >>> members = ml.list_dataset_members("1-abc123", recurse=True)
546
+ >>> for type_name, records in members.items():
547
+ ... print(f"{type_name}: {len(records)} records")
500
548
  """
501
549
 
502
550
  if not self._is_dataset_rid(dataset_rid):
@@ -506,21 +554,18 @@ class Dataset:
506
554
  # the appropriate association table.
507
555
  members = defaultdict(list)
508
556
  pb = self._model.catalog.getPathBuilder()
509
- for assoc_table in self.dataset_table.find_associations():
557
+ for assoc_table in self._dataset_table.find_associations():
510
558
  other_fkey = assoc_table.other_fkeys.pop()
511
559
  target_table = other_fkey.pk_table
512
560
  member_table = assoc_table.table
513
561
 
514
562
  # Look at domain tables and nested datasets.
515
- if (
516
- target_table.schema.name != self._model.domain_schema
517
- and target_table != self.dataset_table
563
+ if target_table.schema.name != self._model.domain_schema and not (
564
+ target_table == self._dataset_table or target_table.name == "File"
518
565
  ):
519
566
  continue
520
567
  member_column = (
521
- "Nested_Dataset"
522
- if target_table == self.dataset_table
523
- else other_fkey.foreign_key_columns[0].name
568
+ "Nested_Dataset" if target_table == self._dataset_table else other_fkey.foreign_key_columns[0].name
524
569
  )
525
570
 
526
571
  target_path = pb.schemas[target_table.schema.name].tables[target_table.name]
@@ -530,15 +575,13 @@ class Dataset:
530
575
  target_path,
531
576
  on=(member_path.columns[member_column] == target_path.columns["RID"]),
532
577
  )
533
- target_entities = list(
534
- path.entities().fetch(limit=limit) if limit else path.entities().fetch()
535
- )
578
+ target_entities = list(path.entities().fetch(limit=limit) if limit else path.entities().fetch())
536
579
  members[target_table.name].extend(target_entities)
537
- if recurse and target_table == self.dataset_table:
580
+ if recurse and target_table == self._dataset_table:
538
581
  # Get the members for all the nested datasets and add to the member list.
539
582
  nested_datasets = [d["RID"] for d in target_entities]
540
583
  for ds in nested_datasets:
541
- for k, v in self.list_dataset_members(ds, recurse=False).items():
584
+ for k, v in self.list_dataset_members(ds, recurse=recurse).items():
542
585
  members[k].extend(v)
543
586
  return dict(members)
544
587
 
@@ -546,24 +589,38 @@ class Dataset:
546
589
  def add_dataset_members(
547
590
  self,
548
591
  dataset_rid: RID,
549
- members: list[RID],
592
+ members: list[RID] | dict[str, list[RID]],
550
593
  validate: bool = True,
551
- description: Optional[str] = "",
552
- execution_rid: Optional[RID] = None,
594
+ description: str | None = "",
595
+ execution_rid: RID | None = None,
553
596
  ) -> None:
554
- """Add additional elements to an existing dataset_table.
597
+ """Adds members to a dataset.
555
598
 
556
- Add new elements to an existing dataset. In addition to adding new members, the minor version number of the
557
- dataset is incremented and the description, if provide is applied to that new version.
599
+ Associates one or more records with a dataset. Can optionally validate member types
600
+ and create a new dataset version to track the changes.
558
601
 
559
602
  Args:
560
- dataset_rid: RID of dataset_table to extend or None if a new dataset_table is to be created.
561
- members: List of member RIDs to add to the dataset_table.
562
- validate: Check rid_list to make sure elements are not already in the dataset_table.
563
- description: Markdown description of the updated dataset.
564
- execution_rid: Optional RID of execution associated with this dataset.
603
+ dataset_rid: Resource Identifier of the dataset.
604
+ members: List of RIDs to add as dataset members. Can be orginized into a dictionary that indicates the
605
+ table that the member rids belong to.
606
+ validate: Whether to validate member types. Defaults to True.
607
+ description: Optional description of the member additions.
608
+ execution_rid: Optional execution RID to associate with changes.
609
+
610
+ Raises:
611
+ DerivaMLException: If:
612
+ - dataset_rid is invalid
613
+ - members are invalid or of wrong type
614
+ - adding members would create a cycle
615
+ - validation fails
616
+
617
+ Example:
618
+ >>> ml.add_dataset_members(
619
+ ... dataset_rid="1-abc123",
620
+ ... members=["1-def456", "1-ghi789"],
621
+ ... description="Added sample data"
622
+ ... )
565
623
  """
566
- members = set(members)
567
624
  description = description or "Updated dataset via add_dataset_members"
568
625
 
569
626
  def check_dataset_cycle(member_rid, path=None):
@@ -580,43 +637,37 @@ class Dataset:
580
637
  return member_rid in path
581
638
 
582
639
  if validate:
583
- existing_rids = set(
584
- m["RID"]
585
- for ms in self.list_dataset_members(dataset_rid).values()
586
- for m in ms
587
- )
640
+ existing_rids = set(m["RID"] for ms in self.list_dataset_members(dataset_rid).values() for m in ms)
588
641
  if overlap := set(existing_rids).intersection(members):
589
- raise DerivaMLException(
590
- f"Attempting to add existing member to dataset_table {dataset_rid}: {overlap}"
591
- )
642
+ raise DerivaMLException(f"Attempting to add existing member to dataset_table {dataset_rid}: {overlap}")
592
643
 
593
644
  # Now go through every rid to be added to the data set and sort them based on what association table entries
594
645
  # need to be made.
595
646
  dataset_elements = {}
596
647
  association_map = {
597
- a.other_fkeys.pop().pk_table.name: a.table.name
598
- for a in self.dataset_table.find_associations()
648
+ a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
599
649
  }
650
+
600
651
  # Get a list of all the object types that can be linked to a dataset_table.
601
- for m in members:
602
- try:
603
- rid_info = self._model.catalog.resolve_rid(m)
604
- except KeyError:
605
- raise DerivaMLException(f"Invalid RID: {m}")
606
- if rid_info.table.name not in association_map:
607
- raise DerivaMLException(
608
- f"RID table: {rid_info.table.name} not part of dataset_table"
609
- )
610
- if rid_info.table == self.dataset_table and check_dataset_cycle(
611
- rid_info.rid
612
- ):
613
- raise DerivaMLException("Creating cycle of datasets is not allowed")
614
- dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
652
+ if type(members) is list:
653
+ members = set(members)
654
+ for m in members:
655
+ try:
656
+ rid_info = self._model.catalog.resolve_rid(m)
657
+ except KeyError:
658
+ raise DerivaMLException(f"Invalid RID: {m}")
659
+ if rid_info.table.name not in association_map:
660
+ raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
661
+ if rid_info.table == self._dataset_table and check_dataset_cycle(rid_info.rid):
662
+ raise DerivaMLException("Creating cycle of datasets is not allowed")
663
+ dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
664
+ else:
665
+ dataset_elements = {t: set(ms) for t, ms in members.items()}
615
666
  # Now make the entries into the association tables.
616
667
  pb = self._model.catalog.getPathBuilder()
617
668
  for table, elements in dataset_elements.items():
618
669
  schema_path = pb.schemas[
619
- self._ml_schema if table == "Dataset" else self._model.domain_schema
670
+ self._ml_schema if (table == "Dataset" or table == "File") else self._model.domain_schema
620
671
  ]
621
672
  fk_column = "Nested_Dataset" if table == "Dataset" else table
622
673
  if len(elements):
@@ -637,7 +688,7 @@ class Dataset:
637
688
  dataset_rid: RID,
638
689
  members: list[RID],
639
690
  description: str = "",
640
- execution_rid: Optional[RID] = None,
691
+ execution_rid: RID | None = None,
641
692
  ) -> None:
642
693
  """Remove elements to an existing dataset_table.
643
694
 
@@ -658,8 +709,7 @@ class Dataset:
658
709
  # need to be made.
659
710
  dataset_elements = {}
660
711
  association_map = {
661
- a.other_fkeys.pop().pk_table.name: a.table.name
662
- for a in self.dataset_table.find_associations()
712
+ a.other_fkeys.pop().pk_table.name: a.table.name for a in self._dataset_table.find_associations()
663
713
  }
664
714
  # Get a list of all the object types that can be linked to a dataset_table.
665
715
  for m in members:
@@ -668,16 +718,12 @@ class Dataset:
668
718
  except KeyError:
669
719
  raise DerivaMLException(f"Invalid RID: {m}")
670
720
  if rid_info.table.name not in association_map:
671
- raise DerivaMLException(
672
- f"RID table: {rid_info.table.name} not part of dataset_table"
673
- )
721
+ raise DerivaMLException(f"RID table: {rid_info.table.name} not part of dataset_table")
674
722
  dataset_elements.setdefault(rid_info.table.name, []).append(rid_info.rid)
675
723
  # Now make the entries into the association tables.
676
724
  pb = self._model.catalog.getPathBuilder()
677
725
  for table, elements in dataset_elements.items():
678
- schema_path = pb.schemas[
679
- self._ml_schema if table == "Dataset" else self._model.domain_schema
680
- ]
726
+ schema_path = pb.schemas[self._ml_schema if table == "Dataset" else self._model.domain_schema]
681
727
  fk_column = "Nested_Dataset" if table == "Dataset" else table
682
728
 
683
729
  if len(elements):
@@ -685,8 +731,7 @@ class Dataset:
685
731
  # Find out the name of the column in the association table.
686
732
  for e in elements:
687
733
  entity = atable_path.filter(
688
- (atable_path.Dataset == dataset_rid)
689
- & (atable_path.columns[fk_column] == e),
734
+ (atable_path.Dataset == dataset_rid) & (atable_path.columns[fk_column] == e),
690
735
  )
691
736
  entity.delete()
692
737
  self.increment_dataset_version(
@@ -708,21 +753,14 @@ class Dataset:
708
753
  RID of the parent dataset_table.
709
754
  """
710
755
  if not self._is_dataset_rid(dataset_rid):
711
- raise DerivaMLException(
712
- f"RID: {dataset_rid} does not belong to dataset_table {self.dataset_table.name}"
713
- )
756
+ raise DerivaMLException(f"RID: {dataset_rid} does not belong to dataset_table {self._dataset_table.name}")
714
757
  # Get association table for nested datasets
715
758
  pb = self._model.catalog.getPathBuilder()
716
759
  atable_path = pb.schemas[self._ml_schema].Dataset_Dataset
717
- return [
718
- p["Dataset"]
719
- for p in atable_path.filter(atable_path.Nested_Dataset == dataset_rid)
720
- .entities()
721
- .fetch()
722
- ]
760
+ return [p["Dataset"] for p in atable_path.filter(atable_path.Nested_Dataset == dataset_rid).entities().fetch()]
723
761
 
724
762
  @validate_call
725
- def list_dataset_children(self, dataset_rid: RID, recurse=False) -> list[RID]:
763
+ def list_dataset_children(self, dataset_rid: RID, recurse: bool = False) -> list[RID]:
726
764
  """Given a dataset_table RID, return a list of RIDs for any nested datasets.
727
765
 
728
766
  Args:
@@ -733,19 +771,11 @@ class Dataset:
733
771
  list of nested dataset RIDs.
734
772
 
735
773
  """
736
- dataset_dataset_path = (
737
- self._model.catalog.getPathBuilder()
738
- .schemas[self._ml_schema]
739
- .tables["Dataset_Dataset"]
740
- )
774
+ dataset_dataset_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
741
775
  nested_datasets = list(dataset_dataset_path.entities().fetch())
742
776
 
743
777
  def find_children(rid: RID):
744
- children = [
745
- child["Nested_Dataset"]
746
- for child in nested_datasets
747
- if child["Dataset"] == rid
748
- ]
778
+ children = [child["Nested_Dataset"] for child in nested_datasets if child["Dataset"] == rid]
749
779
  if recurse:
750
780
  for child in children.copy():
751
781
  children.extend(find_children(child))
@@ -753,9 +783,7 @@ class Dataset:
753
783
 
754
784
  return find_children(dataset_rid)
755
785
 
756
- def _export_vocabulary(
757
- self, writer: Callable[[str, str, Table], list[dict[str, Any]]]
758
- ) -> list[dict[str, Any]]:
786
+ def _export_vocabulary(self, writer: Callable[[str, str, Table], list[dict[str, Any]]]) -> list[dict[str, Any]]:
759
787
  """
760
788
 
761
789
  Args:
@@ -770,16 +798,12 @@ class Dataset:
770
798
  for table in s.tables.values()
771
799
  if self._model.is_vocabulary(table)
772
800
  ]
773
- return [
774
- o
775
- for table in vocabs
776
- for o in writer(f"{table.schema.name}:{table.name}", table.name, table)
777
- ]
801
+ return [o for table in vocabs for o in writer(f"{table.schema.name}:{table.name}", table.name, table)]
778
802
 
779
803
  def _table_paths(
780
804
  self,
781
- dataset: Optional[DatasetSpec] = None,
782
- snapshot_catalog: Optional[DerivaML] = None,
805
+ dataset: DatasetSpec | None = None,
806
+ snapshot_catalog: DerivaML | None = None,
783
807
  ) -> Iterator[tuple[str, str, Table]]:
784
808
  paths = self._collect_paths(dataset and dataset.rid, snapshot_catalog)
785
809
 
@@ -805,25 +829,20 @@ class Dataset:
805
829
 
806
830
  def _collect_paths(
807
831
  self,
808
- dataset_rid: Optional[RID] = None,
809
- snapshot: Optional[Dataset] = None,
810
- dataset_nesting_depth: Optional[int] = None,
832
+ dataset_rid: RID | None = None,
833
+ snapshot: Dataset | None = None,
834
+ dataset_nesting_depth: int | None = None,
811
835
  ) -> set[tuple[Table, ...]]:
812
836
  snapshot_catalog = snapshot if snapshot else self
813
837
 
814
- dataset_table = snapshot_catalog._model.schemas[self._ml_schema].tables[
815
- "Dataset"
816
- ]
817
- dataset_dataset = snapshot_catalog._model.schemas[self._ml_schema].tables[
818
- "Dataset_Dataset"
819
- ]
838
+ dataset_table = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset"]
839
+ dataset_dataset = snapshot_catalog._model.schemas[self._ml_schema].tables["Dataset_Dataset"]
820
840
 
821
841
  # Figure out what types of elements the dataset contains.
822
842
  dataset_associations = [
823
843
  a
824
- for a in self.dataset_table.find_associations()
825
- if a.table.schema.name != self._ml_schema
826
- or a.table.name == "Dataset_Dataset"
844
+ for a in self._dataset_table.find_associations()
845
+ if a.table.schema.name != self._ml_schema or a.table.name == "Dataset_Dataset"
827
846
  ]
828
847
  if dataset_rid:
829
848
  # Get a list of the members of the dataset so we can figure out which tables to query.
@@ -835,9 +854,7 @@ class Dataset:
835
854
  if m
836
855
  ]
837
856
  included_associations = [
838
- a.table
839
- for a in dataset_table.find_associations()
840
- if a.other_fkeys.pop().pk_table in dataset_elements
857
+ a.table for a in dataset_table.find_associations() if a.other_fkeys.pop().pk_table in dataset_elements
841
858
  ]
842
859
  else:
843
860
  included_associations = dataset_associations
@@ -848,9 +865,7 @@ class Dataset:
848
865
  for p in snapshot_catalog._model._schema_to_paths()
849
866
  if (len(p) == 1)
850
867
  or (p[1] not in dataset_associations) # Tables in the domain schema
851
- or (
852
- p[1] in included_associations
853
- ) # Tables that include members of the dataset
868
+ or (p[1] in included_associations) # Tables that include members of the dataset
854
869
  }
855
870
  # Now get paths for nested datasets
856
871
  nested_paths = set()
@@ -860,56 +875,42 @@ class Dataset:
860
875
  else:
861
876
  # Initialize nesting depth if not already provided.
862
877
  dataset_nesting_depth = (
863
- self._dataset_nesting_depth()
864
- if dataset_nesting_depth is None
865
- else dataset_nesting_depth
878
+ self._dataset_nesting_depth() if dataset_nesting_depth is None else dataset_nesting_depth
866
879
  )
867
880
  if dataset_nesting_depth:
868
- nested_paths = self._collect_paths(
869
- dataset_nesting_depth=dataset_nesting_depth - 1
870
- )
881
+ nested_paths = self._collect_paths(dataset_nesting_depth=dataset_nesting_depth - 1)
871
882
  if nested_paths:
872
883
  paths |= {
873
884
  tuple([dataset_table]),
874
885
  (dataset_table, dataset_dataset),
875
886
  }
876
- paths |= {(self.dataset_table, dataset_dataset) + p for p in nested_paths}
887
+ paths |= {(self._dataset_table, dataset_dataset) + p for p in nested_paths}
877
888
  return paths
878
889
 
879
- def _dataset_nesting_depth(self, dataset_rid: Optional[RID] = None) -> int:
890
+ def _dataset_nesting_depth(self, dataset_rid: RID | None = None) -> int:
880
891
  """Determine the maximum dataset nesting depth in the current catalog.
881
892
 
882
893
  Returns:
883
894
 
884
895
  """
885
896
 
886
- def children_depth(
887
- dataset_rid: RID, nested_datasets: dict[str, list[str]]
888
- ) -> int:
897
+ def children_depth(dataset_rid: RID, nested_datasets: dict[str, list[str]]) -> int:
889
898
  """Return the number of nested datasets for the dataset_rid if provided, otherwise in the current catalog"""
890
899
  try:
891
900
  children = nested_datasets[dataset_rid]
892
- return (
893
- max(map(lambda x: children_depth(x, nested_datasets), children)) + 1
894
- if children
895
- else 1
896
- )
901
+ return max(map(lambda x: children_depth(x, nested_datasets), children)) + 1 if children else 1
897
902
  except KeyError:
898
903
  return 0
899
904
 
900
905
  # Build up the dataset_table nesting graph...
901
- pb = (
902
- self._model.catalog.getPathBuilder()
903
- .schemas[self._ml_schema]
904
- .tables["Dataset_Dataset"]
905
- )
906
+ pb = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Dataset"]
906
907
  dataset_children = (
907
908
  [
908
909
  {
909
910
  "Dataset": dataset_rid,
910
911
  "Nested_Dataset": c,
911
912
  } # Make uniform with return from datapath
912
- for c in self.list_dataset_children(dataset_rid)
913
+ for c in self.list_dataset_children(dataset_rid=dataset_rid)
913
914
  ]
914
915
  if dataset_rid
915
916
  else pb.entities().fetch()
@@ -917,30 +918,29 @@ class Dataset:
917
918
  nested_dataset = defaultdict(list)
918
919
  for ds in dataset_children:
919
920
  nested_dataset[ds["Dataset"]].append(ds["Nested_Dataset"])
920
- return (
921
- max(map(lambda d: children_depth(d, dict(nested_dataset)), nested_dataset))
922
- if nested_dataset
923
- else 0
924
- )
921
+ return max(map(lambda d: children_depth(d, dict(nested_dataset)), nested_dataset)) if nested_dataset else 0
925
922
 
926
923
  def _dataset_specification(
927
924
  self,
928
925
  writer: Callable[[str, str, Table], list[dict[str, Any]]],
929
- dataset: Optional[DatasetSpec] = None,
930
- snapshot_catalog: Optional[DerivaML] = None,
926
+ dataset: DatasetSpec | None = None,
927
+ snapshot_catalog: DerivaML | None = None,
931
928
  ) -> list[dict[str, Any]]:
932
- """Output a download/export specification for a dataset_table. Each element of the dataset_table will be placed in its own dir
933
- The top level data directory of the resulting BDBag will have one subdirectory for element type. The subdirectory
934
- will contain the CSV indicating which elements of that type are present in the dataset_table, and then there will be a
935
- subdirectory for each object that is reachable from the dataset_table members.
936
-
937
- To simplify reconstructing the relationship between tables, the CVS for each
938
- The top level data directory will also contain a subdirectory for any controlled vocabularies used in the dataset_table.
939
- All assets will be placed into a directory named asset in a subdirectory with the asset table name.
940
-
941
- For example, consider a dataset_table that consists of two element types, T1 and T2. T1 has foreign key relationships to
942
- objects in tables T3 and T4. There are also two controlled vocabularies, CV1 and CV2. T2 is an asset table
943
- which has two assets in it. The layout of the resulting bdbag would be:
929
+ """Output a download/export specification for a dataset_table. Each element of the dataset_table
930
+ will be placed in its own directory.
931
+ The top level data directory of the resulting BDBag will have one subdirectory for element type.
932
+ The subdirectory will contain the CSV indicating which elements of that type are present in the
933
+ dataset_table, and then there will be a subdirectory for each object that is reachable from the
934
+ dataset_table members.
935
+
936
+ To simplify reconstructing the relationship between tables, the CVS for each element is included.
937
+ The top level data directory will also contain a subdirectory for any controlled vocabularies used in
938
+ the dataset_table. All assets will be placed into a directory named asset in a subdirectory with the
939
+ asset table name.
940
+
941
+ For example, consider a dataset_table that consists of two element types, T1 and T2. T1 has foreign
942
+ key relationships to objects in tables T3 and T4. There are also two controlled vocabularies, CV1 and
943
+ CV2. T2 is an asset table which has two assets in it. The layout of the resulting bdbag would be:
944
944
  data
945
945
  CV1/
946
946
  cv1.csv
@@ -967,17 +967,15 @@ class Dataset:
967
967
  A dataset_table specification.
968
968
  """
969
969
  element_spec = self._export_vocabulary(writer)
970
- for path in self._table_paths(
971
- dataset=dataset, snapshot_catalog=snapshot_catalog
972
- ):
970
+ for path in self._table_paths(dataset=dataset, snapshot_catalog=snapshot_catalog):
973
971
  element_spec.extend(writer(*path))
974
972
  return element_spec
975
973
 
976
974
  def _download_dataset_bag(
977
975
  self,
978
976
  dataset: DatasetSpec,
979
- execution_rid: Optional[RID] = None,
980
- snapshot_catalog: Optional[DerivaML] = None,
977
+ execution_rid: RID | None = None,
978
+ snapshot_catalog: DerivaML | None = None,
981
979
  ) -> DatasetBag:
982
980
  """Download a dataset onto the local file system. Create a MINID for the dataset if one doesn't already exist.
983
981
 
@@ -1007,27 +1005,29 @@ class Dataset:
1007
1005
 
1008
1006
  def _version_snapshot(self, dataset: DatasetSpec) -> str:
1009
1007
  """Return a catalog with snapshot for the specified dataset version"""
1010
- version_record = [
1011
- h
1012
- for h in self.dataset_history(dataset_rid=dataset.rid)
1013
- if h.dataset_version == dataset.version
1014
- ][0]
1008
+ try:
1009
+ version_record = next(
1010
+ h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1011
+ )
1012
+ except StopIteration:
1013
+ raise DerivaMLException(f"Dataset version {dataset.version} not found for dataset {dataset.rid}")
1015
1014
  return f"{self._model.catalog.catalog_id}@{version_record.snapshot}"
1016
1015
 
1017
- def _create_dataset_minid(
1018
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1019
- ) -> str:
1016
+ def _create_dataset_minid(self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None) -> str:
1020
1017
  with TemporaryDirectory() as tmp_dir:
1021
1018
  # Generate a download specification file for the current catalog schema. By default, this spec
1022
1019
  # will generate a minid and place the bag into S3 storage.
1023
- spec_file = f"{tmp_dir}/download_spec.json"
1024
- with open(spec_file, "w", encoding="utf-8") as ds:
1025
- json.dump(
1026
- self._generate_dataset_download_spec(dataset, snapshot_catalog), ds
1027
- )
1020
+ spec_file = Path(tmp_dir) / "download_spec.json"
1021
+ with spec_file.open("w", encoding="utf-8") as ds:
1022
+ json.dump(self._generate_dataset_download_spec(dataset, snapshot_catalog), ds)
1028
1023
  try:
1029
1024
  self._logger.info(
1030
- f"Downloading dataset {'minid' if self._use_minid else 'bag'} for catalog: {dataset.rid}@{str(dataset.version)}"
1025
+ "Downloading dataset %s for catalog: %s@%s"
1026
+ % (
1027
+ "minid" if self._use_minid else "bag",
1028
+ dataset.rid,
1029
+ str(dataset.version),
1030
+ )
1031
1031
  )
1032
1032
  # Generate the bag and put into S3 storage.
1033
1033
  exporter = DerivaExport(
@@ -1050,15 +1050,9 @@ class Dataset:
1050
1050
  raise DerivaMLException(format_exception(e))
1051
1051
  # Update version table with MINID.
1052
1052
  if self._use_minid:
1053
- version_path = (
1054
- self._model.catalog.getPathBuilder()
1055
- .schemas[self._ml_schema]
1056
- .tables["Dataset_Version"]
1057
- )
1053
+ version_path = self._model.catalog.getPathBuilder().schemas[self._ml_schema].tables["Dataset_Version"]
1058
1054
  version_rid = [
1059
- h
1060
- for h in self.dataset_history(dataset_rid=dataset.rid)
1061
- if h.dataset_version == dataset.version
1055
+ h for h in self.dataset_history(dataset_rid=dataset.rid) if h.dataset_version == dataset.version
1062
1056
  ][0].version_rid
1063
1057
  version_path.update([{"RID": version_rid, "Minid": minid_page_url}])
1064
1058
  return minid_page_url
@@ -1066,10 +1060,10 @@ class Dataset:
1066
1060
  def _get_dataset_minid(
1067
1061
  self,
1068
1062
  dataset: DatasetSpec,
1069
- snapshot_catalog: Optional[DerivaML] = None,
1063
+ snapshot_catalog: DerivaML | None = None,
1070
1064
  create: bool = True,
1071
- ) -> DatasetMinid:
1072
- """Return a MINID to the specified dataset. If no version is specified, use the latest.
1065
+ ) -> DatasetMinid | None:
1066
+ """Return a MINID for the specified dataset. If no version is specified, use the latest.
1073
1067
 
1074
1068
  Args:
1075
1069
  dataset: Specification of the dataset.
@@ -1079,50 +1073,53 @@ class Dataset:
1079
1073
  Returns:
1080
1074
  New or existing MINID for the dataset.
1081
1075
  """
1082
- if dataset.rid.startswith("minid"):
1083
- minid_url = f"https://identifiers.org/{dataset.rid}"
1084
- elif dataset.rid.startswith("http"):
1085
- minid_url = dataset.rid
1086
- else:
1087
- if not any([dataset.rid == ds["RID"] for ds in self.find_datasets()]):
1088
- raise DerivaMLException(f"RID {dataset.rid} is not a dataset_table")
1089
-
1090
- # Get the history record for the version we are looking for.
1091
- dataset_version_record = [
1092
- v
1093
- for v in self.dataset_history(dataset.rid)
1094
- if v.dataset_version == str(dataset.version)
1095
- ][0]
1096
- if not dataset_version_record:
1097
- raise DerivaMLException(
1098
- f"Version {str(dataset.version)} does not exist for RID {dataset.rid}"
1099
- )
1100
- minid_url = dataset_version_record.minid
1101
- if not minid_url:
1102
- if not create:
1103
- raise DerivaMLException(
1104
- f"Minid for dataset {dataset.rid} doesn't exist"
1105
- )
1106
- if self._use_minid:
1107
- self._logger.info("Creating new MINID for dataset %s", dataset.rid)
1108
- minid_url = self._create_dataset_minid(dataset, snapshot_catalog)
1109
- # If provided a MINID, use the MINID metadata to get the checksum and download the bag.
1076
+ rid = dataset.rid
1077
+
1078
+ # Case 1: RID is already a MINID or direct URL
1079
+ if rid.startswith("minid"):
1080
+ return self._fetch_minid_metadata(f"https://identifiers.org/{rid}", dataset.version)
1081
+ if rid.startswith("http"):
1082
+ return self._fetch_minid_metadata(rid, dataset.version)
1083
+
1084
+ # Case 2: RID is a dataset RID validate existence
1085
+ if not any(rid == ds["RID"] for ds in self.find_datasets()):
1086
+ raise DerivaMLTableTypeError("Dataset", rid)
1087
+
1088
+ # Find dataset version record
1089
+ version_str = str(dataset.version)
1090
+ history = self.dataset_history(rid)
1091
+ try:
1092
+ version_record = next(v for v in history if v.dataset_version == version_str)
1093
+ except StopIteration:
1094
+ raise DerivaMLException(f"Version {version_str} does not exist for RID {rid}")
1095
+
1096
+ # Check or create MINID
1097
+ minid_url = version_record.minid
1098
+ if not minid_url:
1099
+ if not create:
1100
+ raise DerivaMLException(f"Minid for dataset {rid} doesn't exist")
1110
1101
  if self._use_minid:
1111
- r = requests.get(minid_url, headers={"accept": "application/json"})
1112
- dataset_minid = DatasetMinid(
1113
- dataset_version=dataset.version, **r.json()
1114
- )
1115
- else:
1116
- dataset_minid = DatasetMinid(
1117
- dataset_version=dataset.version,
1118
- RID=f"{dataset.rid}@{dataset_version_record.snapshot}",
1119
- location=minid_url,
1120
- )
1121
- return dataset_minid
1102
+ self._logger.info("Creating new MINID for dataset %s", rid)
1103
+ minid_url = self._create_dataset_minid(dataset, snapshot_catalog)
1104
+
1105
+ # Return based on MINID usage
1106
+ if self._use_minid:
1107
+ return self._fetch_minid_metadata(minid_url, dataset.version)
1108
+
1109
+ return DatasetMinid(
1110
+ dataset_version=dataset.version,
1111
+ RID=f"{rid}@{version_record.snapshot}",
1112
+ location=minid_url,
1113
+ )
1114
+
1115
+ def _fetch_minid_metadata(self, url: str, version: DatasetVersion) -> DatasetMinid:
1116
+ r = requests.get(url, headers={"accept": "application/json"})
1117
+ r.raise_for_status()
1118
+ return DatasetMinid(dataset_version=version, **r.json())
1122
1119
 
1123
1120
  def _download_dataset_minid(self, minid: DatasetMinid) -> Path:
1124
- """Given a RID to a dataset_table, or a MINID to an existing bag, download the bag file, extract it, and validate
1125
- that all the metadata is correct
1121
+ """Given a RID to a dataset_table, or a MINID to an existing bag, download the bag file, extract it, and
1122
+ validate that all the metadata is correct
1126
1123
 
1127
1124
  Args:
1128
1125
  minid: The RID of a dataset_table or a minid to an existing bag.
@@ -1134,9 +1131,7 @@ class Dataset:
1134
1131
  # it. If not, then we need to extract the contents of the archive into our cache directory.
1135
1132
  bag_dir = self._cache_dir / f"{minid.dataset_rid}_{minid.checksum}"
1136
1133
  if bag_dir.exists():
1137
- self._logger.info(
1138
- f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}"
1139
- )
1134
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1140
1135
  return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1141
1136
 
1142
1137
  # Either bag hasn't been downloaded yet, or we are not using a Minid, so we don't know the checksum yet.
@@ -1145,19 +1140,13 @@ class Dataset:
1145
1140
  # Get bag from S3
1146
1141
  archive_path = fetch_single_file(minid.bag_url)
1147
1142
  else:
1148
- exporter = DerivaExport(
1149
- host=self._model.catalog.deriva_server.server, output_dir=tmp_dir
1150
- )
1143
+ exporter = DerivaExport(host=self._model.catalog.deriva_server.server, output_dir=tmp_dir)
1151
1144
  archive_path = exporter.retrieve_file(minid.bag_url)
1152
- hashes = hash_utils.compute_file_hashes(
1153
- archive_path, hashes=["md5", "sha256"]
1154
- )
1145
+ hashes = hash_utils.compute_file_hashes(archive_path, hashes=["md5", "sha256"])
1155
1146
  checksum = hashes["sha256"][0]
1156
1147
  bag_dir = self._cache_dir / f"{minid.dataset_rid}_{checksum}"
1157
1148
  if bag_dir.exists():
1158
- self._logger.info(
1159
- f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}"
1160
- )
1149
+ self._logger.info(f"Using cached bag for {minid.dataset_rid} Version:{minid.dataset_version}")
1161
1150
  return Path(bag_dir / f"Dataset_{minid.dataset_rid}")
1162
1151
  bag_path = bdb.extract_bag(archive_path, bag_dir.as_posix())
1163
1152
  bdb.validate_bag_structure(bag_path)
@@ -1166,7 +1155,7 @@ class Dataset:
1166
1155
  def _materialize_dataset_bag(
1167
1156
  self,
1168
1157
  minid: DatasetMinid,
1169
- execution_rid: Optional[RID] = None,
1158
+ execution_rid: RID | None = None,
1170
1159
  ) -> Path:
1171
1160
  """Materialize a dataset_table bag into a local directory
1172
1161
 
@@ -1180,9 +1169,7 @@ class Dataset:
1180
1169
  def update_status(status: Status, msg: str) -> None:
1181
1170
  """Update the current status for this execution in the catalog"""
1182
1171
  if execution_rid and execution_rid != DRY_RUN_RID:
1183
- self._model.catalog.getPathBuilder().schemas[
1184
- self._ml_schema
1185
- ].Execution.update(
1172
+ self._model.catalog.getPathBuilder().schemas[self._ml_schema].Execution.update(
1186
1173
  [
1187
1174
  {
1188
1175
  "RID": execution_rid,
@@ -1212,9 +1199,7 @@ class Dataset:
1212
1199
 
1213
1200
  # If this bag has already been validated, our work is done. Otherwise, materialize the bag.
1214
1201
  if not validated_check.exists():
1215
- self._logger.info(
1216
- f"Materializing bag {minid.dataset_rid} Version:{minid.dataset_version}"
1217
- )
1202
+ self._logger.info(f"Materializing bag {minid.dataset_rid} Version:{minid.dataset_version}")
1218
1203
  bdb.materialize(
1219
1204
  bag_path.as_posix(),
1220
1205
  fetch_callback=fetch_progress_callback,
@@ -1225,7 +1210,7 @@ class Dataset:
1225
1210
 
1226
1211
  def _export_annotation(
1227
1212
  self,
1228
- snapshot_catalog: Optional[DerivaML] = None,
1213
+ snapshot_catalog: DerivaML | None = None,
1229
1214
  ) -> list[dict[str, Any]]:
1230
1215
  """Return and output specification for the datasets in the provided model
1231
1216
 
@@ -1257,7 +1242,7 @@ class Dataset:
1257
1242
  )
1258
1243
 
1259
1244
  def _export_specification(
1260
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1245
+ self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1261
1246
  ) -> list[dict[str, Any]]:
1262
1247
  """
1263
1248
  Generate a specification for export engine for specific dataset.
@@ -1273,14 +1258,10 @@ class Dataset:
1273
1258
  "processor": "json",
1274
1259
  "processor_params": {"query_path": "/schema", "output_path": "schema"},
1275
1260
  }
1276
- ] + self._dataset_specification(
1277
- self._export_specification_dataset_element, dataset, snapshot_catalog
1278
- )
1261
+ ] + self._dataset_specification(self._export_specification_dataset_element, dataset, snapshot_catalog)
1279
1262
 
1280
1263
  @staticmethod
1281
- def _export_specification_dataset_element(
1282
- spath: str, dpath: str, table: Table
1283
- ) -> list[dict[str, Any]]:
1264
+ def _export_specification_dataset_element(spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1284
1265
  """Return the download specification for the data object indicated by a path through the data model.
1285
1266
 
1286
1267
  Args:
@@ -1315,10 +1296,9 @@ class Dataset:
1315
1296
  )
1316
1297
  return exports
1317
1298
 
1318
- def _export_annotation_dataset_element(
1319
- self, spath: str, dpath: str, table: Table
1320
- ) -> list[dict[str, Any]]:
1321
- """Given a path in the data model, output an export specification for the path taken to get to the current table.
1299
+ def _export_annotation_dataset_element(self, spath: str, dpath: str, table: Table) -> list[dict[str, Any]]:
1300
+ """Given a path in the data model, output an export specification for the path taken to get to the
1301
+ current table.
1322
1302
 
1323
1303
  Args:
1324
1304
  spath: Source path
@@ -1369,7 +1349,7 @@ class Dataset:
1369
1349
  return exports
1370
1350
 
1371
1351
  def _generate_dataset_download_spec(
1372
- self, dataset: DatasetSpec, snapshot_catalog: Optional[DerivaML] = None
1352
+ self, dataset: DatasetSpec, snapshot_catalog: DerivaML | None = None
1373
1353
  ) -> dict[str, Any]:
1374
1354
  """
1375
1355
  Generate a specification for downloading a specific dataset.
@@ -1472,9 +1452,7 @@ class Dataset:
1472
1452
  else {}
1473
1453
  )
1474
1454
  return {
1475
- deriva_tags.export_fragment_definitions: {
1476
- "dataset_export_outputs": self._export_annotation()
1477
- },
1455
+ deriva_tags.export_fragment_definitions: {"dataset_export_outputs": self._export_annotation()},
1478
1456
  deriva_tags.visible_foreign_keys: self._dataset_visible_fkeys(),
1479
1457
  deriva_tags.export_2019: {
1480
1458
  "detailed": {