deriva-ml 1.17.9__py3-none-any.whl → 1.17.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. deriva_ml/__init__.py +43 -1
  2. deriva_ml/asset/__init__.py +17 -0
  3. deriva_ml/asset/asset.py +357 -0
  4. deriva_ml/asset/aux_classes.py +100 -0
  5. deriva_ml/bump_version.py +254 -11
  6. deriva_ml/catalog/__init__.py +21 -0
  7. deriva_ml/catalog/clone.py +1199 -0
  8. deriva_ml/catalog/localize.py +426 -0
  9. deriva_ml/core/__init__.py +29 -0
  10. deriva_ml/core/base.py +817 -1067
  11. deriva_ml/core/config.py +169 -21
  12. deriva_ml/core/constants.py +120 -19
  13. deriva_ml/core/definitions.py +123 -13
  14. deriva_ml/core/enums.py +47 -73
  15. deriva_ml/core/ermrest.py +226 -193
  16. deriva_ml/core/exceptions.py +297 -14
  17. deriva_ml/core/filespec.py +99 -28
  18. deriva_ml/core/logging_config.py +225 -0
  19. deriva_ml/core/mixins/__init__.py +42 -0
  20. deriva_ml/core/mixins/annotation.py +915 -0
  21. deriva_ml/core/mixins/asset.py +384 -0
  22. deriva_ml/core/mixins/dataset.py +237 -0
  23. deriva_ml/core/mixins/execution.py +408 -0
  24. deriva_ml/core/mixins/feature.py +365 -0
  25. deriva_ml/core/mixins/file.py +263 -0
  26. deriva_ml/core/mixins/path_builder.py +145 -0
  27. deriva_ml/core/mixins/rid_resolution.py +204 -0
  28. deriva_ml/core/mixins/vocabulary.py +400 -0
  29. deriva_ml/core/mixins/workflow.py +322 -0
  30. deriva_ml/core/validation.py +389 -0
  31. deriva_ml/dataset/__init__.py +2 -1
  32. deriva_ml/dataset/aux_classes.py +20 -4
  33. deriva_ml/dataset/catalog_graph.py +575 -0
  34. deriva_ml/dataset/dataset.py +1242 -1008
  35. deriva_ml/dataset/dataset_bag.py +1311 -182
  36. deriva_ml/dataset/history.py +27 -14
  37. deriva_ml/dataset/upload.py +225 -38
  38. deriva_ml/demo_catalog.py +186 -105
  39. deriva_ml/execution/__init__.py +46 -2
  40. deriva_ml/execution/base_config.py +639 -0
  41. deriva_ml/execution/execution.py +545 -244
  42. deriva_ml/execution/execution_configuration.py +26 -11
  43. deriva_ml/execution/execution_record.py +592 -0
  44. deriva_ml/execution/find_caller.py +298 -0
  45. deriva_ml/execution/model_protocol.py +175 -0
  46. deriva_ml/execution/multirun_config.py +153 -0
  47. deriva_ml/execution/runner.py +595 -0
  48. deriva_ml/execution/workflow.py +224 -35
  49. deriva_ml/experiment/__init__.py +8 -0
  50. deriva_ml/experiment/experiment.py +411 -0
  51. deriva_ml/feature.py +6 -1
  52. deriva_ml/install_kernel.py +143 -6
  53. deriva_ml/interfaces.py +862 -0
  54. deriva_ml/model/__init__.py +99 -0
  55. deriva_ml/model/annotations.py +1278 -0
  56. deriva_ml/model/catalog.py +286 -60
  57. deriva_ml/model/database.py +144 -649
  58. deriva_ml/model/deriva_ml_database.py +308 -0
  59. deriva_ml/model/handles.py +14 -0
  60. deriva_ml/run_model.py +319 -0
  61. deriva_ml/run_notebook.py +507 -38
  62. deriva_ml/schema/__init__.py +18 -2
  63. deriva_ml/schema/annotations.py +62 -33
  64. deriva_ml/schema/create_schema.py +169 -69
  65. deriva_ml/schema/validation.py +601 -0
  66. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -5
  67. deriva_ml-1.17.11.dist-info/RECORD +77 -0
  68. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
  69. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +2 -0
  70. deriva_ml/protocols/dataset.py +0 -19
  71. deriva_ml/test.py +0 -94
  72. deriva_ml-1.17.9.dist-info/RECORD +0 -45
  73. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
  74. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
