deriva-ml 1.16.0__py3-none-any.whl → 1.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,20 +6,46 @@ from __future__ import annotations
6
6
 
7
7
  import json
8
8
  import logging
9
- import sqlite3
10
9
  from csv import reader
11
10
  from pathlib import Path
12
- from typing import Any, Generator, Optional
11
+ from typing import Any, Generator, Optional, Type
13
12
  from urllib.parse import urlparse
14
13
 
14
+ from dateutil import parser
15
+ from deriva.core.ermrest_model import Column as DerivaColumn
15
16
  from deriva.core.ermrest_model import Model
17
+ from deriva.core.ermrest_model import Table as DerivaTable
18
+ from deriva.core.ermrest_model import Type as DerivaType
19
+ from pydantic import ConfigDict, validate_call
20
+ from sqlalchemy import (
21
+ JSON,
22
+ Boolean,
23
+ Date,
24
+ DateTime,
25
+ Float,
26
+ Integer,
27
+ MetaData,
28
+ String,
29
+ create_engine,
30
+ event,
31
+ inspect,
32
+ select,
33
+ )
34
+ from sqlalchemy import Column as SQLColumn
35
+ from sqlalchemy import ForeignKeyConstraint as SQLForeignKeyConstraint
36
+ from sqlalchemy import Table as SQLTable
37
+ from sqlalchemy import UniqueConstraint as SQLUniqueConstraint
38
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
39
+ from sqlalchemy.ext.automap import automap_base
40
+ from sqlalchemy.orm import backref, configure_mappers, foreign, relationship
41
+ from sqlalchemy.sql.type_api import TypeEngine
42
+ from sqlalchemy.types import TypeDecorator
16
43
 
17
44
  from deriva_ml.core.definitions import ML_SCHEMA, RID, MLVocab
18
45
  from deriva_ml.core.exceptions import DerivaMLException
19
46
  from deriva_ml.dataset.aux_classes import DatasetMinid, DatasetVersion
20
47
  from deriva_ml.dataset.dataset_bag import DatasetBag
21
48
  from deriva_ml.model.catalog import DerivaModel
22
- from deriva_ml.model.sql_mapper import SQLMapper
23
49
 
24
50
  try:
25
51
  from icecream import ic
@@ -27,8 +53,63 @@ except ImportError: # Graceful fallback if IceCream isn't installed.
27
53
  ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa
28
54
 
29
55
 
56
+ class ERMRestBoolean(TypeDecorator):
57
+ impl = Boolean
58
+ cache_ok = True
59
+
60
+ def process_bind_param(self, value, dialect):
61
+ if value in ("Y", "y", 1, True, "t", "T"):
62
+ return True
63
+ elif value in ("N", "n", 0, False, "f", "F"):
64
+ return False
65
+ elif value is None:
66
+ return None
67
+ raise ValueError(f"Invalid boolean value: {value!r}")
68
+
69
+
70
+ class StringToFloat(TypeDecorator):
71
+ impl = Float
72
+ cache_ok = True
73
+
74
+ def process_bind_param(self, value, dialect):
75
+ if value == "" or value is None:
76
+ return None
77
+ else:
78
+ return float(value)
79
+
80
+ class StringToInteger(TypeDecorator):
81
+ impl = Integer
82
+ cache_ok = True
83
+
84
+ def process_bind_param(self, value, dialect):
85
+ if value == "" or value is None:
86
+ return None
87
+ else:
88
+ return int(value)
89
+
90
+
91
+ class StringToDateTime(TypeDecorator):
92
+ impl = DateTime
93
+ cache_ok = True
94
+
95
+ def process_bind_param(self, value, dialect):
96
+ if value == "" or value is None:
97
+ return None
98
+ else:
99
+ return parser.parse(value)
100
+
101
+ class StringToDate(TypeDecorator):
102
+ impl = Date
103
+ cache_ok = True
104
+
105
+ def process_bind_param(self, value, dialect):
106
+ if value == "" or value is None:
107
+ return None
108
+ else:
109
+ return parser.parse(value).date()
110
+
30
111
  class DatabaseModelMeta(type):
31
- """Use metaclass to ensure that there is onl one instance per path"""
112
+ """Use metaclass to ensure that there is only one instance of a database model per path"""
32
113
 
33
114
  _paths_loaded: dict[Path, "DatabaseModel"] = {}
34
115
 
