deriva-ml 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/.DS_Store +0 -0
- deriva_ml/__init__.py +0 -10
- deriva_ml/core/base.py +18 -6
- deriva_ml/dataset/__init__.py +2 -7
- deriva_ml/dataset/aux_classes.py +21 -11
- deriva_ml/dataset/dataset.py +5 -4
- deriva_ml/dataset/dataset_bag.py +144 -151
- deriva_ml/dataset/upload.py +6 -4
- deriva_ml/demo_catalog.py +16 -2
- deriva_ml/execution/__init__.py +2 -1
- deriva_ml/execution/execution.py +4 -2
- deriva_ml/execution/execution_configuration.py +28 -9
- deriva_ml/execution/workflow.py +8 -0
- deriva_ml/model/catalog.py +55 -50
- deriva_ml/model/database.py +455 -81
- deriva_ml/test.py +94 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/METADATA +9 -7
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/RECORD +22 -21
- deriva_ml/model/sql_mapper.py +0 -44
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/WHEEL +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/entry_points.txt +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.16.0.dist-info → deriva_ml-1.17.0.dist-info}/top_level.txt +0 -0
deriva_ml/model/database.py
CHANGED
|
@@ -6,20 +6,46 @@ from __future__ import annotations
|
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
8
|
import logging
|
|
9
|
-
import sqlite3
|
|
10
9
|
from csv import reader
|
|
11
10
|
from pathlib import Path
|
|
12
|
-
from typing import Any, Generator, Optional
|
|
11
|
+
from typing import Any, Generator, Optional, Type
|
|
13
12
|
from urllib.parse import urlparse
|
|
14
13
|
|
|
14
|
+
from dateutil import parser
|
|
15
|
+
from deriva.core.ermrest_model import Column as DerivaColumn
|
|
15
16
|
from deriva.core.ermrest_model import Model
|
|
17
|
+
from deriva.core.ermrest_model import Table as DerivaTable
|
|
18
|
+
from deriva.core.ermrest_model import Type as DerivaType
|
|
19
|
+
from pydantic import ConfigDict, validate_call
|
|
20
|
+
from sqlalchemy import (
|
|
21
|
+
JSON,
|
|
22
|
+
Boolean,
|
|
23
|
+
Date,
|
|
24
|
+
DateTime,
|
|
25
|
+
Float,
|
|
26
|
+
Integer,
|
|
27
|
+
MetaData,
|
|
28
|
+
String,
|
|
29
|
+
create_engine,
|
|
30
|
+
event,
|
|
31
|
+
inspect,
|
|
32
|
+
select,
|
|
33
|
+
)
|
|
34
|
+
from sqlalchemy import Column as SQLColumn
|
|
35
|
+
from sqlalchemy import ForeignKeyConstraint as SQLForeignKeyConstraint
|
|
36
|
+
from sqlalchemy import Table as SQLTable
|
|
37
|
+
from sqlalchemy import UniqueConstraint as SQLUniqueConstraint
|
|
38
|
+
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
39
|
+
from sqlalchemy.ext.automap import automap_base
|
|
40
|
+
from sqlalchemy.orm import backref, configure_mappers, foreign, relationship
|
|
41
|
+
from sqlalchemy.sql.type_api import TypeEngine
|
|
42
|
+
from sqlalchemy.types import TypeDecorator
|
|
16
43
|
|
|
17
44
|
from deriva_ml.core.definitions import ML_SCHEMA, RID, MLVocab
|
|
18
45
|
from deriva_ml.core.exceptions import DerivaMLException
|
|
19
46
|
from deriva_ml.dataset.aux_classes import DatasetMinid, DatasetVersion
|
|
20
47
|
from deriva_ml.dataset.dataset_bag import DatasetBag
|
|
21
48
|
from deriva_ml.model.catalog import DerivaModel
|
|
22
|
-
from deriva_ml.model.sql_mapper import SQLMapper
|
|
23
49
|
|
|
24
50
|
try:
|
|
25
51
|
from icecream import ic
|
|
@@ -27,8 +53,63 @@ except ImportError: # Graceful fallback if IceCream isn't installed.
|
|
|
27
53
|
ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
|
|
28
54
|
|
|
29
55
|
|
|
56
|
+
class ERMRestBoolean(TypeDecorator):
|
|
57
|
+
impl = Boolean
|
|
58
|
+
cache_ok = True
|
|
59
|
+
|
|
60
|
+
def process_bind_param(self, value, dialect):
|
|
61
|
+
if value in ("Y", "y", 1, True, "t", "T"):
|
|
62
|
+
return True
|
|
63
|
+
elif value in ("N", "n", 0, False, "f", "F"):
|
|
64
|
+
return False
|
|
65
|
+
elif value is None:
|
|
66
|
+
return None
|
|
67
|
+
raise ValueError(f"Invalid boolean value: {value!r}")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class StringToFloat(TypeDecorator):
|
|
71
|
+
impl = Float
|
|
72
|
+
cache_ok = True
|
|
73
|
+
|
|
74
|
+
def process_bind_param(self, value, dialect):
|
|
75
|
+
if value == "" or value is None:
|
|
76
|
+
return None
|
|
77
|
+
else:
|
|
78
|
+
return float(value)
|
|
79
|
+
|
|
80
|
+
class StringToInteger(TypeDecorator):
|
|
81
|
+
impl = Integer
|
|
82
|
+
cache_ok = True
|
|
83
|
+
|
|
84
|
+
def process_bind_param(self, value, dialect):
|
|
85
|
+
if value == "" or value is None:
|
|
86
|
+
return None
|
|
87
|
+
else:
|
|
88
|
+
return int(value)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class StringToDateTime(TypeDecorator):
|
|
92
|
+
impl = DateTime
|
|
93
|
+
cache_ok = True
|
|
94
|
+
|
|
95
|
+
def process_bind_param(self, value, dialect):
|
|
96
|
+
if value == "" or value is None:
|
|
97
|
+
return None
|
|
98
|
+
else:
|
|
99
|
+
return parser.parse(value)
|
|
100
|
+
|
|
101
|
+
class StringToDate(TypeDecorator):
|
|
102
|
+
impl = Date
|
|
103
|
+
cache_ok = True
|
|
104
|
+
|
|
105
|
+
def process_bind_param(self, value, dialect):
|
|
106
|
+
if value == "" or value is None:
|
|
107
|
+
return None
|
|
108
|
+
else:
|
|
109
|
+
return parser.parse(value).date()
|
|
110
|
+
|
|
30
111
|
class DatabaseModelMeta(type):
|
|
31
|
-
"""Use metaclass to ensure that there is
|
|
112
|
+
"""Use metaclass to ensure that there is only one instance of a database model per path"""
|
|
32
113
|
|
|
33
114
|
_paths_loaded: dict[Path, "DatabaseModel"] = {}
|
|
34
115
|
|
|
@@ -65,7 +146,7 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
65
146
|
bag_path (Path): path to the local copy of the BDBag
|
|
66
147
|
minid (DatasetMinid): Minid for the specified bag
|
|
67
148
|
dataset_rid (RID): RID for the specified dataset
|
|
68
|
-
|
|
149
|
+
engine (Connection): connection to the sqlalchemy database holding table values
|
|
69
150
|
domain_schema (str): Name of the domain schema
|
|
70
151
|
dataset_table (Table): the dataset table in the ERMRest model.
|
|
71
152
|
"""
|
|
@@ -73,24 +154,6 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
73
154
|
# Maintain a global map of RIDS to versions and databases.
|
|
74
155
|
_rid_map: dict[RID, list[tuple[DatasetVersion, "DatabaseModel"]]] = {}
|
|
75
156
|
|
|
76
|
-
@staticmethod
|
|
77
|
-
def rid_lookup(dataset_rid: RID) -> list[tuple[DatasetVersion, "DatabaseModel"]]:
|
|
78
|
-
"""Return a list of DatasetVersion/DatabaseModel instances corresponding to the given RID.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
dataset_rid: Rit to be looked up.
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
List of DatasetVersion/DatabaseModel instances corresponding to the given RID.
|
|
85
|
-
|
|
86
|
-
Raises:
|
|
87
|
-
Raise a DerivaMLException if the given RID is not found.
|
|
88
|
-
"""
|
|
89
|
-
try:
|
|
90
|
-
return DatabaseModel._rid_map[dataset_rid]
|
|
91
|
-
except KeyError:
|
|
92
|
-
raise DerivaMLException(f"Dataset {dataset_rid} not found")
|
|
93
|
-
|
|
94
157
|
def __init__(self, minid: DatasetMinid, bag_path: Path, dbase_path: Path):
|
|
95
158
|
"""Create a new DatabaseModel.
|
|
96
159
|
|
|
@@ -99,33 +162,41 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
99
162
|
bag_path: Path to the local copy of the BDBag.
|
|
100
163
|
"""
|
|
101
164
|
|
|
165
|
+
super().__init__(Model.fromfile("file-system", bag_path / "data/schema.json"))
|
|
166
|
+
|
|
102
167
|
self.bag_path = bag_path
|
|
103
168
|
self.minid = minid
|
|
104
169
|
self.dataset_rid = minid.dataset_rid
|
|
105
|
-
self.
|
|
106
|
-
self.
|
|
170
|
+
self.dbase_path = dbase_path / f"{minid.version_rid}"
|
|
171
|
+
self.dbase_path.mkdir(parents=True, exist_ok=True)
|
|
172
|
+
|
|
173
|
+
self.engine = create_engine(f"sqlite:///{(self.dbase_path / 'main.db').resolve()}", future=True)
|
|
174
|
+
self.metadata = MetaData()
|
|
175
|
+
self.Base = automap_base(metadata=self.metadata)
|
|
176
|
+
|
|
177
|
+
# Attach event listener for *this instance's* engine
|
|
178
|
+
event.listen(self.engine, "connect", self._attach_schemas)
|
|
107
179
|
|
|
108
180
|
schema_file = self.bag_path / "data/schema.json"
|
|
109
181
|
with schema_file.open("r") as f:
|
|
110
182
|
self.snaptime = json.load(f)["snaptime"]
|
|
111
183
|
|
|
112
|
-
super().__init__(Model.fromfile("file-system", self.bag_path / "data/schema.json"))
|
|
113
184
|
self._logger = logging.getLogger("deriva_ml")
|
|
114
185
|
self._load_model()
|
|
115
186
|
self.ml_schema = ML_SCHEMA
|
|
116
|
-
self.
|
|
187
|
+
self._load_database()
|
|
117
188
|
self._logger.info(
|
|
118
189
|
"Creating new database for dataset: %s in %s",
|
|
119
190
|
self.dataset_rid,
|
|
120
|
-
self.
|
|
191
|
+
self.dbase_path,
|
|
121
192
|
)
|
|
122
193
|
self.dataset_table = self.model.schemas[self.ml_schema].tables["Dataset"]
|
|
194
|
+
|
|
123
195
|
# Now go through the database and pick out all the dataset_table RIDS, along with their versions.
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
]
|
|
196
|
+
with self.engine.connect() as conn:
|
|
197
|
+
dataset_version = self.metadata.tables[f"{self.ml_schema}.Dataset_Version"]
|
|
198
|
+
result = conn.execute(select(dataset_version.c.Dataset, dataset_version.c.Version))
|
|
199
|
+
dataset_versions = [t for t in result]
|
|
129
200
|
|
|
130
201
|
dataset_versions = [(v[0], DatasetVersion.parse(v[1])) for v in dataset_versions]
|
|
131
202
|
# Get most current version of each rid
|
|
@@ -137,16 +208,159 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
137
208
|
version_list = DatabaseModel._rid_map.setdefault(dataset_rid, [])
|
|
138
209
|
version_list.append((dataset_version, self))
|
|
139
210
|
|
|
211
|
+
def _attach_schemas(self, dbapi_conn, _conn_record):
|
|
212
|
+
cur = dbapi_conn.cursor()
|
|
213
|
+
for schema in [self.domain_schema, self.ml_schema]:
|
|
214
|
+
schema_file = (self.dbase_path / f"{schema}.db").resolve()
|
|
215
|
+
cur.execute(f"ATTACH DATABASE '{schema_file}' AS '{schema}'")
|
|
216
|
+
cur.close()
|
|
217
|
+
|
|
218
|
+
@staticmethod
|
|
219
|
+
def _sql_type(type: DerivaType) -> TypeEngine:
|
|
220
|
+
"""Return the SQL type for a Deriva column."""
|
|
221
|
+
return {
|
|
222
|
+
"boolean": ERMRestBoolean,
|
|
223
|
+
"date": StringToDate,
|
|
224
|
+
"float4": StringToFloat,
|
|
225
|
+
"float8": StringToFloat,
|
|
226
|
+
"int2": StringToInteger,
|
|
227
|
+
"int4": StringToInteger,
|
|
228
|
+
"int8": StringToInteger,
|
|
229
|
+
"json": JSON,
|
|
230
|
+
"jsonb": JSON,
|
|
231
|
+
"timestamptz": StringToDateTime,
|
|
232
|
+
"timestamp": StringToDateTime,
|
|
233
|
+
}.get(type.typename, String)
|
|
234
|
+
|
|
140
235
|
def _load_model(self) -> None:
|
|
141
236
|
"""Create a sqlite database schema that contains all the tables within the catalog from which the BDBag
|
|
142
237
|
was created."""
|
|
143
|
-
with self.dbase:
|
|
144
|
-
for t in self.model.schemas[self.domain_schema].tables.values():
|
|
145
|
-
self.dbase.execute(t.sqlite3_ddl())
|
|
146
|
-
for t in self.model.schemas["deriva-ml"].tables.values():
|
|
147
|
-
self.dbase.execute(t.sqlite3_ddl())
|
|
148
238
|
|
|
149
|
-
|
|
239
|
+
def is_key(column: DerivaColumn, table: DerivaTable) -> bool:
|
|
240
|
+
return column in [key.unique_columns[0] for key in table.keys] and column.name == "RID"
|
|
241
|
+
|
|
242
|
+
def col(model, name: str):
|
|
243
|
+
# try ORM attribute first
|
|
244
|
+
try:
|
|
245
|
+
return getattr(model, name).property.columns[0]
|
|
246
|
+
except AttributeError:
|
|
247
|
+
# fall back to exact DB column key on the Table
|
|
248
|
+
return model.__table__.c[name]
|
|
249
|
+
|
|
250
|
+
def guess_attr_name(col_name: str) -> str:
|
|
251
|
+
return col_name[:-3] if col_name.lower().endswith("_id") else col_name
|
|
252
|
+
|
|
253
|
+
database_tables: list[SQLTable] = []
|
|
254
|
+
for schema_name in [self.domain_schema, self.ml_schema]:
|
|
255
|
+
for table in self.model.schemas[schema_name].tables.values():
|
|
256
|
+
database_columns: list[SQLColumn] = []
|
|
257
|
+
for c in table.columns:
|
|
258
|
+
# clone column (type, nullability, PK, defaults, unique)
|
|
259
|
+
database_column = SQLColumn(
|
|
260
|
+
name=c.name,
|
|
261
|
+
type_=self._sql_type(c.type), # SQLAlchemy type object is reusable
|
|
262
|
+
comment=c.comment,
|
|
263
|
+
default=c.default,
|
|
264
|
+
primary_key=is_key(c, table),
|
|
265
|
+
nullable=c.nullok,
|
|
266
|
+
# NOTE: server_onupdate, computed, etc. can be added if you use them
|
|
267
|
+
)
|
|
268
|
+
database_columns.append(database_column)
|
|
269
|
+
database_table = SQLTable(table.name, self.metadata, *database_columns, schema=schema_name)
|
|
270
|
+
for key in table.keys:
|
|
271
|
+
key_columns = [c.name for c in key.unique_columns]
|
|
272
|
+
# if key.name[0] == "RID":
|
|
273
|
+
# continue
|
|
274
|
+
database_table.append_constraint(
|
|
275
|
+
SQLUniqueConstraint(
|
|
276
|
+
*key_columns,
|
|
277
|
+
name=key.name[1],
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
for fk in table.foreign_keys:
|
|
281
|
+
if fk.pk_table.schema.name not in [self.domain_schema, self.ml_schema]:
|
|
282
|
+
continue
|
|
283
|
+
if fk.pk_table.schema.name != schema_name:
|
|
284
|
+
continue
|
|
285
|
+
# Attach FK to the chosen column
|
|
286
|
+
database_table.append_constraint(
|
|
287
|
+
SQLForeignKeyConstraint(
|
|
288
|
+
columns=[f"{c.name}" for c in fk.foreign_key_columns],
|
|
289
|
+
refcolumns=[f"{schema_name}.{c.table.name}.{c.name}" for c in fk.referenced_columns],
|
|
290
|
+
name=fk.name[1],
|
|
291
|
+
comment=fk.comment,
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
database_tables.append(database_table)
|
|
295
|
+
with self.engine.begin() as conn:
|
|
296
|
+
self.metadata.create_all(conn, tables=database_tables)
|
|
297
|
+
|
|
298
|
+
def name_for_scalar_relationship(_base, local_cls, referred_cls, constraint):
|
|
299
|
+
cols = list(constraint.columns) if constraint is not None else []
|
|
300
|
+
if len(cols) == 1:
|
|
301
|
+
name = cols[0].key
|
|
302
|
+
if name in {c.key for c in local_cls.__table__.columns}:
|
|
303
|
+
name += "_rel"
|
|
304
|
+
return name
|
|
305
|
+
return constraint.name or referred_cls.__name__.lower()
|
|
306
|
+
|
|
307
|
+
def name_for_collection_relationship(_base, local_cls, referred_cls, constraint):
|
|
308
|
+
backref_name = constraint.name.replace("_fkey", "_collection")
|
|
309
|
+
return backref_name or (referred_cls.__name__.lower() + "_collection")
|
|
310
|
+
|
|
311
|
+
# Now build ORM mappings for the tables.
|
|
312
|
+
self.Base.prepare(
|
|
313
|
+
self.engine,
|
|
314
|
+
name_for_scalar_relationship=name_for_scalar_relationship,
|
|
315
|
+
name_for_collection_relationship=name_for_collection_relationship,
|
|
316
|
+
reflect=True,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
for schema in [self.domain_schema, self.ml_schema]:
|
|
320
|
+
for table in self.model.schemas[schema].tables.values():
|
|
321
|
+
for fk in table.foreign_keys:
|
|
322
|
+
if fk.pk_table.schema.name not in [self.domain_schema, self.ml_schema]:
|
|
323
|
+
continue
|
|
324
|
+
if fk.pk_table.schema.name == schema:
|
|
325
|
+
continue
|
|
326
|
+
table_name = f"{schema}.{table.name}"
|
|
327
|
+
table_class = self.get_orm_class_by_name(table_name)
|
|
328
|
+
foreign_key_column_name = fk.foreign_key_columns[0].name
|
|
329
|
+
foreign_key_column = col(table_class, foreign_key_column_name)
|
|
330
|
+
|
|
331
|
+
referenced_table_name = f"{fk.pk_table.schema.name}.{fk.pk_table.name}"
|
|
332
|
+
referenced_class = self.get_orm_class_by_name(referenced_table_name)
|
|
333
|
+
referenced_column = col(referenced_class, fk.referenced_columns[0].name)
|
|
334
|
+
|
|
335
|
+
relationship_attr = guess_attr_name(foreign_key_column_name)
|
|
336
|
+
backref_attr = fk.name[1].replace("_fkey", "_collection")
|
|
337
|
+
setattr(
|
|
338
|
+
table_class,
|
|
339
|
+
relationship_attr,
|
|
340
|
+
relationship(
|
|
341
|
+
referenced_class,
|
|
342
|
+
foreign_keys=[foreign_key_column],
|
|
343
|
+
primaryjoin=foreign(foreign_key_column) == referenced_column,
|
|
344
|
+
backref=backref(backref_attr, viewonly=True),
|
|
345
|
+
viewonly=True, # set False for write behavior, but best with proper FKs
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Reflect won't pick up the second FK in the dataset_dataset table. We need to do it manually
|
|
350
|
+
# dataset_dataset_class = self.get_orm_class_by_name("deriva-ml.Dataset_Dataset")
|
|
351
|
+
# dataset_class = self.get_orm_class_by_name("deriva-ml.Dataset")
|
|
352
|
+
# dataset_dataset_class.Nested_Dataset = relationship(
|
|
353
|
+
# dataset_class,
|
|
354
|
+
# primaryjoin=foreign(dataset_dataset_class.__table__.c["Nested_Dataset"])
|
|
355
|
+
# == dataset_class.__table__.c["RID"],
|
|
356
|
+
# foreign_keys=[dataset_dataset_class.__table__.c["Nested_Dataset"]],
|
|
357
|
+
# backref=backref("nested_dataset_collection", viewonly=True), # pick a distinct name
|
|
358
|
+
# viewonly=True, # keep it read-only unless you truly want writes
|
|
359
|
+
# overlaps="Dataset,dataset_dataset_collection", # optional: silence overlap warnings
|
|
360
|
+
# )
|
|
361
|
+
configure_mappers()
|
|
362
|
+
|
|
363
|
+
def _load_database(self) -> None:
|
|
150
364
|
"""Load a SQLite database from a bdbag. THis is done by looking for all the CSV files in the bdbag directory.
|
|
151
365
|
|
|
152
366
|
If the file is for an asset table, update the FileName column of the table to have the local file path for
|
|
@@ -160,6 +374,7 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
160
374
|
for csv_file in Path(dpath).rglob("*.csv"):
|
|
161
375
|
table = csv_file.stem
|
|
162
376
|
schema = self.domain_schema if table in self.model.schemas[self.domain_schema].tables else self.ml_schema
|
|
377
|
+
sql_table = self.metadata.tables[f"{schema}.{table}"]
|
|
163
378
|
|
|
164
379
|
with csv_file.open(newline="") as csvfile:
|
|
165
380
|
csv_reader = reader(csvfile)
|
|
@@ -170,15 +385,14 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
170
385
|
(column_names.index("Filename"), column_names.index("URL")) if self._is_asset(table) else None
|
|
171
386
|
)
|
|
172
387
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
object_table,
|
|
388
|
+
with self.engine.begin() as conn:
|
|
389
|
+
object_table = [
|
|
390
|
+
self._localize_asset(o, asset_indexes, asset_map, table == "Dataset")
|
|
391
|
+
for o in csv_reader
|
|
392
|
+
]
|
|
393
|
+
conn.execute(
|
|
394
|
+
sqlite_insert(sql_table).on_conflict_do_nothing(),
|
|
395
|
+
[dict(zip(column_names, row)) for row in object_table],
|
|
182
396
|
)
|
|
183
397
|
|
|
184
398
|
def _localize_asset_table(self) -> dict[str, str]:
|
|
@@ -237,19 +451,21 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
237
451
|
o[file_column] = asset_map[o[url_column]] if o[url_column] else ""
|
|
238
452
|
return tuple(o)
|
|
239
453
|
|
|
454
|
+
def find_table(self, table_name: str) -> SQLTable:
|
|
455
|
+
"""Find a table in the catalog."""
|
|
456
|
+
# We will look across ml and domain schema to find a table whose name matches.
|
|
457
|
+
table = [t for t in self.metadata.tables if t == table_name or t.split(".")[1] == table_name][0]
|
|
458
|
+
return self.metadata.tables[table]
|
|
459
|
+
|
|
240
460
|
def list_tables(self) -> list[str]:
|
|
241
461
|
"""List the names of the tables in the catalog
|
|
242
462
|
|
|
243
463
|
Returns:
|
|
244
464
|
A list of table names. These names are all qualified with the Deriva schema name.
|
|
245
465
|
"""
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
for t in self.dbase.execute(
|
|
250
|
-
"SELECT name FROM sqlite_master WHERE type = 'table' ORDER BY name;"
|
|
251
|
-
).fetchall()
|
|
252
|
-
]
|
|
466
|
+
tables = list(self.metadata.tables.keys())
|
|
467
|
+
tables.sort()
|
|
468
|
+
return tables
|
|
253
469
|
|
|
254
470
|
def get_dataset(self, dataset_rid: Optional[RID] = None) -> DatasetBag:
|
|
255
471
|
"""Get a dataset, or nested dataset from the bag database
|
|
@@ -281,8 +497,8 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
281
497
|
|
|
282
498
|
# Get a list of all the dataset_type values associated with this dataset_table.
|
|
283
499
|
datasets = []
|
|
284
|
-
ds_types = list(self.
|
|
285
|
-
for dataset in self.
|
|
500
|
+
ds_types = list(self._get_table_contents(atable))
|
|
501
|
+
for dataset in self._get_table_contents("Dataset"):
|
|
286
502
|
my_types = [t for t in ds_types if t["Dataset"] == dataset["RID"]]
|
|
287
503
|
datasets.append(dataset | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in my_types]})
|
|
288
504
|
return datasets
|
|
@@ -291,48 +507,206 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
|
|
|
291
507
|
"""Returns a list of all the dataset_table entries associated with a dataset."""
|
|
292
508
|
return self.get_dataset(dataset_rid).list_dataset_members()
|
|
293
509
|
|
|
294
|
-
def
|
|
510
|
+
def _get_table_contents(self, table: str) -> Generator[dict[str, Any], None, None]:
|
|
295
511
|
"""Retrieve the contents of the specified table as a dictionary.
|
|
296
512
|
|
|
297
513
|
Args:
|
|
298
|
-
table: Table to retrieve data from.
|
|
514
|
+
table: Table to retrieve data from. If schema is not provided as part of the table name,
|
|
299
515
|
the method will attempt to locate the schema for the table.
|
|
300
516
|
|
|
301
517
|
Returns:
|
|
302
518
|
A generator producing dictionaries containing the contents of the specified table as name/value pairs.
|
|
303
519
|
"""
|
|
304
|
-
table_name = self.normalize_table_name(table)
|
|
305
|
-
table = self.name_to_table(table)
|
|
306
520
|
|
|
307
|
-
with self.
|
|
308
|
-
|
|
309
|
-
|
|
521
|
+
with self.engine.connect() as conn:
|
|
522
|
+
result = conn.execute(select(self.find_table(table)))
|
|
523
|
+
for row in result.mappings():
|
|
524
|
+
yield dict(row)
|
|
310
525
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
def normalize_table_name(self, table: str) -> str:
|
|
315
|
-
"""Attempt to insert the schema into a table name if it's not provided.
|
|
526
|
+
@staticmethod
|
|
527
|
+
def rid_lookup(dataset_rid: RID) -> list[tuple[DatasetVersion, "DatabaseModel"]]:
|
|
528
|
+
"""Return a list of DatasetVersion/DatabaseModel instances corresponding to the given RID.
|
|
316
529
|
|
|
317
530
|
Args:
|
|
318
|
-
|
|
531
|
+
dataset_rid: Rit to be looked up.
|
|
319
532
|
|
|
320
533
|
Returns:
|
|
321
|
-
|
|
534
|
+
List of DatasetVersion/DatabaseModel instances corresponding to the given RID.
|
|
322
535
|
|
|
536
|
+
Raises:
|
|
537
|
+
Raise a DerivaMLException if the given RID is not found.
|
|
323
538
|
"""
|
|
324
539
|
try:
|
|
325
|
-
[
|
|
326
|
-
except ValueError:
|
|
327
|
-
tname = table
|
|
328
|
-
for sname in [self.domain_schema, self.ml_schema, "WWW"]: # Be careful of File table.
|
|
329
|
-
if sname in self.model.schemas and table in self.model.schemas[sname].tables:
|
|
330
|
-
break
|
|
331
|
-
try:
|
|
332
|
-
_ = self.model.schemas[sname].tables[tname]
|
|
333
|
-
return f"{sname}:{tname}"
|
|
540
|
+
return DatabaseModel._rid_map[dataset_rid]
|
|
334
541
|
except KeyError:
|
|
335
|
-
raise DerivaMLException(f
|
|
542
|
+
raise DerivaMLException(f"Dataset {dataset_rid} not found")
|
|
543
|
+
|
|
544
|
+
def get_orm_class_by_name(self, table: str) -> Any | None:
|
|
545
|
+
sql_table = self.find_table(table)
|
|
546
|
+
if sql_table is None:
|
|
547
|
+
raise DerivaMLException(f"Table {table} not found")
|
|
548
|
+
return self.get_orm_class_for_table(sql_table)
|
|
549
|
+
|
|
550
|
+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
|
|
551
|
+
def get_orm_class_for_table(self, table: SQLTable | DerivaTable | str) -> Any | None:
|
|
552
|
+
if isinstance(table, DerivaTable):
|
|
553
|
+
table = self.metadata.tables[f"{table.schema.name}.{table.name}"]
|
|
554
|
+
if isinstance(table, str):
|
|
555
|
+
table = self.find_table(table)
|
|
556
|
+
for mapper in self.Base.registry.mappers:
|
|
557
|
+
if mapper.persist_selectable is table or table in mapper.tables:
|
|
558
|
+
return mapper.class_
|
|
559
|
+
return None
|
|
560
|
+
|
|
561
|
+
@staticmethod
|
|
562
|
+
def _is_association(
|
|
563
|
+
table_class, min_arity=2, max_arity=2, unqualified=True, pure=True, no_overlap=True, return_fkeys=False
|
|
564
|
+
):
|
|
565
|
+
"""Return (truthy) integer arity if self is a matching association, else False.
|
|
566
|
+
|
|
567
|
+
min_arity: minimum number of associated fkeys (default 2)
|
|
568
|
+
max_arity: maximum number of associated fkeys (default 2) or None
|
|
569
|
+
unqualified: reject qualified associations when True (default True)
|
|
570
|
+
pure: reject impure associations when True (default True)
|
|
571
|
+
no_overlap: reject overlapping associations when True (default True)
|
|
572
|
+
return_fkeys: return the set of N associated ForeignKeys if True
|
|
573
|
+
|
|
574
|
+
The default behavior with no arguments is to test for pure,
|
|
575
|
+
unqualified, non-overlapping, binary associations.
|
|
576
|
+
|
|
577
|
+
An association is comprised of several foreign keys which are
|
|
578
|
+
covered by a non-nullable composite row key. This allows
|
|
579
|
+
specific combinations of foreign keys to appear at most once.
|
|
580
|
+
|
|
581
|
+
The arity of an association is the number of foreign keys
|
|
582
|
+
being associated. A typical binary association has arity=2.
|
|
583
|
+
|
|
584
|
+
An unqualified association contains *only* the foreign key
|
|
585
|
+
material in its row key. Conversely, a qualified association
|
|
586
|
+
mixes in other material which means that a specific
|
|
587
|
+
combination of foreign keys may repeat with different
|
|
588
|
+
qualifiers.
|
|
589
|
+
|
|
590
|
+
A pure association contains *only* row key
|
|
591
|
+
material. Conversely, an impure association includes
|
|
592
|
+
additional metadata columns not covered by the row key. Unlike
|
|
593
|
+
qualifiers, impure metadata merely decorates an association
|
|
594
|
+
without augmenting its identifying characteristics.
|
|
595
|
+
|
|
596
|
+
A non-overlapping association does not share any columns
|
|
597
|
+
between multiple foreign keys. This means that all
|
|
598
|
+
combinations of foreign keys are possible. Conversely, an
|
|
599
|
+
overlapping association shares some columns between multiple
|
|
600
|
+
foreign keys, potentially limiting the combinations which can
|
|
601
|
+
be represented in an association row.
|
|
602
|
+
|
|
603
|
+
These tests ignore the five ERMrest system columns and any
|
|
604
|
+
corresponding constraints.
|
|
605
|
+
|
|
606
|
+
"""
|
|
607
|
+
if min_arity < 2:
|
|
608
|
+
raise ValueError("An assocation cannot have arity < 2")
|
|
609
|
+
if max_arity is not None and max_arity < min_arity:
|
|
610
|
+
raise ValueError("max_arity cannot be less than min_arity")
|
|
611
|
+
|
|
612
|
+
mapper = inspect(table_class).mapper
|
|
613
|
+
|
|
614
|
+
# TODO: revisit whether there are any other cases we might
|
|
615
|
+
# care about where system columns are involved?
|
|
616
|
+
non_sys_cols = {col.name for col in mapper.columns if col.name not in {"RID", "RCT", "RMT", "RCB", "RMB"}}
|
|
617
|
+
unique_columns = [
|
|
618
|
+
{c.name for c in constraint.columns}
|
|
619
|
+
for constraint in inspect(table_class).local_table.constraints
|
|
620
|
+
if isinstance(constraint, SQLUniqueConstraint)
|
|
621
|
+
]
|
|
622
|
+
|
|
623
|
+
non_sys_key_colsets = {
|
|
624
|
+
frozenset(unique_column_set)
|
|
625
|
+
for unique_column_set in unique_columns
|
|
626
|
+
if unique_column_set.issubset(non_sys_cols) and len(unique_column_set) > 1
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
if not non_sys_key_colsets:
|
|
630
|
+
# reject: not association
|
|
631
|
+
return False
|
|
632
|
+
|
|
633
|
+
# choose longest compound key (arbitrary choice with ties!)
|
|
634
|
+
row_key = sorted(non_sys_key_colsets, key=lambda s: len(s), reverse=True)[0]
|
|
635
|
+
foreign_keys = [constraint for constraint in inspect(table_class).relationships.values()]
|
|
636
|
+
|
|
637
|
+
covered_fkeys = {fkey for fkey in foreign_keys if {c.name for c in fkey.local_columns}.issubset(row_key)}
|
|
638
|
+
covered_fkey_cols = set()
|
|
639
|
+
|
|
640
|
+
if len(covered_fkeys) < min_arity:
|
|
641
|
+
# reject: not enough fkeys in association
|
|
642
|
+
return False
|
|
643
|
+
elif max_arity is not None and len(covered_fkeys) > max_arity:
|
|
644
|
+
# reject: too many fkeys in association
|
|
645
|
+
return False
|
|
646
|
+
|
|
647
|
+
for fkey in covered_fkeys:
|
|
648
|
+
fkcols = {c.name for c in fkey.local_columns}
|
|
649
|
+
if no_overlap and fkcols.intersection(covered_fkey_cols):
|
|
650
|
+
# reject: overlapping fkeys in association
|
|
651
|
+
return False
|
|
652
|
+
covered_fkey_cols.update(fkcols)
|
|
653
|
+
|
|
654
|
+
if unqualified and row_key.difference(covered_fkey_cols):
|
|
655
|
+
# reject: qualified association
|
|
656
|
+
return False
|
|
657
|
+
|
|
658
|
+
if pure and non_sys_cols.difference(row_key):
|
|
659
|
+
# reject: impure association
|
|
660
|
+
return False
|
|
661
|
+
|
|
662
|
+
# return (truthy) arity or fkeys
|
|
663
|
+
if return_fkeys:
|
|
664
|
+
return covered_fkeys
|
|
665
|
+
else:
|
|
666
|
+
return len(covered_fkeys)
|
|
667
|
+
|
|
668
|
+
def get_orm_association_class(
|
|
669
|
+
self,
|
|
670
|
+
left_cls: Type[Any],
|
|
671
|
+
right_cls: Type[Any],
|
|
672
|
+
min_arity=2,
|
|
673
|
+
max_arity=2,
|
|
674
|
+
unqualified=True,
|
|
675
|
+
pure=True,
|
|
676
|
+
no_overlap=True,
|
|
677
|
+
):
|
|
678
|
+
"""
|
|
679
|
+
Find an association class C by: (1) walking rels on left_cls to a mid class C,
|
|
680
|
+
(2) verifying C also relates to right_cls. Returns (C, C->left, C->right) or None.
|
|
681
|
+
|
|
682
|
+
"""
|
|
683
|
+
for _, left_rel in inspect(left_cls).relationships.items():
|
|
684
|
+
mid_cls = left_rel.mapper.class_
|
|
685
|
+
is_assoc = self._is_association(mid_cls, return_fkeys=True)
|
|
686
|
+
if not is_assoc:
|
|
687
|
+
continue
|
|
688
|
+
assoc_local_columns_left = list(is_assoc)[0].local_columns
|
|
689
|
+
assoc_local_columns_right = list(is_assoc)[1].local_columns
|
|
690
|
+
|
|
691
|
+
found_left = found_right = False
|
|
692
|
+
for r in inspect(left_cls).relationships.values():
|
|
693
|
+
remote_side = list(r.remote_side)[0]
|
|
694
|
+
if remote_side in assoc_local_columns_left:
|
|
695
|
+
found_left = r
|
|
696
|
+
if remote_side in assoc_local_columns_right:
|
|
697
|
+
found_left = r
|
|
698
|
+
# We have left and right backwards from the assocation, so swap them.
|
|
699
|
+
assoc_local_columns_left, assoc_local_columns_right = (
|
|
700
|
+
assoc_local_columns_right,
|
|
701
|
+
assoc_local_columns_left,
|
|
702
|
+
)
|
|
703
|
+
for r in inspect(right_cls).relationships.values():
|
|
704
|
+
remote_side = list(r.remote_side)[0]
|
|
705
|
+
if remote_side in assoc_local_columns_right:
|
|
706
|
+
found_right = r
|
|
707
|
+
if found_left != False and found_right != False:
|
|
708
|
+
return mid_cls, found_left.class_attribute, found_right.class_attribute
|
|
709
|
+
return None
|
|
336
710
|
|
|
337
711
|
def delete_database(self):
|
|
338
712
|
"""
|