@@ -1,719 +1,214 @@
1
- """This module contains the definition of the DatabaseModel class. The role of this class is to provide an interface
2
- between the BDBag representation of a dataset and a sqlite database in which the contents of the bag are stored.
1
+ """DerivaML-specific database model for downloaded BDBags.
2
+
3
+ This module provides the DatabaseModel class which extends the generic BagDatabase
4
+ from deriva-py with DerivaML-specific functionality:
5
+
6
+ - Dataset version tracking
7
+ - Dataset RID resolution
8
+ - Integration with DerivaModel for schema analysis
9
+
10
+ For schema-independent BDBag operations, see deriva.core.bag_database.BagDatabase.
3
11
  """
4
12
 
5
13
  from __future__ import annotations
6
14
 
7
- import json
8
15
  import logging
9
- from csv import reader
10
16
  from pathlib import Path
11
- from typing import Any, Generator, Optional, Type
12
- from urllib.parse import urlparse
17
+ from typing import Any, Generator, Optional
13
18
 
14
- from dateutil import parser
15
- from deriva.core.ermrest_model import Column as DerivaColumn
19
+ from sqlalchemy import select
20
+ from sqlalchemy.orm import Session
21
+
22
+ from deriva.core.bag_database import BagDatabase
16
23
  from deriva.core.ermrest_model import Model
17
24
  from deriva.core.ermrest_model import Table as DerivaTable
18
- from deriva.core.ermrest_model import Type as DerivaType
19
- from pydantic import ConfigDict, validate_call
20
- from sqlalchemy import (
21
- JSON,
22
- Boolean,
23
- Date,
24
- DateTime,
25
- Float,
26
- Integer,
27
- MetaData,
28
- String,
29
- create_engine,
30
- event,
31
- inspect,
32
- select,
33
- )
34
- from sqlalchemy import Column as SQLColumn
35
- from sqlalchemy import ForeignKeyConstraint as SQLForeignKeyConstraint
36
- from sqlalchemy import Table as SQLTable
37
- from sqlalchemy import UniqueConstraint as SQLUniqueConstraint
38
- from sqlalchemy.dialects.sqlite import insert as sqlite_insert
39
- from sqlalchemy.ext.automap import automap_base
40
- from sqlalchemy.orm import backref, configure_mappers, foreign, relationship
41
- from sqlalchemy.sql.type_api import TypeEngine
42
- from sqlalchemy.types import TypeDecorator
43
-
44
- from deriva_ml.core.definitions import ML_SCHEMA, RID, MLVocab
25
+
26
+ from deriva_ml.core.definitions import ML_SCHEMA, RID, get_domain_schemas
45
27
  from deriva_ml.core.exceptions import DerivaMLException
46
28
  from deriva_ml.dataset.aux_classes import DatasetMinid, DatasetVersion
47
- from deriva_ml.dataset.dataset_bag import DatasetBag
48
29
  from deriva_ml.model.catalog import DerivaModel
49
30
 
50
- try:
51
- from icecream import ic
52
- except ImportError: # Graceful fallback if IceCream isn't installed.
53
- ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
54
-
55
31
 
56
- class ERMRestBoolean(TypeDecorator):
57
- impl = Boolean
58
- cache_ok = True
32
+ class DatabaseModel(BagDatabase, DerivaModel):
33
+ """DerivaML database model for downloaded BDBags.
59
34
 
60
- def process_bind_param(self, value, dialect):
61
- if value in ("Y", "y", 1, True, "t", "T"):
62
- return True
63
- elif value in ("N", "n", 0, False, "f", "F"):
64
- return False
65
- elif value is None:
66
- return None
67
- raise ValueError(f"Invalid boolean value: {value!r}")
68
-
69
-
70
- class StringToFloat(TypeDecorator):
71
- impl = Float
72
- cache_ok = True
73
-
74
- def process_bind_param(self, value, dialect):
75
- if value == "" or value is None:
76
- return None
77
- else:
78
- return float(value)
79
-
80
- class StringToInteger(TypeDecorator):
81
- impl = Integer
82
- cache_ok = True
83
-
84
- def process_bind_param(self, value, dialect):
85
- if value == "" or value is None:
86
- return None
87
- else:
88
- return int(value)
89
-
90
-
91
- class StringToDateTime(TypeDecorator):
92
- impl = DateTime
93
- cache_ok = True
94
-
95
- def process_bind_param(self, value, dialect):
96
- if value == "" or value is None:
97
- return None
98
- else:
99
- return parser.parse(value)
100
-
101
- class StringToDate(TypeDecorator):
102
- impl = Date
103
- cache_ok = True
104
-
105
- def process_bind_param(self, value, dialect):
106
- if value == "" or value is None:
107
- return None
108
- else:
109
- return parser.parse(value).date()
35
+ This class combines the generic BagDatabase functionality with DerivaML-specific
36
+ features like dataset versioning and the DerivaModel schema utilities.
110
37
 
