deriva-ml 1.17.14__py3-none-any.whl → 1.17.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. deriva_ml/__init__.py +2 -2
  2. deriva_ml/asset/asset.py +0 -4
  3. deriva_ml/catalog/__init__.py +6 -0
  4. deriva_ml/catalog/clone.py +1591 -38
  5. deriva_ml/catalog/localize.py +66 -29
  6. deriva_ml/core/base.py +12 -9
  7. deriva_ml/core/definitions.py +13 -12
  8. deriva_ml/core/ermrest.py +11 -12
  9. deriva_ml/core/mixins/annotation.py +2 -2
  10. deriva_ml/core/mixins/asset.py +3 -3
  11. deriva_ml/core/mixins/dataset.py +3 -3
  12. deriva_ml/core/mixins/execution.py +1 -0
  13. deriva_ml/core/mixins/feature.py +2 -2
  14. deriva_ml/core/mixins/file.py +2 -2
  15. deriva_ml/core/mixins/path_builder.py +2 -2
  16. deriva_ml/core/mixins/rid_resolution.py +2 -2
  17. deriva_ml/core/mixins/vocabulary.py +2 -2
  18. deriva_ml/core/mixins/workflow.py +3 -3
  19. deriva_ml/dataset/catalog_graph.py +3 -4
  20. deriva_ml/dataset/dataset.py +5 -3
  21. deriva_ml/dataset/dataset_bag.py +0 -2
  22. deriva_ml/dataset/upload.py +2 -2
  23. deriva_ml/demo_catalog.py +0 -1
  24. deriva_ml/execution/__init__.py +8 -8
  25. deriva_ml/execution/base_config.py +2 -2
  26. deriva_ml/execution/execution.py +5 -3
  27. deriva_ml/execution/execution_record.py +0 -1
  28. deriva_ml/execution/model_protocol.py +1 -1
  29. deriva_ml/execution/multirun_config.py +0 -1
  30. deriva_ml/execution/runner.py +3 -3
  31. deriva_ml/experiment/experiment.py +3 -3
  32. deriva_ml/feature.py +2 -2
  33. deriva_ml/interfaces.py +2 -2
  34. deriva_ml/model/__init__.py +45 -24
  35. deriva_ml/model/annotations.py +0 -1
  36. deriva_ml/model/catalog.py +3 -2
  37. deriva_ml/model/data_loader.py +330 -0
  38. deriva_ml/model/data_sources.py +439 -0
  39. deriva_ml/model/database.py +216 -32
  40. deriva_ml/model/fk_orderer.py +379 -0
  41. deriva_ml/model/handles.py +1 -1
  42. deriva_ml/model/schema_builder.py +816 -0
  43. deriva_ml/run_model.py +3 -3
  44. deriva_ml/schema/annotations.py +2 -1
  45. deriva_ml/schema/create_schema.py +1 -1
  46. deriva_ml/schema/validation.py +1 -1
  47. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/METADATA +1 -1
  48. deriva_ml-1.17.16.dist-info/RECORD +81 -0
  49. deriva_ml-1.17.14.dist-info/RECORD +0 -77
  50. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/WHEEL +0 -0
  51. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/entry_points.txt +0 -0
  52. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/licenses/LICENSE +0 -0
  53. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,816 @@