@@ -65,7 +146,7 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
65
146
  bag_path (Path): path to the local copy of the BDBag
66
147
  minid (DatasetMinid): Minid for the specified bag
67
148
  dataset_rid (RID): RID for the specified dataset
68
- dbase (Connection): connection to the sqlite database holding table values
149
+ engine (Connection): connection to the sqlalchemy database holding table values
69
150
  domain_schema (str): Name of the domain schema
70
151
  dataset_table (Table): the dataset table in the ERMRest model.
71
152
  """
@@ -73,24 +154,6 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
73
154
  # Maintain a global map of RIDS to versions and databases.
74
155
  _rid_map: dict[RID, list[tuple[DatasetVersion, "DatabaseModel"]]] = {}
75
156
 
76
- @staticmethod
77
- def rid_lookup(dataset_rid: RID) -> list[tuple[DatasetVersion, "DatabaseModel"]]:
78
- """Return a list of DatasetVersion/DatabaseModel instances corresponding to the given RID.
79
-
80
- Args:
81
- dataset_rid: Rit to be looked up.
82
-
83
- Returns:
84
- List of DatasetVersion/DatabaseModel instances corresponding to the given RID.
85
-
86
- Raises:
87
- Raise a DerivaMLException if the given RID is not found.
88
- """
89
- try:
90
- return DatabaseModel._rid_map[dataset_rid]
91
- except KeyError:
92
- raise DerivaMLException(f"Dataset {dataset_rid} not found")
93
-
94
157
  def __init__(self, minid: DatasetMinid, bag_path: Path, dbase_path: Path):
95
158
  """Create a new DatabaseModel.
96
159
 
@@ -99,33 +162,41 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
99
162
  bag_path: Path to the local copy of the BDBag.
100
163
  """
101
164
 
165
+ super().__init__(Model.fromfile("file-system", bag_path / "data/schema.json"))
166
+
102
167
  self.bag_path = bag_path
103
168
  self.minid = minid
104
169
  self.dataset_rid = minid.dataset_rid
105
- self.dbase_file = dbase_path / f"{minid.version_rid}.db"
106
- self.dbase = sqlite3.connect(self.dbase_file)
170
+ self.dbase_path = dbase_path / f"{minid.version_rid}"
171
+ self.dbase_path.mkdir(parents=True, exist_ok=True)
172
+
173
+ self.engine = create_engine(f"sqlite:///{(self.dbase_path / 'main.db').resolve()}", future=True)
174
+ self.metadata = MetaData()
175
+ self.Base = automap_base(metadata=self.metadata)
176
+
177
+ # Attach event listener for *this instance's* engine
178
+ event.listen(self.engine, "connect", self._attach_schemas)
107
179
 
108
180
  schema_file = self.bag_path / "data/schema.json"
109
181
  with schema_file.open("r") as f:
110
182
  self.snaptime = json.load(f)["snaptime"]
111
183
 
112
- super().__init__(Model.fromfile("file-system", self.bag_path / "data/schema.json"))
113
184
  self._logger = logging.getLogger("deriva_ml")
114
185
  self._load_model()
115
186
  self.ml_schema = ML_SCHEMA
116
- self._load_sqlite()
187
+ self._load_database()
117
188
  self._logger.info(
118
189
  "Creating new database for dataset: %s in %s",
119
190
  self.dataset_rid,
120
- self.dbase_file,
191
+ self.dbase_path,
121
192
  )
122
193
  self.dataset_table = self.model.schemas[self.ml_schema].tables["Dataset"]
194
+
123
195
  # Now go through the database and pick out all the dataset_table RIDS, along with their versions.
124
- sql_dataset = self.normalize_table_name("Dataset_Version")
125
- with self.dbase:
126
- dataset_versions = [
127
- t for t in self.dbase.execute(f'SELECT "Dataset", "Version" FROM "{sql_dataset}"').fetchall()
128
- ]
196
+ with self.engine.connect() as conn:
197
+ dataset_version = self.metadata.tables[f"{self.ml_schema}.Dataset_Version"]
198
+ result = conn.execute(select(dataset_version.c.Dataset, dataset_version.c.Version))
199
+ dataset_versions = [t for t in result]
129
200
 
130
201
  dataset_versions = [(v[0], DatasetVersion.parse(v[1])) for v in dataset_versions]
131
202
  # Get most current version of each rid
@@ -137,16 +208,159 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
137
208
  version_list = DatabaseModel._rid_map.setdefault(dataset_rid, [])
138
209
  version_list.append((dataset_version, self))
139
210
 
211
+ def _attach_schemas(self, dbapi_conn, _conn_record):
212
+ cur = dbapi_conn.cursor()
213
+ for schema in [self.domain_schema, self.ml_schema]:
214
+ schema_file = (self.dbase_path / f"{schema}.db").resolve()
215
+ cur.execute(f"ATTACH DATABASE '{schema_file}' AS '{schema}'")
216
+ cur.close()
217
+
218
+ @staticmethod
219
+ def _sql_type(type: DerivaType) -> TypeEngine:
220
+ """Return the SQL type for a Deriva column."""
221
+ return {
222
+ "boolean": ERMRestBoolean,
223
+ "date": StringToDate,
224
+ "float4": StringToFloat,
225
+ "float8": StringToFloat,
226
+ "int2": StringToInteger,
227
+ "int4": StringToInteger,
228
+ "int8": StringToInteger,
229
+ "json": JSON,
230
+ "jsonb": JSON,
231
+ "timestamptz": StringToDateTime,
232
+ "timestamp": StringToDateTime,
233
+ }.get(type.typename, String)
234
+
140
235
  def _load_model(self) -> None:
141
236
  """Create a sqlite database schema that contains all the tables within the catalog from which the BDBag