111
- class DatabaseModelMeta(type):
112
- """Use metaclass to ensure that there is only one instance of a database model per path"""
113
-
114
- _paths_loaded: dict[Path, "DatabaseModel"] = {}
115
-
116
- def __call__(cls, *args, **kwargs):
117
- logger = logging.getLogger("deriva_ml")
118
- bag_path: Path = args[1]
119
- if bag_path.as_posix() not in cls._paths_loaded:
120
- logger.info(f"Loading {bag_path}")
121
- cls._paths_loaded[bag_path] = super().__call__(*args, **kwargs)
122
- return cls._paths_loaded[bag_path]
123
-
124
-
125
- class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
126
- """Read in the contents of a BDBag and create a local SQLite database.
127
-
128
- As part of its initialization, this routine will create a sqlite database that has the contents of all the
129
- tables in the dataset_table. In addition, any asset tables will the `Filename` column remapped to have the path
130
- of the local copy of the file. In addition, a local version of the ERMRest model that as used to generate the
131
- dataset_table is available.
132
-
133
- The sqlite database will not have any foreign key constraints applied, however, foreign-key relationships can be
134
- found by looking in the ERMRest model. In addition, as sqlite doesn't support schema, Ermrest schema are added
135
- to the table name using the convention SchemaName:TableName. Methods in DatasetBag that have table names as the
136
- argument will perform the appropriate name mappings.
137
-
138
- Because of nested datasets, it's possible that more than one dataset rid is in a bag, or that a dataset rid might
139
- appear in more than one database. To help manage this, a global list of all the datasets that have been loaded
140
- into DatabaseModels, is kept in the class variable `_rid_map`.
141
-
142
- Because you can load different versions of a dataset simultaneously, the dataset RID and version number are tracked,
143
- and a new sqlite instance is created for every new dataset version present.
38
+ It reads a BDBag and creates a SQLite database, then provides:
39
+ - All BagDatabase query methods (list_tables, get_table_contents, etc.)
40
+ - All DerivaModel schema methods (find_features, is_asset, etc.)
41
+ - Dataset version tracking (bag_rids, dataset_version)
42
+ - Dataset RID validation (rid_lookup)
144
43
 
145
44
  Attributes:
146
- bag_path (Path): path to the local copy of the BDBag
147
- minid (DatasetMinid): Minid for the specified bag
148
- dataset_rid (RID): RID for the specified dataset
149
- engine (Connection): connection to the sqlalchemy database holding table values
150
- domain_schema (str): Name of the domain schema
151
- dataset_table (Table): the dataset table in the ERMRest model.
45
+ bag_path: Path to the BDBag directory.
46
+ minid: DatasetMinid for the downloaded bag.
47
+ dataset_rid: Primary dataset RID in this bag.
48
+ bag_rids: Dictionary mapping all dataset RIDs to their versions.
49
+ dataset_table: The Dataset table from the ERMrest model.
50
+
51
+ Example:
52
+ >>> db = DatabaseModel(minid, bag_path, working_dir)
53
+ >>> version = db.dataset_version("ABC123")
54
+ >>> for row in db.get_table_contents("Image"):
55
+ ... print(row["Filename"])
152
56
  """
153
57
 
154
- # Maintain a global map of RIDS to versions and databases.
155
- _rid_map: dict[RID, list[tuple[DatasetVersion, "DatabaseModel"]]] = {}
156
-
157
58
  def __init__(self, minid: DatasetMinid, bag_path: Path, dbase_path: Path):
158
- """Create a new DatabaseModel.
59
+ """Create a DerivaML database from a BDBag.
159
60
 
160
61
  Args:
161
- minid: Minid for the specified bag.
162
- bag_path: Path to the local copy of the BDBag.
62
+ minid: DatasetMinid containing bag metadata (RID, version, etc.).
63
+ bag_path: Path to the BDBag directory.
64
+ dbase_path: Base directory for SQLite database files.
163
65
  """
164
-
165
- super().__init__(Model.fromfile("file-system", bag_path / "data/schema.json"))
166
-
167
- self.bag_path = bag_path
66
+ self._logger = logging.getLogger("deriva_ml")
168
67
  self.minid = minid
169
68
  self.dataset_rid = minid.dataset_rid
170
- self.dbase_path = dbase_path / f"{minid.version_rid}"
171
- self.dbase_path.mkdir(parents=True, exist_ok=True)
172
-
173
- self.engine = create_engine(f"sqlite:///{(self.dbase_path / 'main.db').resolve()}", future=True)
174
- self.metadata = MetaData()
175
- self.Base = automap_base(metadata=self.metadata)
176
69
 
177
- # Attach event listener for *this instance's* engine
178
- event.listen(self.engine, "connect", self._attach_schemas)
70
+ # Load the model first to determine schema names
71
+ schema_file = bag_path / "data/schema.json"
72
+ temp_model = Model.fromfile("file-system", schema_file)
179
73
 
180
- schema_file = self.bag_path / "data/schema.json"
181
- with schema_file.open("r") as f:
182
- self.snaptime = json.load(f)["snaptime"]
74
+ # Determine domain schemas using schema classification
75
+ ml_schema = ML_SCHEMA
76
+ domain_schemas = get_domain_schemas(temp_model.schemas.keys(), ml_schema)
183
77
 
184
- self._logger = logging.getLogger("deriva_ml")
185
- self._load_model()
186
- self.ml_schema = ML_SCHEMA
187
- self._load_database()
188
- self._logger.info(
189
- "Creating new database for dataset: %s in %s",
190
- self.dataset_rid,
191
- self.dbase_path,
78
+ # Initialize BagDatabase (creates SQLite DB)
79
+ BagDatabase.__init__(
80
+ self,
81
+ bag_path=bag_path,
82
+ database_dir=dbase_path,
83
+ schemas=[*domain_schemas, ml_schema],
192
84
  )
193
- self.dataset_table = self.model.schemas[self.ml_schema].tables["Dataset"]
194
85
 
195
- # Now go through the database and pick out all the dataset_table RIDS, along with their versions.
196
- with self.engine.connect() as conn:
197
- dataset_version = self.metadata.tables[f"{self.ml_schema}.Dataset_Version"]
198
- result = conn.execute(select(dataset_version.c.Dataset, dataset_version.c.Version))
199
- dataset_versions = [t for t in result]
200
-
201
- dataset_versions = [(v[0], DatasetVersion.parse(v[1])) for v in dataset_versions]
202
- # Get most current version of each rid
203
- self.bag_rids = {}
204
- for rid, version in dataset_versions:
205
- self.bag_rids[rid] = max(self.bag_rids.get(rid, DatasetVersion(0, 1, 0)), version)
206
-
207
- for dataset_rid, dataset_version in self.bag_rids.items():
208
- version_list = DatabaseModel._rid_map.setdefault(dataset_rid, [])
209
- version_list.append((dataset_version, self))
210
-
211
- def _attach_schemas(self, dbapi_conn, _conn_record):
212
- cur = dbapi_conn.cursor()
213
- for schema in [self.domain_schema, self.ml_schema]:
214
- schema_file = (self.dbase_path / f"{schema}.db").resolve()
215
- cur.execute(f"ATTACH DATABASE '{schema_file}' AS '{schema}'")
216
- cur.close()
217
-
218
- @staticmethod
219
- def _sql_type(type: DerivaType) -> TypeEngine:
220
- """Return the SQL type for a Deriva column."""
221
- return {
222
- "boolean": ERMRestBoolean,
223
- "date": StringToDate,
224
- "float4": StringToFloat,
225
- "float8": StringToFloat,
226
- "int2": StringToInteger,
227
- "int4": StringToInteger,
228
- "int8": StringToInteger,
229
- "json": JSON,
230
- "jsonb": JSON,
231
- "timestamptz": StringToDateTime,
232
- "timestamp": StringToDateTime,
233
- }.get(type.typename, String)
234
-
235
- def _load_model(self) -> None:
236
- """Create a sqlite database schema that contains all the tables within the catalog from which the BDBag
237
- was created."""
238
-
239
- def is_key(column: DerivaColumn, table: DerivaTable) -> bool:
240
- return column in [key.unique_columns[0] for key in table.keys] and column.name == "RID"
241
-
242
- def col(model, name: str):
243
- # try ORM attribute first
244
- try:
245
- return getattr(model, name).property.columns[0]
246
- except AttributeError:
247
- # fall back to exact DB column key on the Table
248
- return model.__table__.c[name]
249
-
250
- def guess_attr_name(col_name: str) -> str:
251
- return col_name[:-3] if col_name.lower().endswith("_id") else col_name
252
-
253
- database_tables: list[SQLTable] = []
254
- for schema_name in [self.domain_schema, self.ml_schema]:
255
- for table in self.model.schemas[schema_name].tables.values():
256
- database_columns: list[SQLColumn] = []
257
- for c in table.columns:
258
- # clone column (type, nullability, PK, defaults, unique)
259
- database_column = SQLColumn(
260
- name=c.name,
261
- type_=self._sql_type(c.type), # SQLAlchemy type object is reusable
262
- comment=c.comment,
263
- default=c.default,
264
- primary_key=is_key(c, table),
265
- nullable=c.nullok,
266
- # NOTE: server_onupdate, computed, etc. can be added if you use them
267
- )
268
- database_columns.append(database_column)
269
- database_table = SQLTable(table.name, self.metadata, *database_columns, schema=schema_name)
270
- for key in table.keys:
271
- key_columns = [c.name for c in key.unique_columns]
272
- # if key.name[0] == "RID":
273
- # continue
274
- database_table.append_constraint(
275
- SQLUniqueConstraint(
276
- *key_columns,
277
- name=key.name[1],
278
- )
279
- )
280
- for fk in table.foreign_keys:
281
- if fk.pk_table.schema.name not in [self.domain_schema, self.ml_schema]:
282
- continue
283
- if fk.pk_table.schema.name != schema_name:
284
- continue
285
- # Attach FK to the chosen column
286
- database_table.append_constraint(
287
- SQLForeignKeyConstraint(
288
- columns=[f"{c.name}" for c in fk.foreign_key_columns],
289
- refcolumns=[f"{schema_name}.{c.table.name}.{c.name}" for c in fk.referenced_columns],
290
- name=fk.name[1],
291
- comment=fk.comment,
292
- )
293
- )
294
- database_tables.append(database_table)
295
- with self.engine.begin() as conn:
296
- self.metadata.create_all(conn, tables=database_tables)
297
-
298
- def name_for_scalar_relationship(_base, local_cls, referred_cls, constraint):
299
- cols = list(constraint.columns) if constraint is not None else []
300
- if len(cols) == 1:
301
- name = cols[0].key
302
- if name in {c.key for c in local_cls.__table__.columns}:
303
- name += "_rel"
304
- return name
305
- return constraint.name or referred_cls.__name__.lower()
306
-
307
- def name_for_collection_relationship(_base, local_cls, referred_cls, constraint):
308
- backref_name = constraint.name.replace("_fkey", "_collection")
309
- return backref_name or (referred_cls.__name__.lower() + "_collection")
310
-
311
- # Now build ORM mappings for the tables.
312
- self.Base.prepare(
313
- self.engine,
314
- name_for_scalar_relationship=name_for_scalar_relationship,
315
- name_for_collection_relationship=name_for_collection_relationship,
316
- reflect=True,
86
+ # Initialize DerivaModel (provides schema analysis methods)
87
+ # Note: We pass self.model which was set by BagDatabase
88
+ DerivaModel.__init__(
89
+ self,
90
+ model=self.model,
91
+ ml_schema=ml_schema,
92
+ domain_schemas=domain_schemas,
317
93
  )
318
94
 
319
- for schema in [self.domain_schema, self.ml_schema]:
320
- for table in self.model.schemas[schema].tables.values():
321
- for fk in table.foreign_keys:
322
- if fk.pk_table.schema.name not in [self.domain_schema, self.ml_schema]:
323
- continue
324
- if fk.pk_table.schema.name == schema:
325
- continue
326
- table_name = f"{schema}.{table.name}"
327
- table_class = self.get_orm_class_by_name(table_name)
328
- foreign_key_column_name = fk.foreign_key_columns[0].name
329
- foreign_key_column = col(table_class, foreign_key_column_name)
330
-
331
- referenced_table_name = f"{fk.pk_table.schema.name}.{fk.pk_table.name}"
332
- referenced_class = self.get_orm_class_by_name(referenced_table_name)
333
- referenced_column = col(referenced_class, fk.referenced_columns[0].name)
334
-
335
- relationship_attr = guess_attr_name(foreign_key_column_name)
336
- backref_attr = fk.name[1].replace("_fkey", "_collection")
337
- setattr(
338
- table_class,
339
- relationship_attr,
340
- relationship(
341
- referenced_class,
342
- foreign_keys=[foreign_key_column],
343
- primaryjoin=foreign(foreign_key_column) == referenced_column,
344
- backref=backref(backref_attr, viewonly=True),
345
- viewonly=True, # set False for write behavior, but best with proper FKs
346
- ),
347
- )
348
-
349
- # Reflect won't pick up the second FK in the dataset_dataset table. We need to do it manually
350
- # dataset_dataset_class = self.get_orm_class_by_name("deriva-ml.Dataset_Dataset")
351
- # dataset_class = self.get_orm_class_by_name("deriva-ml.Dataset")
352
- # dataset_dataset_class.Nested_Dataset = relationship(
353
- # dataset_class,
354
- # primaryjoin=foreign(dataset_dataset_class.__table__.c["Nested_Dataset"])
355
- # == dataset_class.__table__.c["RID"],
356
- # foreign_keys=[dataset_dataset_class.__table__.c["Nested_Dataset"]],
357
- # backref=backref("nested_dataset_collection", viewonly=True), # pick a distinct name
358
- # viewonly=True, # keep it read-only unless you truly want writes
359
- # overlaps="Dataset,dataset_dataset_collection", # optional: silence overlap warnings
360
- # )
361
- configure_mappers()
362
-
363
- def _load_database(self) -> None:
364
- """Load a SQLite database from a bdbag. THis is done by looking for all the CSV files in the bdbag directory.
365
-
366
- If the file is for an asset table, update the FileName column of the table to have the local file path for
367
- the materialized file. Then load into the sqlite database.
368
- Note: none of the foreign key constraints are included in the database.
369
- """
370
- dpath = self.bag_path / "data"
371
- asset_map = self._localize_asset_table() # Map of remote to local assets.
372
-
373
- # Find all the CSV files in the subdirectory and load each file into the database.
374
- for csv_file in Path(dpath).rglob("*.csv"):
375
- table = csv_file.stem
376
- schema = self.domain_schema if table in self.model.schemas[self.domain_schema].tables else self.ml_schema
377
- sql_table = self.metadata.tables[f"{schema}.{table}"]
378
-
379
- with csv_file.open(newline="") as csvfile:
380
- csv_reader = reader(csvfile)
381
- column_names = next(csv_reader)
382
-
383
- # Determine which columns in the table has the Filename and the URL
384
- asset_indexes = (
385
- (column_names.index("Filename"), column_names.index("URL")) if self._is_asset(table) else None
386
- )
387
-
388
- with self.engine.begin() as conn:
389
- object_table = [
390
- self._localize_asset(o, asset_indexes, asset_map, table == "Dataset")
391
- for o in csv_reader
392
- ]
393
- conn.execute(
394
- sqlite_insert(sql_table).on_conflict_do_nothing(),
395
- [dict(zip(column_names, row)) for row in object_table],
396
- )
397
-
398
- def _localize_asset_table(self) -> dict[str, str]:
399
- """Use the fetch.txt file in a bdbag to create a map from a URL to a local file path.
95
+ self.dataset_table = self.model.schemas[self.ml_schema].tables["Dataset"]
400
96
 
