deriva-ml 1.17.14__py3-none-any.whl → 1.17.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/__init__.py +2 -2
- deriva_ml/asset/asset.py +0 -4
- deriva_ml/catalog/__init__.py +6 -0
- deriva_ml/catalog/clone.py +1591 -38
- deriva_ml/catalog/localize.py +66 -29
- deriva_ml/core/base.py +12 -9
- deriva_ml/core/definitions.py +13 -12
- deriva_ml/core/ermrest.py +11 -12
- deriva_ml/core/mixins/annotation.py +2 -2
- deriva_ml/core/mixins/asset.py +3 -3
- deriva_ml/core/mixins/dataset.py +3 -3
- deriva_ml/core/mixins/execution.py +1 -0
- deriva_ml/core/mixins/feature.py +2 -2
- deriva_ml/core/mixins/file.py +2 -2
- deriva_ml/core/mixins/path_builder.py +2 -2
- deriva_ml/core/mixins/rid_resolution.py +2 -2
- deriva_ml/core/mixins/vocabulary.py +2 -2
- deriva_ml/core/mixins/workflow.py +3 -3
- deriva_ml/dataset/catalog_graph.py +3 -4
- deriva_ml/dataset/dataset.py +5 -3
- deriva_ml/dataset/dataset_bag.py +0 -2
- deriva_ml/dataset/upload.py +2 -2
- deriva_ml/demo_catalog.py +0 -1
- deriva_ml/execution/__init__.py +8 -8
- deriva_ml/execution/base_config.py +2 -2
- deriva_ml/execution/execution.py +5 -3
- deriva_ml/execution/execution_record.py +0 -1
- deriva_ml/execution/model_protocol.py +1 -1
- deriva_ml/execution/multirun_config.py +0 -1
- deriva_ml/execution/runner.py +3 -3
- deriva_ml/experiment/experiment.py +3 -3
- deriva_ml/feature.py +2 -2
- deriva_ml/interfaces.py +2 -2
- deriva_ml/model/__init__.py +45 -24
- deriva_ml/model/annotations.py +0 -1
- deriva_ml/model/catalog.py +3 -2
- deriva_ml/model/data_loader.py +330 -0
- deriva_ml/model/data_sources.py +439 -0
- deriva_ml/model/database.py +216 -32
- deriva_ml/model/fk_orderer.py +379 -0
- deriva_ml/model/handles.py +1 -1
- deriva_ml/model/schema_builder.py +816 -0
- deriva_ml/run_model.py +3 -3
- deriva_ml/schema/annotations.py +2 -1
- deriva_ml/schema/create_schema.py +1 -1
- deriva_ml/schema/validation.py +1 -1
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/METADATA +1 -1
- deriva_ml-1.17.16.dist-info/RECORD +81 -0
- deriva_ml-1.17.14.dist-info/RECORD +0 -77
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/WHEEL +0 -0
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/entry_points.txt +0 -0
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/top_level.txt +0 -0
deriva_ml/execution/execution.py
CHANGED
|
@@ -39,12 +39,16 @@ import time
|
|
|
39
39
|
from collections import defaultdict
|
|
40
40
|
from datetime import datetime
|
|
41
41
|
from pathlib import Path
|
|
42
|
-
from typing import Any, Callable, Iterable, List
|
|
42
|
+
from typing import TYPE_CHECKING, Any, Callable, Iterable, List
|
|
43
43
|
|
|
44
44
|
from deriva.core import format_exception
|
|
45
|
+
|
|
46
|
+
if TYPE_CHECKING:
|
|
47
|
+
from deriva_ml.asset.asset import Asset
|
|
45
48
|
from deriva.core.hatrac_store import HatracStore
|
|
46
49
|
from pydantic import ConfigDict, validate_call
|
|
47
50
|
|
|
51
|
+
from deriva_ml.asset.aux_classes import AssetFilePath
|
|
48
52
|
from deriva_ml.core.base import DerivaML
|
|
49
53
|
from deriva_ml.core.definitions import (
|
|
50
54
|
DRY_RUN_RID,
|
|
@@ -58,7 +62,6 @@ from deriva_ml.core.definitions import (
|
|
|
58
62
|
UploadProgress,
|
|
59
63
|
)
|
|
60
64
|
from deriva_ml.core.exceptions import DerivaMLException
|
|
61
|
-
from deriva_ml.asset.aux_classes import AssetFilePath
|
|
62
65
|
from deriva_ml.dataset.aux_classes import DatasetSpec, DatasetVersion
|
|
63
66
|
from deriva_ml.dataset.dataset import Dataset
|
|
64
67
|
from deriva_ml.dataset.dataset_bag import DatasetBag
|
|
@@ -1170,7 +1173,6 @@ class Execution:
|
|
|
1170
1173
|
return self._execution_record.list_assets(asset_role=asset_role)
|
|
1171
1174
|
|
|
1172
1175
|
# Fallback for dry_run mode
|
|
1173
|
-
from deriva_ml.asset.asset import Asset
|
|
1174
1176
|
|
|
1175
1177
|
pb = self._ml_object.pathBuilder()
|
|
1176
1178
|
asset_exec = pb.schemas[self._ml_object.ml_schema].Execution_Asset_Execution
|
|
@@ -533,7 +533,6 @@ class ExecutionRecord(BaseModel):
|
|
|
533
533
|
>>> for asset in record.list_assets(asset_role="Output"):
|
|
534
534
|
... print(f"Output Asset: {asset.asset_rid}")
|
|
535
535
|
"""
|
|
536
|
-
from deriva_ml.asset.asset import Asset
|
|
537
536
|
|
|
538
537
|
if self._ml_instance is None:
|
|
539
538
|
raise DerivaMLException("ExecutionRecord is not bound to a catalog")
|
|
@@ -82,7 +82,7 @@ The protocol uses @runtime_checkable, so isinstance() checks work at runtime.
|
|
|
82
82
|
|
|
83
83
|
from __future__ import annotations
|
|
84
84
|
|
|
85
|
-
from typing import
|
|
85
|
+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
|
86
86
|
|
|
87
87
|
if TYPE_CHECKING:
|
|
88
88
|
from deriva_ml import DerivaML
|
deriva_ml/execution/runner.py
CHANGED
|
@@ -145,7 +145,7 @@ from __future__ import annotations
|
|
|
145
145
|
import atexit
|
|
146
146
|
import logging
|
|
147
147
|
from pathlib import Path
|
|
148
|
-
from typing import Any, TypeVar
|
|
148
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
149
149
|
|
|
150
150
|
from hydra.core.hydra_config import HydraConfig
|
|
151
151
|
from hydra_zen import builds
|
|
@@ -153,9 +153,9 @@ from hydra_zen import builds
|
|
|
153
153
|
if TYPE_CHECKING:
|
|
154
154
|
from deriva_ml import DerivaML
|
|
155
155
|
from deriva_ml.core.config import DerivaMLConfig
|
|
156
|
-
from deriva_ml.dataset import DatasetSpec
|
|
157
|
-
from deriva_ml.execution import ExecutionConfiguration, Workflow
|
|
158
156
|
from deriva_ml.core.definitions import RID
|
|
157
|
+
from deriva_ml.dataset import DatasetSpec
|
|
158
|
+
from deriva_ml.execution import Workflow
|
|
159
159
|
|
|
160
160
|
|
|
161
161
|
# Type variable for DerivaML and its subclasses
|
|
@@ -26,10 +26,10 @@ import yaml
|
|
|
26
26
|
from deriva.core.hatrac_store import HatracStore
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from deriva_ml.core.base import DerivaML
|
|
30
|
-
from deriva_ml.execution.execution_record import ExecutionRecord
|
|
31
29
|
from deriva_ml.asset.asset import Asset
|
|
30
|
+
from deriva_ml.core.base import DerivaML
|
|
32
31
|
from deriva_ml.dataset.dataset import Dataset
|
|
32
|
+
from deriva_ml.execution.execution_record import ExecutionRecord
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
@dataclass
|
|
@@ -403,7 +403,7 @@ class Experiment:
|
|
|
403
403
|
>>> exp = ml.lookup_experiment("47BE")
|
|
404
404
|
>>> exp.display_markdown()
|
|
405
405
|
"""
|
|
406
|
-
from IPython.display import
|
|
406
|
+
from IPython.display import Markdown, display
|
|
407
407
|
|
|
408
408
|
display(Markdown(self.to_markdown(show_datasets, show_assets)))
|
|
409
409
|
|
deriva_ml/feature.py
CHANGED
|
@@ -12,12 +12,12 @@ Typical usage example:
|
|
|
12
12
|
>>> record = FeatureClass(value="high", confidence=0.95)
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
# Deriva imports - use importlib to avoid shadowing by local 'deriva.py' files
|
|
16
|
+
import importlib
|
|
15
17
|
from pathlib import Path
|
|
16
18
|
from types import UnionType
|
|
17
19
|
from typing import TYPE_CHECKING, ClassVar, Optional, Type
|
|
18
20
|
|
|
19
|
-
# Deriva imports - use importlib to avoid shadowing by local 'deriva.py' files
|
|
20
|
-
import importlib
|
|
21
21
|
_ermrest_model = importlib.import_module("deriva.core.ermrest_model")
|
|
22
22
|
Column = _ermrest_model.Column
|
|
23
23
|
FindAssociationResult = _ermrest_model.FindAssociationResult
|
deriva_ml/interfaces.py
CHANGED
|
@@ -59,13 +59,13 @@ Implementation Notes
|
|
|
59
59
|
|
|
60
60
|
from __future__ import annotations
|
|
61
61
|
|
|
62
|
+
# Deriva imports - use importlib to avoid shadowing by local 'deriva.py' files
|
|
63
|
+
import importlib
|
|
62
64
|
from pathlib import Path
|
|
63
65
|
from typing import TYPE_CHECKING, Any, Generator, Iterable, Protocol, Self, runtime_checkable
|
|
64
66
|
|
|
65
67
|
import pandas as pd
|
|
66
68
|
|
|
67
|
-
# Deriva imports - use importlib to avoid shadowing by local 'deriva.py' files
|
|
68
|
-
import importlib
|
|
69
69
|
_deriva_core = importlib.import_module("deriva.core")
|
|
70
70
|
_datapath = importlib.import_module("deriva.core.datapath")
|
|
71
71
|
_ermrest_catalog = importlib.import_module("deriva.core.ermrest_catalog")
|
deriva_ml/model/__init__.py
CHANGED
|
@@ -3,47 +3,60 @@
|
|
|
3
3
|
This module provides catalog and database model classes, as well as
|
|
4
4
|
handle wrappers for ERMrest model objects and annotation builders.
|
|
5
5
|
|
|
6
|
+
Key components:
|
|
7
|
+
- DerivaModel: Schema analysis utilities
|
|
8
|
+
- DatabaseModel: SQLite database from BDBag
|
|
9
|
+
- SchemaBuilder/SchemaORM: Create ORM from Deriva Model (Phase 1)
|
|
10
|
+
- DataLoader: Fill database from data source (Phase 2)
|
|
11
|
+
- DataSource: Protocol for data sources (BagDataSource, CatalogDataSource)
|
|
12
|
+
- ForeignKeyOrderer: Compute FK-safe insertion order
|
|
13
|
+
|
|
6
14
|
Lazy imports are used for DatabaseModel and DerivaMLDatabase to avoid
|
|
7
15
|
circular imports with the dataset module.
|
|
8
16
|
"""
|
|
9
17
|
|
|
10
|
-
from deriva_ml.model.catalog import DerivaModel
|
|
11
|
-
from deriva_ml.model.handles import ColumnHandle, TableHandle
|
|
12
|
-
|
|
13
18
|
# Annotation builders - import the most common ones for convenience
|
|
14
19
|
from deriva_ml.model.annotations import (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
20
|
+
CONTEXT_COMPACT,
|
|
21
|
+
# Context constants
|
|
22
|
+
CONTEXT_DEFAULT,
|
|
23
|
+
CONTEXT_DETAILED,
|
|
24
|
+
CONTEXT_ENTRY,
|
|
25
|
+
CONTEXT_FILTER,
|
|
26
|
+
Aggregate,
|
|
27
|
+
ArrayUxMode,
|
|
21
28
|
ColumnDisplay,
|
|
22
29
|
ColumnDisplayOptions,
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
PseudoColumnDisplay,
|
|
30
|
+
# Builders
|
|
31
|
+
Display,
|
|
26
32
|
Facet,
|
|
27
33
|
FacetList,
|
|
28
34
|
FacetRange,
|
|
29
|
-
|
|
30
|
-
NameStyle,
|
|
35
|
+
FacetUxMode,
|
|
31
36
|
# FK helpers
|
|
32
37
|
InboundFK,
|
|
38
|
+
NameStyle,
|
|
33
39
|
OutboundFK,
|
|
34
|
-
|
|
40
|
+
PreFormat,
|
|
41
|
+
PseudoColumn,
|
|
42
|
+
PseudoColumnDisplay,
|
|
43
|
+
SortKey,
|
|
44
|
+
TableDisplay,
|
|
45
|
+
TableDisplayOptions,
|
|
35
46
|
# Enums
|
|
36
47
|
TemplateEngine,
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
# Context constants
|
|
41
|
-
CONTEXT_DEFAULT,
|
|
42
|
-
CONTEXT_COMPACT,
|
|
43
|
-
CONTEXT_DETAILED,
|
|
44
|
-
CONTEXT_ENTRY,
|
|
45
|
-
CONTEXT_FILTER,
|
|
48
|
+
VisibleColumns,
|
|
49
|
+
VisibleForeignKeys,
|
|
50
|
+
fk_constraint,
|
|
46
51
|
)
|
|
52
|
+
from deriva_ml.model.catalog import DerivaModel
|
|
53
|
+
from deriva_ml.model.data_loader import DataLoader
|
|
54
|
+
from deriva_ml.model.data_sources import BagDataSource, CatalogDataSource, DataSource
|
|
55
|
+
from deriva_ml.model.fk_orderer import ForeignKeyOrderer
|
|
56
|
+
from deriva_ml.model.handles import ColumnHandle, TableHandle
|
|
57
|
+
|
|
58
|
+
# Two-phase ORM creation components
|
|
59
|
+
from deriva_ml.model.schema_builder import SchemaBuilder, SchemaORM
|
|
47
60
|
|
|
48
61
|
__all__ = [
|
|
49
62
|
# Core classes
|
|
@@ -52,6 +65,14 @@ __all__ = [
|
|
|
52
65
|
"DerivaMLDatabase",
|
|
53
66
|
"TableHandle",
|
|
54
67
|
"ColumnHandle",
|
|
68
|
+
# Two-phase ORM creation
|
|
69
|
+
"SchemaBuilder",
|
|
70
|
+
"SchemaORM",
|
|
71
|
+
"DataSource",
|
|
72
|
+
"BagDataSource",
|
|
73
|
+
"CatalogDataSource",
|
|
74
|
+
"DataLoader",
|
|
75
|
+
"ForeignKeyOrderer",
|
|
55
76
|
# Annotation builders
|
|
56
77
|
"Display",
|
|
57
78
|
"VisibleColumns",
|
deriva_ml/model/annotations.py
CHANGED
|
@@ -131,7 +131,6 @@ from dataclasses import dataclass, field
|
|
|
131
131
|
from enum import Enum
|
|
132
132
|
from typing import Any, Literal
|
|
133
133
|
|
|
134
|
-
|
|
135
134
|
# =============================================================================
|
|
136
135
|
# Enums for constrained values
|
|
137
136
|
# =============================================================================
|
deriva_ml/model/catalog.py
CHANGED
|
@@ -7,13 +7,14 @@ ML-specific functionality. It handles schema management, feature definitions, an
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
+
# Deriva imports - use importlib to avoid shadowing by local 'deriva.py' files
|
|
11
|
+
import importlib
|
|
12
|
+
|
|
10
13
|
# Standard library imports
|
|
11
14
|
from collections import Counter, defaultdict
|
|
12
15
|
from graphlib import CycleError, TopologicalSorter
|
|
13
16
|
from typing import Any, Callable, Final, Iterable, NewType, TypeAlias
|
|
14
17
|
|
|
15
|
-
# Deriva imports - use importlib to avoid shadowing by local 'deriva.py' files
|
|
16
|
-
import importlib
|
|
17
18
|
_ermrest_catalog = importlib.import_module("deriva.core.ermrest_catalog")
|
|
18
19
|
_ermrest_model = importlib.import_module("deriva.core.ermrest_model")
|
|
19
20
|
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""Load data into SQLite database with FK ordering.
|
|
2
|
+
|
|
3
|
+
This module provides the DataLoader class which loads data from a
|
|
4
|
+
DataSource into a SchemaORM database. It handles:
|
|
5
|
+
|
|
6
|
+
- Automatic FK dependency ordering
|
|
7
|
+
- Batch inserts with conflict handling
|
|
8
|
+
- Progress tracking
|
|
9
|
+
|
|
10
|
+
This is Phase 2 of the two-phase pattern:
|
|
11
|
+
1. Phase 1 (SchemaBuilder): Create ORM structure without data
|
|
12
|
+
2. Phase 2 (DataLoader): Fill database from a data source
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
# Phase 1: Create ORM
|
|
16
|
+
orm = SchemaBuilder(model, schemas).build()
|
|
17
|
+
|
|
18
|
+
# Phase 2: Fill with data
|
|
19
|
+
source = BagDataSource(bag_path)
|
|
20
|
+
loader = DataLoader(orm, source)
|
|
21
|
+
counts = loader.load_tables(['Subject', 'Image', 'Diagnosis'])
|
|
22
|
+
print(f"Loaded {sum(counts.values())} rows")
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import logging
|
|
28
|
+
from typing import Any, Callable
|
|
29
|
+
|
|
30
|
+
from deriva.core.ermrest_model import Table as DerivaTable
|
|
31
|
+
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
32
|
+
|
|
33
|
+
from .data_sources import DataSource
|
|
34
|
+
from .fk_orderer import ForeignKeyOrderer
|
|
35
|
+
from .schema_builder import SchemaORM
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DataLoader:
|
|
41
|
+
"""Loads data into a database with FK ordering.
|
|
42
|
+
|
|
43
|
+
Phase 2 of the two-phase database creation pattern. Takes a
|
|
44
|
+
SchemaORM (from Phase 1) and populates it from a DataSource.
|
|
45
|
+
|
|
46
|
+
Automatically orders tables by FK dependencies to ensure
|
|
47
|
+
referential integrity during loading.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
# Phase 1: Create ORM
|
|
51
|
+
orm = SchemaBuilder(model, schemas).build()
|
|
52
|
+
|
|
53
|
+
# Phase 2: Fill with data from bag
|
|
54
|
+
source = BagDataSource(bag_path)
|
|
55
|
+
loader = DataLoader(orm, source)
|
|
56
|
+
counts = loader.load_tables() # All tables
|
|
57
|
+
print(f"Loaded {sum(counts.values())} total rows")
|
|
58
|
+
|
|
59
|
+
# Or load specific tables
|
|
60
|
+
counts = loader.load_tables(['Subject', 'Image'])
|
|
61
|
+
|
|
62
|
+
# With progress callback
|
|
63
|
+
def on_progress(table, count, total):
|
|
64
|
+
print(f"Loaded {table}: {count} rows")
|
|
65
|
+
loader.load_tables(progress_callback=on_progress)
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
schema_orm: SchemaORM,
|
|
71
|
+
data_source: DataSource,
|
|
72
|
+
):
|
|
73
|
+
"""Initialize the loader.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
schema_orm: ORM structure from SchemaBuilder.
|
|
77
|
+
data_source: Source of data to load (BagDataSource, CatalogDataSource, etc.).
|
|
78
|
+
"""
|
|
79
|
+
self.orm = schema_orm
|
|
80
|
+
self.source = data_source
|
|
81
|
+
self.orderer = ForeignKeyOrderer(
|
|
82
|
+
schema_orm.model,
|
|
83
|
+
schema_orm.schemas,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def load_tables(
|
|
87
|
+
self,
|
|
88
|
+
tables: list[str | DerivaTable] | None = None,
|
|
89
|
+
on_conflict: str = "ignore",
|
|
90
|
+
batch_size: int = 1000,
|
|
91
|
+
progress_callback: Callable[[str, int, int], None] | None = None,
|
|
92
|
+
) -> dict[str, int]:
|
|
93
|
+
"""Load data into specified tables with FK ordering.
|
|
94
|
+
|
|
95
|
+
Tables are automatically ordered by FK dependencies to ensure
|
|
96
|
+
referenced tables are populated first.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
tables: Tables to load. If None, loads all tables that have
|
|
100
|
+
data in the source.
|
|
101
|
+
on_conflict: How to handle duplicate keys:
|
|
102
|
+
- "ignore": Skip rows with duplicate keys (default)
|
|
103
|
+
- "replace": Replace existing rows
|
|
104
|
+
- "error": Raise error on duplicates
|
|
105
|
+
batch_size: Number of rows per insert batch.
|
|
106
|
+
progress_callback: Optional callback(table_name, rows_loaded, total_tables)
|
|
107
|
+
called after each table is loaded.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Dict mapping table names to row counts loaded.
|
|
111
|
+
"""
|
|
112
|
+
# Determine tables to load
|
|
113
|
+
if tables is None:
|
|
114
|
+
# Get all tables that have data in source
|
|
115
|
+
available = set(self.source.list_available_tables())
|
|
116
|
+
# Filter to tables that exist in ORM
|
|
117
|
+
orm_tables = set(self.orm.list_tables())
|
|
118
|
+
|
|
119
|
+
# Match available tables to ORM tables
|
|
120
|
+
tables_to_load = []
|
|
121
|
+
for orm_table in orm_tables:
|
|
122
|
+
# Check both qualified and unqualified names
|
|
123
|
+
table_name = orm_table.split(".")[-1]
|
|
124
|
+
if orm_table in available or table_name in available:
|
|
125
|
+
tables_to_load.append(orm_table)
|
|
126
|
+
else:
|
|
127
|
+
tables_to_load = [
|
|
128
|
+
t if isinstance(t, str) else f"{t.schema.name}.{t.name}"
|
|
129
|
+
for t in tables
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
# Compute insertion order
|
|
133
|
+
try:
|
|
134
|
+
ordered_tables = self.orderer.get_insertion_order(tables_to_load)
|
|
135
|
+
except ValueError as e:
|
|
136
|
+
# Some tables might not be in the model, just use original order
|
|
137
|
+
logger.warning(f"Could not compute FK ordering: {e}")
|
|
138
|
+
ordered_tables = [
|
|
139
|
+
self.orderer._to_table(t) if isinstance(t, str) else t
|
|
140
|
+
for t in tables_to_load
|
|
141
|
+
if self._table_exists(t)
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
# Load in order
|
|
145
|
+
counts = {}
|
|
146
|
+
total_tables = len(ordered_tables)
|
|
147
|
+
|
|
148
|
+
for i, table in enumerate(ordered_tables):
|
|
149
|
+
table_key = f"{table.schema.name}.{table.name}"
|
|
150
|
+
|
|
151
|
+
count = self._load_table(table, on_conflict, batch_size)
|
|
152
|
+
counts[table_key] = count
|
|
153
|
+
|
|
154
|
+
if progress_callback:
|
|
155
|
+
progress_callback(table_key, count, total_tables)
|
|
156
|
+
|
|
157
|
+
if count > 0:
|
|
158
|
+
logger.info(f"Loaded {count} rows into {table_key}")
|
|
159
|
+
|
|
160
|
+
return counts
|
|
161
|
+
|
|
162
|
+
def _table_exists(self, table: str | DerivaTable) -> bool:
|
|
163
|
+
"""Check if table exists in ORM."""
|
|
164
|
+
try:
|
|
165
|
+
if isinstance(table, str):
|
|
166
|
+
self.orm.find_table(table)
|
|
167
|
+
else:
|
|
168
|
+
self.orm.find_table(f"{table.schema.name}.{table.name}")
|
|
169
|
+
return True
|
|
170
|
+
except KeyError:
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
def _load_table(
|
|
174
|
+
self,
|
|
175
|
+
table: DerivaTable,
|
|
176
|
+
on_conflict: str,
|
|
177
|
+
batch_size: int,
|
|
178
|
+
) -> int:
|
|
179
|
+
"""Load a single table.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
table: Table to load.
|
|
183
|
+
on_conflict: Conflict handling strategy.
|
|
184
|
+
batch_size: Rows per batch.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Number of rows loaded.
|
|
188
|
+
"""
|
|
189
|
+
table_key = f"{table.schema.name}.{table.name}"
|
|
190
|
+
|
|
191
|
+
# Find SQL table
|
|
192
|
+
try:
|
|
193
|
+
sql_table = self.orm.find_table(table_key)
|
|
194
|
+
except KeyError:
|
|
195
|
+
logger.warning(f"Table {table_key} not found in ORM")
|
|
196
|
+
return 0
|
|
197
|
+
|
|
198
|
+
# Check if source has data
|
|
199
|
+
if not self.source.has_table(table):
|
|
200
|
+
logger.debug(f"No data for {table_key} in source")
|
|
201
|
+
return 0
|
|
202
|
+
|
|
203
|
+
# Get data from source
|
|
204
|
+
rows_loaded = 0
|
|
205
|
+
batch = []
|
|
206
|
+
|
|
207
|
+
with self.orm.engine.begin() as conn:
|
|
208
|
+
for row in self.source.get_table_data(table):
|
|
209
|
+
batch.append(row)
|
|
210
|
+
|
|
211
|
+
if len(batch) >= batch_size:
|
|
212
|
+
rows_loaded += self._insert_batch(
|
|
213
|
+
conn, sql_table, batch, on_conflict
|
|
214
|
+
)
|
|
215
|
+
batch = []
|
|
216
|
+
|
|
217
|
+
# Insert remaining rows
|
|
218
|
+
if batch:
|
|
219
|
+
rows_loaded += self._insert_batch(
|
|
220
|
+
conn, sql_table, batch, on_conflict
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return rows_loaded
|
|
224
|
+
|
|
225
|
+
def _insert_batch(
|
|
226
|
+
self,
|
|
227
|
+
conn: Any,
|
|
228
|
+
sql_table: Any,
|
|
229
|
+
rows: list[dict[str, Any]],
|
|
230
|
+
on_conflict: str,
|
|
231
|
+
) -> int:
|
|
232
|
+
"""Insert a batch of rows.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
conn: Database connection.
|
|
236
|
+
sql_table: SQLAlchemy table.
|
|
237
|
+
rows: List of row dictionaries.
|
|
238
|
+
on_conflict: Conflict handling strategy.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Number of rows inserted.
|
|
242
|
+
"""
|
|
243
|
+
if not rows:
|
|
244
|
+
return 0
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
if on_conflict == "ignore":
|
|
248
|
+
stmt = sqlite_insert(sql_table).on_conflict_do_nothing()
|
|
249
|
+
elif on_conflict == "replace":
|
|
250
|
+
# For SQLite, we need to specify all columns for upsert
|
|
251
|
+
stmt = sqlite_insert(sql_table)
|
|
252
|
+
update_cols = {
|
|
253
|
+
c.name: c for c in stmt.excluded
|
|
254
|
+
if c.name not in ("RID",) # Don't update primary key
|
|
255
|
+
}
|
|
256
|
+
stmt = stmt.on_conflict_do_update(
|
|
257
|
+
index_elements=["RID"],
|
|
258
|
+
set_=update_cols,
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
stmt = sql_table.insert()
|
|
262
|
+
|
|
263
|
+
conn.execute(stmt, rows)
|
|
264
|
+
return len(rows)
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.error(f"Error inserting into {sql_table.name}: {e}")
|
|
268
|
+
if on_conflict == "error":
|
|
269
|
+
raise
|
|
270
|
+
return 0
|
|
271
|
+
|
|
272
|
+
def load_table(
|
|
273
|
+
self,
|
|
274
|
+
table: str | DerivaTable,
|
|
275
|
+
on_conflict: str = "ignore",
|
|
276
|
+
batch_size: int = 1000,
|
|
277
|
+
) -> int:
|
|
278
|
+
"""Load a single table (without FK ordering).
|
|
279
|
+
|
|
280
|
+
Use this when you know the dependencies are already satisfied
|
|
281
|
+
or for loading a single table.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
table: Table to load.
|
|
285
|
+
on_conflict: Conflict handling strategy.
|
|
286
|
+
batch_size: Rows per batch.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Number of rows loaded.
|
|
290
|
+
"""
|
|
291
|
+
if isinstance(table, str):
|
|
292
|
+
table = self.orderer._to_table(table)
|
|
293
|
+
|
|
294
|
+
return self._load_table(table, on_conflict, batch_size)
|
|
295
|
+
|
|
296
|
+
def get_load_order(
|
|
297
|
+
self,
|
|
298
|
+
tables: list[str | DerivaTable] | None = None,
|
|
299
|
+
) -> list[str]:
|
|
300
|
+
"""Get the FK-safe load order for tables without loading.
|
|
301
|
+
|
|
302
|
+
Useful for previewing or manually controlling load order.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
tables: Tables to order. If None, orders all available.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
List of table names in safe insertion order.
|
|
309
|
+
"""
|
|
310
|
+
if tables is None:
|
|
311
|
+
available = self.source.list_available_tables()
|
|
312
|
+
tables = [t for t in available if self._table_exists(t)]
|
|
313
|
+
|
|
314
|
+
ordered = self.orderer.get_insertion_order(tables)
|
|
315
|
+
return [f"{t.schema.name}.{t.name}" for t in ordered]
|
|
316
|
+
|
|
317
|
+
def validate_load_order(
|
|
318
|
+
self,
|
|
319
|
+
tables: list[str | DerivaTable],
|
|
320
|
+
) -> list[tuple[str, str, str]]:
|
|
321
|
+
"""Validate that tables can be loaded in the given order.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
tables: Ordered list of tables.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
List of FK violations as (table, missing_dep, fk_name) tuples.
|
|
328
|
+
Empty if order is valid.
|
|
329
|
+
"""
|
|
330
|
+
return self.orderer.validate_insertion_order(tables)
|