142
237
  was created."""
143
- with self.dbase:
144
- for t in self.model.schemas[self.domain_schema].tables.values():
145
- self.dbase.execute(t.sqlite3_ddl())
146
- for t in self.model.schemas["deriva-ml"].tables.values():
147
- self.dbase.execute(t.sqlite3_ddl())
148
238
 
149
- def _load_sqlite(self) -> None:
239
+ def is_key(column: DerivaColumn, table: DerivaTable) -> bool:
240
+ return column in [key.unique_columns[0] for key in table.keys] and column.name == "RID"
241
+
242
+ def col(model, name: str):
243
+ # try ORM attribute first
244
+ try:
245
+ return getattr(model, name).property.columns[0]
246
+ except AttributeError:
247
+ # fall back to exact DB column key on the Table
248
+ return model.__table__.c[name]
249
+
250
+ def guess_attr_name(col_name: str) -> str:
251
+ return col_name[:-3] if col_name.lower().endswith("_id") else col_name
252
+
253
+ database_tables: list[SQLTable] = []
254
+ for schema_name in [self.domain_schema, self.ml_schema]:
255
+ for table in self.model.schemas[schema_name].tables.values():
256
+ database_columns: list[SQLColumn] = []
257
+ for c in table.columns:
258
+ # clone column (type, nullability, PK, defaults, unique)
259
+ database_column = SQLColumn(
260
+ name=c.name,
261
+ type_=self._sql_type(c.type), # SQLAlchemy type object is reusable
262
+ comment=c.comment,
263
+ default=c.default,
264
+ primary_key=is_key(c, table),
265
+ nullable=c.nullok,
266
+ # NOTE: server_onupdate, computed, etc. can be added if you use them
267
+ )
268
+ database_columns.append(database_column)
269
+ database_table = SQLTable(table.name, self.metadata, *database_columns, schema=schema_name)
270
+ for key in table.keys:
271
+ key_columns = [c.name for c in key.unique_columns]
272
+ # if key.name[0] == "RID":
273
+ # continue
274
+ database_table.append_constraint(
275
+ SQLUniqueConstraint(
276
+ *key_columns,
277
+ name=key.name[1],
278
+ )
279
+ )
280
+ for fk in table.foreign_keys:
281
+ if fk.pk_table.schema.name not in [self.domain_schema, self.ml_schema]:
282
+ continue
283
+ if fk.pk_table.schema.name != schema_name:
284
+ continue
285
+ # Attach FK to the chosen column
286
+ database_table.append_constraint(
287
+ SQLForeignKeyConstraint(
288
+ columns=[f"{c.name}" for c in fk.foreign_key_columns],
289
+ refcolumns=[f"{schema_name}.{c.table.name}.{c.name}" for c in fk.referenced_columns],
290
+ name=fk.name[1],
291
+ comment=fk.comment,
292
+ )
293
+ )
294
+ database_tables.append(database_table)
295
+ with self.engine.begin() as conn:
296
+ self.metadata.create_all(conn, tables=database_tables)
297
+
298
+ def name_for_scalar_relationship(_base, local_cls, referred_cls, constraint):
299
+ cols = list(constraint.columns) if constraint is not None else []
300
+ if len(cols) == 1:
301
+ name = cols[0].key
302
+ if name in {c.key for c in local_cls.__table__.columns}:
303
+ name += "_rel"
304
+ return name
305
+ return constraint.name or referred_cls.__name__.lower()
306
+
307
+ def name_for_collection_relationship(_base, local_cls, referred_cls, constraint):
308
+ backref_name = constraint.name.replace("_fkey", "_collection")
309
+ return backref_name or (referred_cls.__name__.lower() + "_collection")
310
+
311
+ # Now build ORM mappings for the tables.
312
+ self.Base.prepare(
313
+ self.engine,
314
+ name_for_scalar_relationship=name_for_scalar_relationship,
315
+ name_for_collection_relationship=name_for_collection_relationship,
316
+ reflect=True,
317
+ )
318
+
319
+ for schema in [self.domain_schema, self.ml_schema]:
320
+ for table in self.model.schemas[schema].tables.values():
321
+ for fk in table.foreign_keys:
322
+ if fk.pk_table.schema.name not in [self.domain_schema, self.ml_schema]:
323
+ continue
324
+ if fk.pk_table.schema.name == schema:
325
+ continue
326
+ table_name = f"{schema}.{table.name}"
327
+ table_class = self.get_orm_class_by_name(table_name)
328
+ foreign_key_column_name = fk.foreign_key_columns[0].name
329
+ foreign_key_column = col(table_class, foreign_key_column_name)
330
+
331
+ referenced_table_name = f"{fk.pk_table.schema.name}.{fk.pk_table.name}"
332
+ referenced_class = self.get_orm_class_by_name(referenced_table_name)
333
+ referenced_column = col(referenced_class, fk.referenced_columns[0].name)
334
+
335
+ relationship_attr = guess_attr_name(foreign_key_column_name)
336
+ backref_attr = fk.name[1].replace("_fkey", "_collection")
337
+ setattr(
338
+ table_class,
339
+ relationship_attr,
340
+ relationship(
341
+ referenced_class,
342
+ foreign_keys=[foreign_key_column],
343
+ primaryjoin=foreign(foreign_key_column) == referenced_column,
344
+ backref=backref(backref_attr, viewonly=True),
345
+ viewonly=True, # set False for write behavior, but best with proper FKs
346
+ ),
347
+ )
348
+
349
+ # Reflect won't pick up the second FK in the dataset_dataset table. We need to do it manually
350
+ # dataset_dataset_class = self.get_orm_class_by_name("deriva-ml.Dataset_Dataset")
351
+ # dataset_class = self.get_orm_class_by_name("deriva-ml.Dataset")
352
+ # dataset_dataset_class.Nested_Dataset = relationship(
353
+ # dataset_class,
354
+ # primaryjoin=foreign(dataset_dataset_class.__table__.c["Nested_Dataset"])
355
+ # == dataset_class.__table__.c["RID"],
356
+ # foreign_keys=[dataset_dataset_class.__table__.c["Nested_Dataset"]],
357
+ # backref=backref("nested_dataset_collection", viewonly=True), # pick a distinct name
358
+ # viewonly=True, # keep it read-only unless you truly want writes
359
+ # overlaps="Dataset,dataset_dataset_collection", # optional: silence overlap warnings
360
+ # )
361
+ configure_mappers()
362
+
363
+ def _load_database(self) -> None:
150
364
  """Load a SQLite database from a bdbag. THis is done by looking for all the CSV files in the bdbag directory.