401
- Returns:
402
- Dictionary that maps a URL to a local file path.
97
+ # Build dataset RID -> version mapping from Dataset_Version table
98
+ self._build_bag_rids()
403
99
 
404
- """
405
- fetch_map = {}
406
- try:
407
- with Path.open(self.bag_path / "fetch.txt", newline="\n") as fetch_file:
408
- for row in fetch_file:
409
- # Rows in fetch.text are tab seperated with URL filename.
410
- fields = row.split("\t")
411
- local_file = fields[2].replace("\n", "")
412
- local_path = f"{self.bag_path}/{local_file}"
413
- fetch_map[urlparse(fields[0]).path] = local_path
414
- except FileNotFoundError:
415
- dataset_rid = self.bag_path.name.replace("Dataset_", "")
416
- logging.info(f"No downloaded assets in bag {dataset_rid}")
417
- return fetch_map
418
-
419
- def _is_asset(self, table_name: str) -> bool:
420
- """
100
+ self._logger.info(
101
+ "Created DerivaML database for dataset %s in %s",
102
+ self.dataset_rid,
103
+ self.database_dir,
104
+ )
421
105
 
422
- Args:
423
- table_name: str:
106
+ def _build_bag_rids(self) -> None:
107
+ """Build mapping of dataset RIDs to their versions in this bag."""
108
+ self.bag_rids: dict[RID, DatasetVersion] = {}
424
109
 