1
+ """Create SQLAlchemy ORM from Deriva catalog model.
2
+
3
+ This module provides the SchemaBuilder class which creates a SQLAlchemy ORM
4
+ from a Deriva Model object. This is Phase 1 of the two-phase pattern:
5
+
6
+ 1. Phase 1 (SchemaBuilder): Create ORM structure without data
7
+ 2. Phase 2 (DataLoader): Fill database from a data source
8
+
9
+ The Model object can come from either:
10
+ - A live catalog: catalog.getCatalogModel()
11
+ - A schema.json file: Model.fromfile("file-system", schema_file)
12
+
13
+ Example:
14
+ # From catalog
15
+ model = catalog.getCatalogModel()
16
+ builder = SchemaBuilder(model, schemas=['domain', 'deriva-ml'])
17
+ orm = builder.build()
18
+
19
+ # From file
20
+ model = Model.fromfile("file-system", "schema.json")
21
+ builder = SchemaBuilder(model, schemas=['domain', 'deriva-ml'])
22
+ orm = builder.build()
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import logging
28
+ from pathlib import Path
29
+ from typing import Any, Generator, Type
30
+
31
+ from dateutil import parser
32
+ from deriva.core.ermrest_model import Column as DerivaColumn
33
+ from deriva.core.ermrest_model import Model
34
+ from deriva.core.ermrest_model import Table as DerivaTable
35
+ from deriva.core.ermrest_model import Type as DerivaType
36
+ from sqlalchemy import (
37
+ JSON,
38
+ Boolean,
39
+ Date,
40
+ DateTime,
41
+ Float,
42
+ Integer,
43
+ MetaData,
44
+ String,
45
+ create_engine,
46
+ event,
47
+ inspect,
48
+ select,
49
+ )
50
+ from sqlalchemy import Column as SQLColumn
51
+ from sqlalchemy import ForeignKeyConstraint as SQLForeignKeyConstraint
52
+ from sqlalchemy import Table as SQLTable
53
+ from sqlalchemy import UniqueConstraint as SQLUniqueConstraint
54
+ from sqlalchemy.engine import Engine
55
+ from sqlalchemy.ext.automap import AutomapBase, automap_base
56
+ from sqlalchemy.orm import backref, foreign, relationship
57
+ from sqlalchemy.sql.type_api import TypeEngine
58
+ from sqlalchemy.types import TypeDecorator
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ # =============================================================================
64
+ # Type converters for loading CSV string data into SQLite with proper types
65
+ # =============================================================================
66
+
67
+ class ERMRestBoolean(TypeDecorator):
68
+ """Convert ERMrest boolean strings to Python bool."""
69
+ impl = Boolean
70
+ cache_ok = True
71
+
72
+ def process_bind_param(self, value: Any, dialect: Any) -> bool | None:
73
+ if value in ("Y", "y", 1, True, "t", "T"):
74
+ return True
75
+ elif value in ("N", "n", 0, False, "f", "F"):
76
+ return False
77
+ elif value is None or value == "":
78
+ return None
79
+ raise ValueError(f"Invalid boolean value: {value!r}")
80
+
81
+
82
+ class StringToFloat(TypeDecorator):
83
+ """Convert string to float, handling empty strings."""
84
+ impl = Float
85
+ cache_ok = True
86
+
87
+ def process_bind_param(self, value: Any, dialect: Any) -> float | None:
88
+ if value == "" or value is None:
89
+ return None
90
+ return float(value)
91
+
92
+
93
+ class StringToInteger(TypeDecorator):
94
+ """Convert string to integer, handling empty strings."""
95
+ impl = Integer
96
+ cache_ok = True
97
+
98
+ def process_bind_param(self, value: Any, dialect: Any) -> int | None:
99
+ if value == "" or value is None:
100
+ return None
101
+ return int(value)
102
+
103
+
104
+ class StringToDateTime(TypeDecorator):
105
+ """Convert string to datetime, handling empty strings."""
106
+ impl = DateTime
107
+ cache_ok = True
108
+
109
+ def process_bind_param(self, value: Any, dialect: Any) -> Any:
110
+ if value == "" or value is None:
111
+ return None
112
+ return parser.parse(value)
113
+
114
+
115
+ class StringToDate(TypeDecorator):
116
+ """Convert string to date, handling empty strings."""
117
+ impl = Date
118
+ cache_ok = True
119
+
120
+ def process_bind_param(self, value: Any, dialect: Any) -> Any:
121
+ if value == "" or value is None:
122
+ return None
123
+ return parser.parse(value).date()
124
+
125
+
126
+ # =============================================================================
127
+ # SchemaORM - Container for SQLAlchemy ORM components
128
+ # =============================================================================
129
+
130
+ class SchemaORM:
131
+ """Container for SQLAlchemy ORM components.
132
+
133
+ Provides access to the ORM structure and utility methods for
134
+ table/class lookup. This is the result of Phase 1 (SchemaBuilder).
135
+
136
+ Attributes:
137
+ engine: SQLAlchemy Engine for database connections.
138
+ metadata: SQLAlchemy MetaData with table definitions.
139
+ Base: SQLAlchemy automap base for ORM classes.
140
+ model: ERMrest Model the ORM was built from.
141
+ schemas: List of schema names included.
142
+ use_schemas: Whether schema prefixes are used (False for in-memory).
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ engine: Engine,
148
+ metadata: MetaData,
149
+ Base: AutomapBase,
150
+ model: Model,
151
+ schemas: list[str],
152
+ class_prefix: str,
153
+ use_schemas: bool = True,
154
+ ):
155
+ """Initialize SchemaORM container.
156
+
157
+ Args:
158
+ engine: SQLAlchemy Engine.
159
+ metadata: SQLAlchemy MetaData with tables.
160
+ Base: Automap base with ORM classes.
161
+ model: Source ERMrest Model.
162
+ schemas: Schemas that were included.
163
+ class_prefix: Prefix used for ORM class names.
164
+ use_schemas: Whether schema prefixes are used (False for in-memory).
165
+ """
166
+ self.engine = engine
167
+ self.metadata = metadata
168
+ self.Base = Base
169
+ self.model = model
170
+ self.schemas = schemas
171
+ self._class_prefix = class_prefix
172
+ self._use_schemas = use_schemas
173
+ self._disposed = False
174
+
175
+ def list_tables(self) -> list[str]:
176
+ """List all tables in the database.
177
+
178
+ Returns:
179
+ List of fully-qualified table names (schema.table), sorted.
180
+ """
181
+ tables = list(self.metadata.tables.keys())
182
+ tables.sort()
183
+ return tables
184
+
185
+ def find_table(self, table_name: str) -> SQLTable:
186
+ """Find a table by name.
187
+
188
+ Handles both schema.table format and schema_table format (for in-memory databases).
189
+
190
+ Args:
191
+ table_name: Table name, with or without schema prefix.
192
+ Can be "schema.table", "schema_table", or just "table".
193
+
194
+ Returns:
195
+ SQLAlchemy Table object.
196
+
197
+ Raises:
198
+ KeyError: If table not found.
199
+ """
200
+ # Try exact match first
201
+ if table_name in self.metadata.tables:
202
+ return self.metadata.tables[table_name]
203
+
204
+ # Try converting schema.table to schema_table format (for in-memory)
205
+ if "." in table_name and not self._use_schemas:
206
+ converted_name = table_name.replace(".", "_").replace("-", "_")
207
+ if converted_name in self.metadata.tables:
208
+ return self.metadata.tables[converted_name]
209
+
210
+ # Try matching just the table name part
211
+ for full_name, table in self.metadata.tables.items():
212
+ # Handle . separator (file-based)
213
+ if "." in full_name and full_name.split(".")[-1] == table_name:
214
+ return table
215
+ # Handle _ separator (in-memory) - match suffix after first _
216
+ if "_" in full_name and "." not in full_name:
217
+ # Check if table_name matches the part after schema prefix
218
+ parts = full_name.split("_", 1)
219
+ if len(parts) > 1 and parts[1] == table_name:
220
+ return table
221
+ # Also check if it ends with the table name
222
+ if full_name.endswith(f"_{table_name}"):
223
+ return table
224
+
225
+ raise KeyError(f"Table {table_name} not found")
226
+
227
+ def get_orm_class(self, table_name: str) -> Any | None:
228
+ """Get the ORM class for a table by name.
229
+
230
+ Args:
231
+ table_name: Table name, with or without schema prefix.
232
+
233
+ Returns:
234
+ SQLAlchemy ORM class for the table.
235
+
236
+ Raises:
237
+ KeyError: If table not found.
238
+ """
239
+ sql_table = self.find_table(table_name)
240
+ return self.get_orm_class_for_table(sql_table)
241
+
242
+ def get_orm_class_for_table(self, table: SQLTable | DerivaTable | str) -> Any | None:
243
+ """Get the ORM class for a table.
244
+
245
+ Args:
246
+ table: SQLAlchemy Table, Deriva Table, or table name.
247
+
248
+ Returns:
249
+ SQLAlchemy ORM class, or None if not found.
250
+ """
251
+ if isinstance(table, DerivaTable):
252
+ # Try schema.table format first (file-based), then schema_table (in-memory)
253
+ table_key = f"{table.schema.name}.{table.name}"
254
+ table = self.metadata.tables.get(table_key)
255
+ if table is None and not self._use_schemas:
256
+ # Try underscore format for in-memory databases
257
+ table_key = f"{table.schema.name}_{table.name}".replace("-", "_")
258
+ table = self.metadata.tables.get(table_key)
259
+ if isinstance(table, str):
260
+ table = self.find_table(table)
261
+ if table is None:
262
+ return None
263
+
264
+ for mapper in self.Base.registry.mappers:
265
+ if mapper.persist_selectable is table or table in mapper.tables:
266
+ return mapper.class_
267
+ return None
268
+
269
+ def get_table_contents(self, table: str) -> Generator[dict[str, Any], None, None]:
270
+ """Retrieve all rows from a table as dictionaries.
271
+
272
+ Args:
273
+ table: Table name (with or without schema prefix).
274
+
275
+ Yields:
276
+ Dictionary for each row with column names as keys.
277
+ """
278
+ sql_table = self.find_table(table)
279
+ with self.engine.connect() as conn:
280
+ result = conn.execute(select(sql_table))
281
+ for row in result.mappings():
282
+ yield dict(row)
283
+
284
+ @staticmethod
285
+ def is_association_table(
286
+ table_class,
287
+ min_arity: int = 2,
288
+ max_arity: int = 2,
289
+ unqualified: bool = True,
290
+ pure: bool = True,
291
+ no_overlap: bool = True,
292
+ return_fkeys: bool = False,
293
+ ):
294
+ """Check if an ORM class represents an association table.
295
+
296
+ An association table links two or more tables through foreign keys,
297
+ with a composite unique key covering those foreign keys.
298
+
299
+ Args:
300
+ table_class: SQLAlchemy ORM class to check.
301
+ min_arity: Minimum number of foreign keys (default 2).
302
+ max_arity: Maximum number of foreign keys (default 2).
303
+ unqualified: If True, reject associations with extra key columns.
304
+ pure: If True, reject associations with extra non-key columns.
305
+ no_overlap: If True, reject associations with shared FK columns.
306
+ return_fkeys: If True, return the foreign keys instead of arity.
307
+
308
+ Returns:
309
+ If return_fkeys=False: Integer arity if association, False otherwise.
310
+ If return_fkeys=True: Set of foreign keys if association, False otherwise.
311
+ """
312
+ if min_arity < 2:
313
+ raise ValueError("An association cannot have arity < 2")
314
+ if max_arity is not None and max_arity < min_arity:
315
+ raise ValueError("max_arity cannot be less than min_arity")
316
+
317
+ mapper = inspect(table_class).mapper
318
+ system_cols = {"RID", "RCT", "RMT", "RCB", "RMB"}
319
+
320
+ non_sys_cols = {
321
+ col.name for col in mapper.columns if col.name not in system_cols
322
+ }
323
+
324
+ unique_columns = [
325
+ {c.name for c in constraint.columns}
326
+ for constraint in inspect(table_class).local_table.constraints
327
+ if isinstance(constraint, SQLUniqueConstraint)
328
+ ]
329
+
330
+ non_sys_key_colsets = {
331
+ frozenset(uc)
332
+ for uc in unique_columns
333
+ if uc.issubset(non_sys_cols) and len(uc) > 1
334
+ }
335
+
336
+ if not non_sys_key_colsets:
337
+ return False
338
+
339
+ # Choose longest compound key
340
+ row_key = sorted(non_sys_key_colsets, key=lambda s: len(s), reverse=True)[0]
341
+ foreign_keys = list(inspect(table_class).relationships.values())
342
+
343
+ covered_fkeys = {
344
+ fkey for fkey in foreign_keys
345
+ if {c.name for c in fkey.local_columns}.issubset(row_key)
346
+ }
347
+ covered_fkey_cols = set()
348
+
349
+ if len(covered_fkeys) < min_arity:
350
+ return False
351
+ if max_arity is not None and len(covered_fkeys) > max_arity:
352
+ return False
353
+
354
+ for fkey in covered_fkeys:
355
+ fkcols = {c.name for c in fkey.local_columns}
356
+ if no_overlap and fkcols.intersection(covered_fkey_cols):
357
+ return False
358
+ covered_fkey_cols.update(fkcols)
359
+
360
+ if unqualified and row_key.difference(covered_fkey_cols):
361
+ return False
362
+
363
+ if pure and non_sys_cols.difference(row_key):
364
+ return False
365
+
366
+ return covered_fkeys if return_fkeys else len(covered_fkeys)
367
+
368
+ def get_association_class(
369
+ self,
370
+ left_cls: Type[Any],
371
+ right_cls: Type[Any],
372
+ ) -> tuple[Any, Any, Any] | None:
373
+ """Find an association class connecting two ORM classes.
374
+
375
+ Args:
376
+ left_cls: First ORM class.
377
+ right_cls: Second ORM class.
378
+
379
+ Returns:
380
+ Tuple of (association_class, left_relationship, right_relationship),
381
+ or None if no association found.
382
+ """
383
+ for _, left_rel in inspect(left_cls).relationships.items():
384
+ mid_cls = left_rel.mapper.class_
385
+ is_assoc = self.is_association_table(mid_cls, return_fkeys=True)
386
+
387
+ if not is_assoc:
388
+ continue
389
+
390
+ assoc_local_columns_left = list(is_assoc)[0].local_columns
391
+ assoc_local_columns_right = list(is_assoc)[1].local_columns
392
+
393
+ found_left = found_right = False
394
+
395
+ for r in inspect(left_cls).relationships.values():
396
+ remote_side = list(r.remote_side)[0]
397
+ if remote_side in assoc_local_columns_left:
398
+ found_left = r
399
+ if remote_side in assoc_local_columns_right:
400
+ found_left = r
401
+ # Swap if backwards
402
+ assoc_local_columns_left, assoc_local_columns_right = (
403
+ assoc_local_columns_right,
404
+ assoc_local_columns_left,
405
+ )
406
+
407
+ for r in inspect(right_cls).relationships.values():
408
+ remote_side = list(r.remote_side)[0]
409
+ if remote_side in assoc_local_columns_right:
410
+ found_right = r
411
+
412
+ if found_left and found_right:
413
+ return mid_cls, found_left.class_attribute, found_right.class_attribute
414
+
415
+ return None
416
+
417
+ def dispose(self) -> None:
418
+ """Dispose of SQLAlchemy resources.
419
+
420
+ Call this when done with the database to properly clean up connections.
421
+ After calling dispose(), the instance should not be used further.
422
+ """
423
+ if self._disposed:
424
+ return
425
+
426
+ if hasattr(self, "Base") and self.Base is not None:
427
+ self.Base.registry.dispose()
428
+ if hasattr(self, "engine") and self.engine is not None:
429
+ self.engine.dispose()
430
+
431
+ self._disposed = True
432
+
433
+ def __del__(self) -> None:
434
+ """Cleanup resources when garbage collected."""
435
+ self.dispose()
436
+
437
+ def __enter__(self) -> "SchemaORM":
438
+ """Context manager entry."""
439
+ return self
440
+
441
+ def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
442
+ """Context manager exit - dispose resources."""
443
+ self.dispose()
444
+ return False
445
+
446
+
447
+ # =============================================================================
448
+ # SchemaBuilder - Creates ORM from Deriva Model
449
+ # =============================================================================
450
+
451
+ class SchemaBuilder:
452
+ """Creates SQLAlchemy ORM from a Deriva catalog model.
453
+
454
+ Phase 1 of the two-phase database creation pattern. This class handles
455
+ only schema/ORM creation - no data loading.
456
+
457
+ The Model can come from either a live catalog or a schema.json file:
458
+ - From catalog: model = catalog.getCatalogModel()
459
+ - From file: model = Model.fromfile("file-system", "path/to/schema.json")
460
+
461
+ Example:
462
+ # Create ORM from catalog model
463
+ model = catalog.getCatalogModel()
464
+ builder = SchemaBuilder(model, schemas=['domain', 'deriva-ml'])
465
+ orm = builder.build()
466
+
467
+ # Create ORM from schema file
468
+ model = Model.fromfile("file-system", "schema.json")
469
+ builder = SchemaBuilder(model, schemas=['domain'], database_path="local.db")
470
+ orm = builder.build()
471
+
472
+ # Use the ORM
473
+ ImageClass = orm.get_orm_class("Image")
474
+ with Session(orm.engine) as session:
475
+ images = session.query(ImageClass).all()
476
+
477
+ # Clean up
478
+ orm.dispose()
479
+ """
480
+
481
+ # Type mapping from ERMrest to SQLAlchemy
482
+ _TYPE_MAP = {
483
+ "boolean": ERMRestBoolean,
484
+ "date": StringToDate,
485
+ "float4": StringToFloat,
486
+ "float8": StringToFloat,
487
+ "int2": StringToInteger,
488
+ "int4": StringToInteger,
489
+ "int8": StringToInteger,
490
+ "json": JSON,
491
+ "jsonb": JSON,
492
+ "timestamptz": StringToDateTime,
493
+ "timestamp": StringToDateTime,
494
+ }
495
+
496
+ def __init__(
497
+ self,
498
+ model: Model,
499
+ schemas: list[str],
500
+ database_path: Path | str = ":memory:",
501
+ ):
502
+ """Initialize the schema builder.
503
+
504
+ Args:
505
+ model: ERMrest Model object (from catalog or schema.json file).
506
+ schemas: List of schema names to include in the ORM.
507
+ database_path: Path to SQLite database file. Use ":memory:" for
508
+ in-memory database (default). If a Path or string is provided,
509
+ separate .db files will be created for each schema.
510
+ """
511
+ self.model = model
512
+ self.schemas = schemas
513
+ self.database_path = Path(database_path) if database_path != ":memory:" else database_path
514
+
515
+ # Will be set during build()
516
+ self.engine: Engine | None = None
517
+ self.metadata: MetaData | None = None
518
+ self.Base: AutomapBase | None = None
519
+ self._class_prefix: str = ""
520
+
521
+ @staticmethod
522
+ def _sql_type(deriva_type: DerivaType) -> TypeEngine:
523
+ """Map ERMrest type to SQLAlchemy type with CSV string conversion.
524
+
525
+ Args:
526
+ deriva_type: ERMrest type object.
527
+
528
+ Returns:
529
+ SQLAlchemy type class.
530
+ """
531
+ return SchemaBuilder._TYPE_MAP.get(deriva_type.typename, String)
532
+
533
+ def _is_key_column(self, column: DerivaColumn, table: DerivaTable) -> bool:
534
+ """Check if column is the primary key (RID)."""
535
+ return column in [key.unique_columns[0] for key in table.keys] and column.name == "RID"
536
+
537
+ def build(self) -> SchemaORM:
538
+ """Build the SQLAlchemy ORM structure.
539
+
540
+ Creates SQLite tables from the ERMrest schema and generates
541
+ ORM classes via SQLAlchemy automap.
542
+
543
+ Returns:
544
+ SchemaORM object containing engine, metadata, Base, and utilities.
545
+
546
+ Note:
547
+ In-memory databases (database_path=":memory:") do not support
548
+ SQLite schema attachments, so all tables will be created in a
549
+ single database without schema prefixes in table names.
550
+ """
551
+ # Create unique prefix for ORM class names
552
+ self._class_prefix = f"_{id(self)}_"
553
+
554
+ # Determine if we're using in-memory or file-based database
555
+ self._use_schemas = self.database_path != ":memory:"
556
+
557
+ # Create engine
558
+ if self.database_path == ":memory:":
559
+ self.engine = create_engine("sqlite:///:memory:", future=True)
560
+ else:
561
+ # Ensure the database path exists
562
+ if isinstance(self.database_path, Path):
563
+ if self.database_path.suffix == ".db":
564
+ # Single file path
565
+ self.database_path.parent.mkdir(parents=True, exist_ok=True)
566
+ main_db = self.database_path
567
+ else:
568
+ # Directory path
569
+ self.database_path.mkdir(parents=True, exist_ok=True)
570
+ main_db = self.database_path / "main.db"
571
+ else:
572
+ main_db = Path(self.database_path)
573
+ main_db.parent.mkdir(parents=True, exist_ok=True)
574
+
575
+ self.engine = create_engine(f"sqlite:///{main_db.resolve()}", future=True)
576
+
577
+ # Attach schema-specific databases
578
+ event.listen(self.engine, "connect", self._attach_schemas)
579
+
580
+ self.metadata = MetaData()
581
+ self.Base = automap_base(metadata=self.metadata)
582
+
583
+ # Build the schema
584
+ self._create_tables()
585
+
586
+ logger.info(
587
+ "Built ORM for schemas %s with %d tables",
588
+ self.schemas,
589
+ len(self.metadata.tables),
590
+ )
591
+
592
+ return SchemaORM(
593
+ engine=self.engine,
594
+ metadata=self.metadata,
595
+ Base=self.Base,
596
+ model=self.model,
597
+ schemas=self.schemas,
598
+ class_prefix=self._class_prefix,
599
+ use_schemas=self._use_schemas,
600
+ )
601
+
602
+ def _attach_schemas(self, dbapi_conn, _conn_record):
603
+ """Attach schema-specific SQLite databases."""
604
+ cur = dbapi_conn.cursor()
605
+ db_dir = self.database_path if self.database_path.is_dir() else self.database_path.parent
606
+ for schema in self.schemas:
607
+ schema_file = (db_dir / f"{schema}.db").resolve()
608
+ cur.execute(f"ATTACH DATABASE '{schema_file}' AS '{schema}'")
609
+ cur.close()
610
+
611
+ def _create_tables(self) -> None:
612
+ """Create SQLite tables from the ERMrest schema."""
613
+
614
+ def col(model, name: str):
615
+ """Get column from ORM class, handling both attribute and table column access."""
616
+ try:
617
+ return getattr(model, name).property.columns[0]
618
+ except AttributeError:
619
+ return model.__table__.c[name]
620
+
621
+ def guess_attr_name(col_name: str) -> str:
622
+ """Generate relationship attribute name from column name."""
623
+ return col_name[:-3] if col_name.lower().endswith("_id") else col_name
624
+
625
+ def make_table_name(schema_name: str, table_name: str) -> str:
626
+ """Generate table name, including schema prefix if using schemas."""
627
+ if self._use_schemas:
628
+ return f"{schema_name}.{table_name}"
629
+ else:
630
+ # For in-memory, use underscore separator to avoid conflicts
631
+ return f"{schema_name}_{table_name}"
632
+
633
+ database_tables: list[SQLTable] = []
634
+
635
+ for schema_name in self.schemas:
636
+ if schema_name not in self.model.schemas:
637
+ logger.warning(f"Schema {schema_name} not found in model")
638
+ continue
639
+
640
+ for table in self.model.schemas[schema_name].tables.values():
641
+ database_columns: list[SQLColumn] = []
642
+
643
+ for c in table.columns:
644
+ database_column = SQLColumn(
645
+ name=c.name,
646
+ type_=self._sql_type(c.type),
647
+ comment=c.comment,
648
+ default=c.default,
649
+ primary_key=self._is_key_column(c, table),
650
+ nullable=c.nullok,
651
+ )
652
+ database_columns.append(database_column)
653
+
654
+ # Use schema prefix only for file-based databases
655
+ if self._use_schemas:
656
+ database_table = SQLTable(
657
+ table.name, self.metadata, *database_columns, schema=schema_name
658
+ )
659
+ else:
660
+ # For in-memory, embed schema in table name
661
+ full_name = f"{schema_name}_{table.name}".replace("-", "_")
662
+ database_table = SQLTable(
663
+ full_name, self.metadata, *database_columns
664
+ )
665
+
666
+ # Add unique constraints
667
+ for key in table.keys:
668
+ key_columns = [c.name for c in key.unique_columns]
669
+ database_table.append_constraint(
670
+ SQLUniqueConstraint(*key_columns, name=key.name[1])
671
+ )
672
+
673
+ # Add foreign key constraints (within same schema only for now)
674
+ for fk in table.foreign_keys:
675
+ if fk.pk_table.schema.name not in self.schemas:
676
+ continue
677
+ if fk.pk_table.schema.name != schema_name:
678
+ continue
679
+
680
+ # Build reference column names
681
+ if self._use_schemas:
682
+ refcols = [
683
+ f"{schema_name}.{c.table.name}.{c.name}"
684
+ for c in fk.referenced_columns
685
+ ]
686
+ else:
687
+ # For in-memory, use the embedded schema name
688
+ ref_table_name = f"{schema_name}_{fk.pk_table.name}".replace("-", "_")
689
+ refcols = [
690
+ f"{ref_table_name}.{c.name}"
691
+ for c in fk.referenced_columns
692
+ ]
693
+
694
+ database_table.append_constraint(
695
+ SQLForeignKeyConstraint(
696
+ columns=[f"{c.name}" for c in fk.foreign_key_columns],
697
+ refcolumns=refcols,
698
+ name=fk.name[1],
699
+ comment=fk.comment,
700
+ )
701
+ )
702
+
703
+ database_tables.append(database_table)
704
+
705
+ # Create all tables
706
+ with self.engine.begin() as conn:
707
+ self.metadata.create_all(conn, tables=database_tables)
708
+
709
+ # Configure ORM class naming
710
+ def name_for_scalar_relationship(_base, local_cls, referred_cls, constraint):
711
+ cols = list(constraint.columns) if constraint is not None else []
712
+ if len(cols) == 1:
713
+ name = cols[0].key
714
+ if name in {c.key for c in local_cls.__table__.columns}:
715
+ name += "_rel"
716
+ return name
717
+ return constraint.name or referred_cls.__name__.lower()
718
+
719
+ def name_for_collection_relationship(_base, local_cls, referred_cls, constraint):
720
+ backref_name = constraint.name.replace("_fkey", "_collection")
721
+ return backref_name or (referred_cls.__name__.lower() + "_collection")
722
+
723
+ def classname_for_table(_base, tablename, table):
724
+ return self._class_prefix + tablename.replace(".", "_").replace("-", "_")
725
+
726
+ # Build ORM mappings
727
+ self.Base.prepare(
728
+ self.engine,
729
+ name_for_scalar_relationship=name_for_scalar_relationship,
730
+ name_for_collection_relationship=name_for_collection_relationship,
731
+ classname_for_table=classname_for_table,
732
+ reflect=True,
733
+ )
734
+
735
+ # Add cross-schema relationships
736
+ for schema_name in self.schemas:
737
+ if schema_name not in self.model.schemas:
738
+ continue
739
+
740
+ for table in self.model.schemas[schema_name].tables.values():
741
+ for fk in table.foreign_keys:
742
+ if fk.pk_table.schema.name not in self.schemas:
743
+ continue
744
+ if fk.pk_table.schema.name == schema_name:
745
+ continue
746
+
747
+ table_name = make_table_name(schema_name, table.name)
748
+ table_class = self._get_orm_class_by_name(table_name)
749
+ foreign_key_column_name = fk.foreign_key_columns[0].name
750
+ foreign_key_column = col(table_class, foreign_key_column_name)
751
+
752
+ referenced_table_name = make_table_name(fk.pk_table.schema.name, fk.pk_table.name)
753
+ referenced_class = self._get_orm_class_by_name(referenced_table_name)
754
+ referenced_column = col(referenced_class, fk.referenced_columns[0].name)
755
+
756
+ relationship_attr = guess_attr_name(foreign_key_column_name)
757
+ backref_attr = fk.name[1].replace("_fkey", "_collection")
758
+
759
+ # Check if relationship already exists
760
+ existing_attr = getattr(table_class, relationship_attr, None)
761
+ from sqlalchemy.orm import RelationshipProperty
762
+ from sqlalchemy.orm.attributes import InstrumentedAttribute
763
+
764
+ is_relationship = isinstance(existing_attr, InstrumentedAttribute) and isinstance(
765
+ existing_attr.property, RelationshipProperty
766
+ )
767
+ if not is_relationship:
768
+ setattr(
769
+ table_class,
770
+ relationship_attr,
771
+ relationship(
772
+ referenced_class,
773
+ foreign_keys=[foreign_key_column],
774
+ primaryjoin=foreign(foreign_key_column) == referenced_column,
775
+ backref=backref(backref_attr, viewonly=True),
776
+ viewonly=True,
777
+ ),
778
+ )
779
+
780
+ # Configure mappers
781
+ self.Base.registry.configure()
782
+
783
+ def _get_orm_class_by_name(self, table_name: str) -> Any | None:
784
+ """Get ORM class by table name (internal use during build).
785
+
786
+ Handles both schema.table format (file-based) and schema_table format (in-memory).
787
+ """
788
+ # Try exact match first
789
+ if table_name in self.metadata.tables:
790
+ sql_table = self.metadata.tables[table_name]
791
+ else:
792
+ # For in-memory databases, table names use underscore separator
793
+ # Try converting schema.table to schema_table format
794
+ if "." in table_name and not self._use_schemas:
795
+ converted_name = table_name.replace(".", "_").replace("-", "_")
796
+ if converted_name in self.metadata.tables:
797
+ sql_table = self.metadata.tables[converted_name]
798
+ else:
799
+ sql_table = None
800
+ else:
801
+ # Try matching just the table name part
802
+ sql_table = None
803
+ for full_name, table in self.metadata.tables.items():
804
+ # Handle both . and _ separators
805
+ table_part = full_name.split(".")[-1] if "." in full_name else full_name.split("_", 1)[-1] if "_" in full_name else full_name
806
+ if table_part == table_name or full_name.endswith(f"_{table_name}"):
807
+ sql_table = table
808
+ break
809
+
810
+ if sql_table is None:
811
+ raise KeyError(f"Table {table_name} not found")
812
+
813
+ for mapper in self.Base.registry.mappers:
814
+ if mapper.persist_selectable is sql_table or sql_table in mapper.tables:
815
+ return mapper.class_
816
+ return None