151
365
 
152
366
  If the file is for an asset table, update the FileName column of the table to have the local file path for
@@ -160,6 +374,7 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
160
374
  for csv_file in Path(dpath).rglob("*.csv"):
161
375
  table = csv_file.stem
162
376
  schema = self.domain_schema if table in self.model.schemas[self.domain_schema].tables else self.ml_schema
377
+ sql_table = self.metadata.tables[f"{schema}.{table}"]
163
378
 
164
379
  with csv_file.open(newline="") as csvfile:
165
380
  csv_reader = reader(csvfile)
@@ -170,15 +385,14 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
170
385
  (column_names.index("Filename"), column_names.index("URL")) if self._is_asset(table) else None
171
386
  )
172
387
 
173
- value_template = ",".join(["?"] * len(column_names)) # SQL placeholder for row (?,?..)
174
- column_list = ",".join([f'"{c}"' for c in column_names])
175
- with self.dbase:
176
- object_table = (
177
- self._localize_asset(o, asset_indexes, asset_map, table == "Dataset") for o in csv_reader
178
- )
179
- self.dbase.executemany(
180
- f'INSERT OR REPLACE INTO "{schema}:{table}" ({column_list}) VALUES ({value_template})',
181
- object_table,
388
+ with self.engine.begin() as conn:
389
+ object_table = [
390
+ self._localize_asset(o, asset_indexes, asset_map, table == "Dataset")
391
+ for o in csv_reader
392
+ ]
393
+ conn.execute(
394
+ sqlite_insert(sql_table).on_conflict_do_nothing(),
395
+ [dict(zip(column_names, row)) for row in object_table],
182
396
  )
183
397
 
184
398
  def _localize_asset_table(self) -> dict[str, str]:
@@ -237,19 +451,21 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
237
451
  o[file_column] = asset_map[o[url_column]] if o[url_column] else ""
238
452
  return tuple(o)
239
453
 
454
+ def find_table(self, table_name: str) -> SQLTable:
455
+ """Find a table in the catalog."""
456
+ # We will look across ml and domain schema to find a table whose name matches.
457
+ table = [t for t in self.metadata.tables if t == table_name or t.split(".")[1] == table_name][0]
458
+ return self.metadata.tables[table]
459
+
240
460
  def list_tables(self) -> list[str]:
241
461
  """List the names of the tables in the catalog
242
462
 
243
463
  Returns:
244
464
  A list of table names. These names are all qualified with the Deriva schema name.
245
465
  """
246
- with self.dbase:
247
- return [
248
- t[0]
249
- for t in self.dbase.execute(
250
- "SELECT name FROM sqlite_master WHERE type = 'table' ORDER BY name;"
251
- ).fetchall()
252
- ]
466
+ tables = list(self.metadata.tables.keys())
467
+ tables.sort()
468
+ return tables
253
469
 
254
470
  def get_dataset(self, dataset_rid: Optional[RID] = None) -> DatasetBag:
255
471
  """Get a dataset, or nested dataset from the bag database
@@ -281,8 +497,8 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
281
497
 
282
498
  # Get a list of all the dataset_type values associated with this dataset_table.
283
499
  datasets = []
284
- ds_types = list(self._get_table(atable))
285
- for dataset in self._get_table("Dataset"):
500
+ ds_types = list(self._get_table_contents(atable))
501
+ for dataset in self._get_table_contents("Dataset"):
286
502
  my_types = [t for t in ds_types if t["Dataset"] == dataset["RID"]]
287
503
  datasets.append(dataset | {MLVocab.dataset_type: [ds[MLVocab.dataset_type] for ds in my_types]})
288
504
  return datasets
@@ -291,48 +507,206 @@ class DatabaseModel(DerivaModel, metaclass=DatabaseModelMeta):
291
507
  """Returns a list of all the dataset_table entries associated with a dataset."""
292
508
  return self.get_dataset(dataset_rid).list_dataset_members()
293
509
 
294
- def _get_table(self, table: str) -> Generator[dict[str, Any], None, None]:
510
+ def _get_table_contents(self, table: str) -> Generator[dict[str, Any], None, None]:
295
511
  """Retrieve the contents of the specified table as a dictionary.
296
512
 
297
513
  Args:
298
- table: Table to retrieve data from. f schema is not provided as part of the table name,
514
+ table: Table to retrieve data from. If schema is not provided as part of the table name,
299
515
  the method will attempt to locate the schema for the table.
300
516
 
301
517
  Returns:
302
518
  A generator producing dictionaries containing the contents of the specified table as name/value pairs.
303
519
  """
304
- table_name = self.normalize_table_name(table)
305
- table = self.name_to_table(table)
306
520
 
307
- with self.dbase as _dbase:
308
- mapper = SQLMapper(self, table.name)
309
- result = self.dbase.execute(f'SELECT * FROM "{table_name}"')
521
+ with self.engine.connect() as conn:
522
+ result = conn.execute(select(self.find_table(table)))
523
+ for row in result.mappings():
524
+ yield dict(row)
310
525
 
311
- while (row := result.fetchone()) is not None:
312
- yield mapper.transform_tuple(row)
313
-
314
- def normalize_table_name(self, table: str) -> str:
315
- """Attempt to insert the schema into a table name if it's not provided.
526
+ @staticmethod
527
+ def rid_lookup(dataset_rid: RID) -> list[tuple[DatasetVersion, "DatabaseModel"]]:
528
+ """Return a list of DatasetVersion/DatabaseModel instances corresponding to the given RID.
316
529
 
317
530
  Args:
318
- table: str:
531
+ dataset_rid: Rit to be looked up.
319
532
 
320
533
  Returns:
321
- table name with schema included.
534
+ List of DatasetVersion/DatabaseModel instances corresponding to the given RID.
322
535
 
536
+ Raises:
537
+ Raise a DerivaMLException if the given RID is not found.
323
538
  """
324
539
  try:
325
- [sname, tname] = table.split(":")
326
- except ValueError:
327
- tname = table
328
- for sname in [self.domain_schema, self.ml_schema, "WWW"]: # Be careful of File table.
329
- if sname in self.model.schemas and table in self.model.schemas[sname].tables:
330
- break
331
- try:
332
- _ = self.model.schemas[sname].tables[tname]
333
- return f"{sname}:{tname}"
540
+ return DatabaseModel._rid_map[dataset_rid]
334
541
  except KeyError:
335
- raise DerivaMLException(f'Table name "{table}" does not exist.')
542
+ raise DerivaMLException(f"Dataset {dataset_rid} not found")
543
+
544
+ def get_orm_class_by_name(self, table: str) -> Any | None:
545
+ sql_table = self.find_table(table)
546
+ if sql_table is None:
547
+ raise DerivaMLException(f"Table {table} not found")
548
+ return self.get_orm_class_for_table(sql_table)
549
+
550
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
551
+ def get_orm_class_for_table(self, table: SQLTable | DerivaTable | str) -> Any | None:
552
+ if isinstance(table, DerivaTable):
553
+ table = self.metadata.tables[f"{table.schema.name}.{table.name}"]
554
+ if isinstance(table, str):
555
+ table = self.find_table(table)
556
+ for mapper in self.Base.registry.mappers:
557
+ if mapper.persist_selectable is table or table in mapper.tables:
558
+ return mapper.class_
559
+ return None
560
+
561
+ @staticmethod
562
+ def _is_association(
563
+ table_class, min_arity=2, max_arity=2, unqualified=True, pure=True, no_overlap=True, return_fkeys=False
564
+ ):
565
+ """Return (truthy) integer arity if self is a matching association, else False.
566
+
567
+ min_arity: minimum number of associated fkeys (default 2)
568
+ max_arity: maximum number of associated fkeys (default 2) or None
569
+ unqualified: reject qualified associations when True (default True)
570
+ pure: reject impure associations when True (default True)
571
+ no_overlap: reject overlapping associations when True (default True)
572
+ return_fkeys: return the set of N associated ForeignKeys if True
573
+
574
+ The default behavior with no arguments is to test for pure,
575
+ unqualified, non-overlapping, binary associations.
576
+
577
+ An association is comprised of several foreign keys which are
578
+ covered by a non-nullable composite row key. This allows
579
+ specific combinations of foreign keys to appear at most once.
580
+
581
+ The arity of an association is the number of foreign keys
582
+ being associated. A typical binary association has arity=2.
583
+
584
+ An unqualified association contains *only* the foreign key
585
+ material in its row key. Conversely, a qualified association
586
+ mixes in other material which means that a specific
587
+ combination of foreign keys may repeat with different
588
+ qualifiers.
589
+
590
+ A pure association contains *only* row key
591
+ material. Conversely, an impure association includes
592
+ additional metadata columns not covered by the row key. Unlike
593
+ qualifiers, impure metadata merely decorates an association
594
+ without augmenting its identifying characteristics.
595
+
596
+ A non-overlapping association does not share any columns
597
+ between multiple foreign keys. This means that all
598
+ combinations of foreign keys are possible. Conversely, an
599
+ overlapping association shares some columns between multiple
600
+ foreign keys, potentially limiting the combinations which can
601
+ be represented in an association row.
602
+
603
+ These tests ignore the five ERMrest system columns and any
604
+ corresponding constraints.
605
+
606
+ """
607
+ if min_arity < 2:
608
+ raise ValueError("An assocation cannot have arity < 2")
609
+ if max_arity is not None and max_arity < min_arity:
610
+ raise ValueError("max_arity cannot be less than min_arity")
611
+
612
+ mapper = inspect(table_class).mapper
613
+
614
+ # TODO: revisit whether there are any other cases we might
615
+ # care about where system columns are involved?
616
+ non_sys_cols = {col.name for col in mapper.columns if col.name not in {"RID", "RCT", "RMT", "RCB", "RMB"}}
617
+ unique_columns = [
618
+ {c.name for c in constraint.columns}
619
+ for constraint in inspect(table_class).local_table.constraints
620
+ if isinstance(constraint, SQLUniqueConstraint)
621
+ ]
622
+
623
+ non_sys_key_colsets = {
624
+ frozenset(unique_column_set)
625
+ for unique_column_set in unique_columns
626
+ if unique_column_set.issubset(non_sys_cols) and len(unique_column_set) > 1
627
+ }
628
+
629
+ if not non_sys_key_colsets:
630
+ # reject: not association
631
+ return False
632
+
633
+ # choose longest compound key (arbitrary choice with ties!)
634
+ row_key = sorted(non_sys_key_colsets, key=lambda s: len(s), reverse=True)[0]
635
+ foreign_keys = [constraint for constraint in inspect(table_class).relationships.values()]
636
+
637
+ covered_fkeys = {fkey for fkey in foreign_keys if {c.name for c in fkey.local_columns}.issubset(row_key)}
638
+ covered_fkey_cols = set()
639
+
640
+ if len(covered_fkeys) < min_arity:
641
+ # reject: not enough fkeys in association
642
+ return False
643
+ elif max_arity is not None and len(covered_fkeys) > max_arity:
644
+ # reject: too many fkeys in association
645
+ return False
646
+
647
+ for fkey in covered_fkeys:
648
+ fkcols = {c.name for c in fkey.local_columns}
649
+ if no_overlap and fkcols.intersection(covered_fkey_cols):
650
+ # reject: overlapping fkeys in association
651
+ return False
652
+ covered_fkey_cols.update(fkcols)
653
+
654
+ if unqualified and row_key.difference(covered_fkey_cols):
655
+ # reject: qualified association
656
+ return False
657
+
658
+ if pure and non_sys_cols.difference(row_key):
659
+ # reject: impure association
660
+ return False
661
+
662
+ # return (truthy) arity or fkeys
663
+ if return_fkeys:
664
+ return covered_fkeys
665
+ else:
666
+ return len(covered_fkeys)
667
+
668
+ def get_orm_association_class(
669
+ self,
670
+ left_cls: Type[Any],
671
+ right_cls: Type[Any],
672
+ min_arity=2,
673
+ max_arity=2,
674
+ unqualified=True,
675
+ pure=True,
676
+ no_overlap=True,
677
+ ):
678
+ """
679
+ Find an association class C by: (1) walking rels on left_cls to a mid class C,
680
+ (2) verifying C also relates to right_cls. Returns (C, C->left, C->right) or None.
681
+
682
+ """
683
+ for _, left_rel in inspect(left_cls).relationships.items():
684
+ mid_cls = left_rel.mapper.class_
685
+ is_assoc = self._is_association(mid_cls, return_fkeys=True)
686
+ if not is_assoc:
687
+ continue
688
+ assoc_local_columns_left = list(is_assoc)[0].local_columns
689
+ assoc_local_columns_right = list(is_assoc)[1].local_columns
690
+
691
+ found_left = found_right = False
692
+ for r in inspect(left_cls).relationships.values():
693
+ remote_side = list(r.remote_side)[0]
694
+ if remote_side in assoc_local_columns_left:
695
+ found_left = r
696
+ if remote_side in assoc_local_columns_right:
697
+ found_left = r
698
+ # We have left and right backwards from the assocation, so swap them.
699
+ assoc_local_columns_left, assoc_local_columns_right = (
700
+ assoc_local_columns_right,
701
+ assoc_local_columns_left,
702
+ )
703
+ for r in inspect(right_cls).relationships.values():
704
+ remote_side = list(r.remote_side)[0]
705
+ if remote_side in assoc_local_columns_right:
706
+ found_right = r
707
+ if found_left != False and found_right != False:
708
+ return mid_cls, found_left.class_attribute, found_right.class_attribute
709
+ return None
336
710
 
337
711
  def delete_database(self):
338
712
  """