425
- Returns:
426
- Boolean that is true if the table looks like an asset table.
427
- """
428
- asset_columns = {"Filename", "URL", "Length", "MD5", "Description"}
429
- sname = self.domain_schema if table_name in self.model.schemas[self.domain_schema].tables else self.ml_schema
430
- asset_table = self.model.schemas[sname].tables[table_name]
431
- return asset_columns.issubset({c.name for c in asset_table.columns})
110
+ dataset_version_table = self.metadata.tables.get(f"{self.ml_schema}.Dataset_Version")
111
+ if dataset_version_table is None:
112
+ return
432
113
 
433
- @staticmethod
434
- def _localize_asset(o: list, indexes: tuple[int, int], asset_map: dict[str, str], debug: bool = False) -> tuple:
435
- """Given a list of column values for a table, replace the FileName column with the local file name based on
436
- the URL value.
114
+ with self.engine.connect() as conn:
115
+ result = conn.execute(
116
+ select(dataset_version_table.c.Dataset, dataset_version_table.c.Version)
117
+ )
118
+ for rid, version_str in result:
119
+ version = DatasetVersion.parse(version_str)
120
+ # Keep the highest version for each RID
121
+ if rid not in self.bag_rids or version > self.bag_rids[rid]:
122
+ self.bag_rids[rid] = version
123
+
124
+ def dataset_version(self, dataset_rid: Optional[RID] = None) -> DatasetVersion:
125
+ """Get the version of a dataset in this bag.
437
126
 
438
127
  Args:
439
- o: List of values for each column in a table row.
440
- indexes: A tuple whose first element is the column index of the file name and whose second element
441
- is the index of the URL in an asset table. Tuple is None if table is not an asset table.
442
- o: list:
443
- indexes: Optional[tuple[int, int]]:
128
+ dataset_rid: Dataset RID to look up. If None, uses the primary dataset.
444
129
 
445
130
  Returns:
446
- Tuple of updated column values.
447
-
448
- """
449
- if indexes:
450
- file_column, url_column = indexes
451
- o[file_column] = asset_map[o[url_column]] if o[url_column] else ""
452
- return tuple(o)
453
-
454
- def find_table(self, table_name: str) -> SQLTable:
455
- """Find a table in the catalog."""
456
- # We will look across ml and domain schema to find a table whose name matches.
457
- table = [t for t in self.metadata.tables if t == table_name or t.split(".")[1] == table_name][0]
458
- return self.metadata.tables[table]
459
-
460
- def list_tables(self) -> list[str]:
461
- """List the names of the tables in the catalog
131
+ DatasetVersion for the specified dataset.
462
132
 
463
- Returns:
464
- A list of table names. These names are all qualified with the Deriva schema name.
133
+ Raises:
134
+ DerivaMLException: If the RID is not in this bag.
465
135
  """
466
- tables = list(self.metadata.tables.keys())
467
- tables.sort()
468
- return tables
136
+ rid = dataset_rid or self.dataset_rid
137
+ if rid not in self.bag_rids:
138
+ raise DerivaMLException(f"Dataset RID {rid} is not in this bag")
139
+ return self.bag_rids[rid]
469
140
 
470
- def get_dataset(self, dataset_rid: Optional[RID] = None) -> DatasetBag:
471
- """Get a dataset, or nested dataset from the bag database
141
+ def rid_lookup(self, dataset_rid: RID) -> DatasetVersion | None:
142
+ """Check if a dataset RID exists in this bag.
472
143
 
473
144
  Args:
474
- dataset_rid: Optional. If not provided, use the main RID for the bag. If a value is given, it must
475
- be the RID for a nested dataset.
145
+ dataset_rid: RID to look up.
476
146
 
477
147
  Returns:
478
- DatasetBag object for the specified dataset.
479
- """
480
- if dataset_rid and dataset_rid not in self.bag_rids:
481
- raise DerivaMLException(f"Dataset RID {dataset_rid} is not in model.")
482
- return DatasetBag(self, dataset_rid or self.dataset_rid)
483
-
484
- def dataset_version(self, dataset_rid: Optional[RID] = None) -> DatasetVersion:
485
- """Return the version of the specified dataset."""
486
- if dataset_rid and dataset_rid not in self.bag_rids:
487
- DerivaMLException(f"Dataset RID {dataset_rid} is not in model.")
488
- return self.bag_rids[dataset_rid]
148
+ DatasetVersion if found.
489
149
 
490
- def find_datasets(self) -> list[dict[str, Any]]:
491
- """Returns a list of currently available datasets.
492
-
493
- Returns:
494
- list of currently available datasets.
150
+ Raises:
151
+ DerivaMLException: If the RID is not found in this bag.
495
152
  """
496
- atable = next(self.model.schemas[ML_SCHEMA].tables[MLVocab.dataset_type].find_associations()).name
497
-
498
- # Get a list of all the dataset_type values associated with this dataset_table.
499
- datasets = []
500
- ds_types = list(self._get_table_contents(atable))
501
- for dataset in self._get_table_contents("Dataset"):
502
- my_types = [t for t in ds_types if t["Dataset"] == dataset["RID"]]
503
- datasets.append(dataset | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in my_types]})
504
- return datasets
505
-
506
- def list_dataset_members(self, dataset_rid: RID) -> dict[str, Any]:
507
- """Returns a list of all the dataset_table entries associated with a dataset."""
508
- return self.get_dataset(dataset_rid).list_dataset_members()
153
+ if dataset_rid in self.bag_rids:
154
+ return self.bag_rids[dataset_rid]
155
+ raise DerivaMLException(f"Dataset {dataset_rid} not found in this bag")
509
156
 
510
157
  def _get_table_contents(self, table: str) -> Generator[dict[str, Any], None, None]:
511
- """Retrieve the contents of the specified table as a dictionary.
158
+ """Retrieve table contents as dictionaries.
159
+
160
+ This method provides compatibility with existing code that uses
161
+ _get_table_contents. New code should use get_table_contents instead.
512
162
 
513
163
  Args:
514
- table: Table to retrieve data from. If schema is not provided as part of the table name,
515
- the method will attempt to locate the schema for the table.
164
+ table: Table name.
516
165
 
517
- Returns:
518
- A generator producing dictionaries containing the contents of the specified table as name/value pairs.
166
+ Yields:
167
+ Dictionary for each row.
519
168
  """
169
+ yield from self.get_table_contents(table)
520
170
 
521
- with self.engine.connect() as conn:
522
- result = conn.execute(select(self.find_table(table)))
523
- for row in result.mappings():
524
- yield dict(row)
171
+ def _get_dataset_execution(self, dataset_rid: str) -> dict[str, Any] | None:
172
+ """Get the execution associated with a dataset version.
525
173
 
526
- @staticmethod
527
- def rid_lookup(dataset_rid: RID) -> list[tuple[DatasetVersion, "DatabaseModel"]]:
528
- """Return a list of DatasetVersion/DatabaseModel instances corresponding to the given RID.
174
+ Looks up the Dataset_Version record for the dataset's version in this bag
175
+ and returns the associated execution information.
529
176
 
530
177
  Args:
531
- dataset_rid: Rit to be looked up.
178
+ dataset_rid: Dataset RID to look up.
532
179
 
533
180
  Returns:
534
- List of DatasetVersion/DatabaseModel instances corresponding to the given RID.
535
-
536
- Raises:
537
- Raise a DerivaMLException if the given RID is not found.
181
+ Dataset_Version row as dict, or None if not found.
182
+ The 'Execution' field contains the execution RID (may be None).
538
183
  """
539
- try:
540
- return DatabaseModel._rid_map[dataset_rid]
541
- except KeyError:
542
- raise DerivaMLException(f"Dataset {dataset_rid} not found")
543
-
544
- def get_orm_class_by_name(self, table: str) -> Any | None:
545
- sql_table = self.find_table(table)
546
- if sql_table is None:
547
- raise DerivaMLException(f"Table {table} not found")
548
- return self.get_orm_class_for_table(sql_table)
549
-
550
- @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
551
- def get_orm_class_for_table(self, table: SQLTable | DerivaTable | str) -> Any | None:
552
- if isinstance(table, DerivaTable):
553
- table = self.metadata.tables[f"{table.schema.name}.{table.name}"]
554
- if isinstance(table, str):
555
- table = self.find_table(table)
556
- for mapper in self.Base.registry.mappers:
557
- if mapper.persist_selectable is table or table in mapper.tables:
558
- return mapper.class_
559
- return None
560
-
561
- @staticmethod
562
- def _is_association(
563
- table_class, min_arity=2, max_arity=2, unqualified=True, pure=True, no_overlap=True, return_fkeys=False
564
- ):
565
- """Return (truthy) integer arity if self is a matching association, else False.
566
-
567
- min_arity: minimum number of associated fkeys (default 2)
568
- max_arity: maximum number of associated fkeys (default 2) or None
569
- unqualified: reject qualified associations when True (default True)
570
- pure: reject impure associations when True (default True)
571
- no_overlap: reject overlapping associations when True (default True)
572
- return_fkeys: return the set of N associated ForeignKeys if True
573
-
574
- The default behavior with no arguments is to test for pure,
575
- unqualified, non-overlapping, binary associations.
576
-
577
- An association is comprised of several foreign keys which are
578
- covered by a non-nullable composite row key. This allows
579
- specific combinations of foreign keys to appear at most once.
580
-
581
- The arity of an association is the number of foreign keys
582
- being associated. A typical binary association has arity=2.
583
-
584
- An unqualified association contains *only* the foreign key
585
- material in its row key. Conversely, a qualified association
586
- mixes in other material which means that a specific
587
- combination of foreign keys may repeat with different
588
- qualifiers.
589
-
590
- A pure association contains *only* row key
591
- material. Conversely, an impure association includes
592
- additional metadata columns not covered by the row key. Unlike
593
- qualifiers, impure metadata merely decorates an association
594
- without augmenting its identifying characteristics.
595
-
596
- A non-overlapping association does not share any columns
597
- between multiple foreign keys. This means that all
598
- combinations of foreign keys are possible. Conversely, an
599
- overlapping association shares some columns between multiple
600
- foreign keys, potentially limiting the combinations which can
601
- be represented in an association row.
602
-
603
- These tests ignore the five ERMrest system columns and any
604
- corresponding constraints.
184
+ version = self.bag_rids.get(dataset_rid)
185
+ if not version:
186
+ return None
605
187
 
606
- """
607
- if min_arity < 2:
608
- raise ValueError("An assocation cannot have arity < 2")
609
- if max_arity is not None and max_arity < min_arity:
610
- raise ValueError("max_arity cannot be less than min_arity")
611
-
612
- mapper = inspect(table_class).mapper
613
-
614
- # TODO: revisit whether there are any other cases we might
615
- # care about where system columns are involved?
616
- non_sys_cols = {col.name for col in mapper.columns if col.name not in {"RID", "RCT", "RMT", "RCB", "RMB"}}
617
- unique_columns = [
618
- {c.name for c in constraint.columns}
619
- for constraint in inspect(table_class).local_table.constraints
620
- if isinstance(constraint, SQLUniqueConstraint)
621
- ]
622
-
623
- non_sys_key_colsets = {
624
- frozenset(unique_column_set)
625
- for unique_column_set in unique_columns
626
- if unique_column_set.issubset(non_sys_cols) and len(unique_column_set) > 1
627
- }
628
-
629
- if not non_sys_key_colsets:
630
- # reject: not association
631
- return False
632
-
633
- # choose longest compound key (arbitrary choice with ties!)
634
- row_key = sorted(non_sys_key_colsets, key=lambda s: len(s), reverse=True)[0]
635
- foreign_keys = [constraint for constraint in inspect(table_class).relationships.values()]
636
-
637
- covered_fkeys = {fkey for fkey in foreign_keys if {c.name for c in fkey.local_columns}.issubset(row_key)}
638
- covered_fkey_cols = set()
639
-
640
- if len(covered_fkeys) < min_arity:
641
- # reject: not enough fkeys in association
642
- return False
643
- elif max_arity is not None and len(covered_fkeys) > max_arity:
644
- # reject: too many fkeys in association
645
- return False
646
-
647
- for fkey in covered_fkeys:
648
- fkcols = {c.name for c in fkey.local_columns}
649
- if no_overlap and fkcols.intersection(covered_fkey_cols):
650
- # reject: overlapping fkeys in association
651
- return False
652
- covered_fkey_cols.update(fkcols)
653
-
654
- if unqualified and row_key.difference(covered_fkey_cols):
655
- # reject: qualified association
656
- return False
657
-
658
- if pure and non_sys_cols.difference(row_key):
659
- # reject: impure association
660
- return False
661
-
662
- # return (truthy) arity or fkeys
663
- if return_fkeys:
664
- return covered_fkeys
665
- else:
666
- return len(covered_fkeys)
667
-
668
- def get_orm_association_class(
669
- self,
670
- left_cls: Type[Any],
671
- right_cls: Type[Any],
672
- min_arity=2,
673
- max_arity=2,
674
- unqualified=True,
675
- pure=True,
676
- no_overlap=True,
677
- ):
678
- """
679
- Find an association class C by: (1) walking rels on left_cls to a mid class C,
680
- (2) verifying C also relates to right_cls. Returns (C, C->left, C->right) or None.
188
+ dataset_version_table = self.find_table("Dataset_Version")
189
+ cmd = select(dataset_version_table).where(
190
+ dataset_version_table.columns.Dataset == dataset_rid,
191
+ dataset_version_table.columns.Version == str(version),
192
+ )
681
193
 
682
- """
683
- for _, left_rel in inspect(left_cls).relationships.items():
684
- mid_cls = left_rel.mapper.class_
685
- is_assoc = self._is_association(mid_cls, return_fkeys=True)
686
- if not is_assoc:
687
- continue
688
- assoc_local_columns_left = list(is_assoc)[0].local_columns
689
- assoc_local_columns_right = list(is_assoc)[1].local_columns
690
-
691
- found_left = found_right = False
692
- for r in inspect(left_cls).relationships.values():
693
- remote_side = list(r.remote_side)[0]
694
- if remote_side in assoc_local_columns_left:
695
- found_left = r
696
- if remote_side in assoc_local_columns_right:
697
- found_left = r
698
- # We have left and right backwards from the assocation, so swap them.
699
- assoc_local_columns_left, assoc_local_columns_right = (
700
- assoc_local_columns_right,
701
- assoc_local_columns_left,
702
- )
703
- for r in inspect(right_cls).relationships.values():
704
- remote_side = list(r.remote_side)[0]
705
- if remote_side in assoc_local_columns_right:
706
- found_right = r
707
- if found_left != False and found_right != False:
708
- return mid_cls, found_left.class_attribute, found_right.class_attribute
709
- return None
710
-
711
- def delete_database(self):
712
- """
194
+ with Session(self.engine) as session:
195
+ result = session.execute(cmd).mappings().first()
196
+ return dict(result) if result else None
713
197
 
714
- Args:
198
+ # Compatibility aliases for methods that have different names in BagDatabase
199
+ def get_orm_association_class(self, left_cls, right_cls, **kwargs):
200
+ """Find association class between two ORM classes.
715
201
 
716
- Returns:
202
+ Wrapper around BagDatabase.get_association_class for compatibility.
203
+ """
204
+ return self.get_association_class(left_cls, right_cls)
205
+
206
+ def delete_database(self) -> None:
207
+ """Delete the database files.
717
208
 
209
+ Note: This method is deprecated. Use dispose() and manually remove
210
+ the database directory if needed.
718
211
  """
719
- self.dbase_file.unlink()
212
+ self.dispose()
213
+ # Note: We don't actually delete files here to avoid data loss.
214
+ # The caller should handle file deletion if needed.