deriva-ml 1.17.14__py3-none-any.whl → 1.17.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/__init__.py +2 -2
- deriva_ml/asset/asset.py +0 -4
- deriva_ml/catalog/__init__.py +6 -0
- deriva_ml/catalog/clone.py +1591 -38
- deriva_ml/catalog/localize.py +66 -29
- deriva_ml/core/base.py +12 -9
- deriva_ml/core/definitions.py +13 -12
- deriva_ml/core/ermrest.py +11 -12
- deriva_ml/core/mixins/annotation.py +2 -2
- deriva_ml/core/mixins/asset.py +3 -3
- deriva_ml/core/mixins/dataset.py +3 -3
- deriva_ml/core/mixins/execution.py +1 -0
- deriva_ml/core/mixins/feature.py +2 -2
- deriva_ml/core/mixins/file.py +2 -2
- deriva_ml/core/mixins/path_builder.py +2 -2
- deriva_ml/core/mixins/rid_resolution.py +2 -2
- deriva_ml/core/mixins/vocabulary.py +2 -2
- deriva_ml/core/mixins/workflow.py +3 -3
- deriva_ml/dataset/catalog_graph.py +3 -4
- deriva_ml/dataset/dataset.py +5 -3
- deriva_ml/dataset/dataset_bag.py +0 -2
- deriva_ml/dataset/upload.py +2 -2
- deriva_ml/demo_catalog.py +0 -1
- deriva_ml/execution/__init__.py +8 -8
- deriva_ml/execution/base_config.py +2 -2
- deriva_ml/execution/execution.py +5 -3
- deriva_ml/execution/execution_record.py +0 -1
- deriva_ml/execution/model_protocol.py +1 -1
- deriva_ml/execution/multirun_config.py +0 -1
- deriva_ml/execution/runner.py +3 -3
- deriva_ml/experiment/experiment.py +3 -3
- deriva_ml/feature.py +2 -2
- deriva_ml/interfaces.py +2 -2
- deriva_ml/model/__init__.py +45 -24
- deriva_ml/model/annotations.py +0 -1
- deriva_ml/model/catalog.py +3 -2
- deriva_ml/model/data_loader.py +330 -0
- deriva_ml/model/data_sources.py +439 -0
- deriva_ml/model/database.py +216 -32
- deriva_ml/model/fk_orderer.py +379 -0
- deriva_ml/model/handles.py +1 -1
- deriva_ml/model/schema_builder.py +816 -0
- deriva_ml/run_model.py +3 -3
- deriva_ml/schema/annotations.py +2 -1
- deriva_ml/schema/create_schema.py +1 -1
- deriva_ml/schema/validation.py +1 -1
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/METADATA +1 -1
- deriva_ml-1.17.16.dist-info/RECORD +81 -0
- deriva_ml-1.17.14.dist-info/RECORD +0 -77
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/WHEEL +0 -0
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/entry_points.txt +0 -0
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
"""Data sources for populating SQLite databases.
|
|
2
|
+
|
|
3
|
+
This module provides the DataSource protocol and implementations for
|
|
4
|
+
reading data from various sources:
|
|
5
|
+
|
|
6
|
+
- BagDataSource: Reads from BDBag CSV files
|
|
7
|
+
- CatalogDataSource: Fetches from remote Deriva catalog via ERMrest API
|
|
8
|
+
|
|
9
|
+
These are used with DataLoader in Phase 2 of the two-phase pattern.
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
# From bag
|
|
13
|
+
source = BagDataSource(bag_path)
|
|
14
|
+
for row in source.get_table_data(table):
|
|
15
|
+
print(row)
|
|
16
|
+
|
|
17
|
+
# From catalog
|
|
18
|
+
source = CatalogDataSource(catalog, schemas=['domain', 'deriva-ml'])
|
|
19
|
+
for row in source.get_table_data(table):
|
|
20
|
+
print(row)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import csv
|
|
26
|
+
import logging
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Any, Iterator, Protocol, runtime_checkable
|
|
29
|
+
from urllib.parse import urlparse
|
|
30
|
+
|
|
31
|
+
from deriva.core import ErmrestCatalog
|
|
32
|
+
from deriva.core.ermrest_model import Model
|
|
33
|
+
from deriva.core.ermrest_model import Table as DerivaTable
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Standard asset table columns
|
|
39
|
+
ASSET_COLUMNS = {"Filename", "URL", "Length", "MD5", "Description"}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@runtime_checkable
|
|
43
|
+
class DataSource(Protocol):
|
|
44
|
+
"""Protocol for data sources that can fill a database.
|
|
45
|
+
|
|
46
|
+
Implementations provide data for populating SQLite tables from
|
|
47
|
+
different sources (bags, remote catalogs, etc.).
|
|
48
|
+
|
|
49
|
+
This is used with DataLoader in Phase 2 of the two-phase pattern.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def get_table_data(
|
|
53
|
+
self,
|
|
54
|
+
table: DerivaTable | str,
|
|
55
|
+
) -> Iterator[dict[str, Any]]:
|
|
56
|
+
"""Yield rows for a table as dictionaries.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
table: Table object or name to get data for.
|
|
60
|
+
|
|
61
|
+
Yields:
|
|
62
|
+
Dictionary per row with column names as keys.
|
|
63
|
+
"""
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
def has_table(self, table: DerivaTable | str) -> bool:
|
|
67
|
+
"""Check if this source has data for the table.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
table: Table object or name to check.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
True if data is available for this table.
|
|
74
|
+
"""
|
|
75
|
+
...
|
|
76
|
+
|
|
77
|
+
def list_available_tables(self) -> list[str]:
|
|
78
|
+
"""List tables with available data.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
List of table names (may include schema prefix).
|
|
82
|
+
"""
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class BagDataSource:
|
|
87
|
+
"""DataSource implementation for BDBag directories.
|
|
88
|
+
|
|
89
|
+
Reads data from CSV files in a bag's data/ directory.
|
|
90
|
+
Handles asset URL localization via fetch.txt.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
source = BagDataSource(Path("/path/to/bag"))
|
|
94
|
+
|
|
95
|
+
# List available tables
|
|
96
|
+
print(source.list_available_tables())
|
|
97
|
+
|
|
98
|
+
# Get data for a table
|
|
99
|
+
for row in source.get_table_data("Image"):
|
|
100
|
+
print(row["Filename"])
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
bag_path: Path,
|
|
106
|
+
model: Model | None = None,
|
|
107
|
+
asset_localization: bool = True,
|
|
108
|
+
):
|
|
109
|
+
"""Initialize from a bag path.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
bag_path: Path to BDBag directory.
|
|
113
|
+
model: Optional ERMrest Model for schema info. If not provided,
|
|
114
|
+
will try to load from bag's schema.json.
|
|
115
|
+
asset_localization: Whether to localize asset URLs to local paths
|
|
116
|
+
using fetch.txt mapping.
|
|
117
|
+
"""
|
|
118
|
+
self.bag_path = Path(bag_path)
|
|
119
|
+
self.data_path = self.bag_path / "data"
|
|
120
|
+
|
|
121
|
+
# Load model if not provided
|
|
122
|
+
if model is None:
|
|
123
|
+
schema_file = self.data_path / "schema.json"
|
|
124
|
+
if schema_file.exists():
|
|
125
|
+
self.model = Model.fromfile("file-system", schema_file)
|
|
126
|
+
else:
|
|
127
|
+
self.model = None
|
|
128
|
+
logger.warning(f"No schema.json found in {self.bag_path}")
|
|
129
|
+
else:
|
|
130
|
+
self.model = model
|
|
131
|
+
|
|
132
|
+
# Build asset map for URL localization
|
|
133
|
+
self._asset_map = self._build_asset_map() if asset_localization else {}
|
|
134
|
+
|
|
135
|
+
# Cache of table name -> csv file path
|
|
136
|
+
self._csv_cache: dict[str, Path] = {}
|
|
137
|
+
self._build_csv_cache()
|
|
138
|
+
|
|
139
|
+
def _build_csv_cache(self) -> None:
|
|
140
|
+
"""Build cache mapping table names to CSV file paths."""
|
|
141
|
+
for csv_file in self.data_path.rglob("*.csv"):
|
|
142
|
+
table_name = csv_file.stem
|
|
143
|
+
self._csv_cache[table_name] = csv_file
|
|
144
|
+
|
|
145
|
+
def _build_asset_map(self) -> dict[str, str]:
|
|
146
|
+
"""Build a map from remote URLs to local file paths using fetch.txt.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Dictionary mapping URL paths to local file paths.
|
|
150
|
+
"""
|
|
151
|
+
fetch_map = {}
|
|
152
|
+
fetch_file = self.bag_path / "fetch.txt"
|
|
153
|
+
|
|
154
|
+
if not fetch_file.exists():
|
|
155
|
+
logger.debug(f"No fetch.txt in bag {self.bag_path.name}")
|
|
156
|
+
return fetch_map
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
with fetch_file.open(newline="\n") as f:
|
|
160
|
+
for row in f:
|
|
161
|
+
# Rows in fetch.txt are tab-separated: URL, size, local_path
|
|
162
|
+
fields = row.split("\t")
|
|
163
|
+
if len(fields) >= 3:
|
|
164
|
+
local_file = fields[2].replace("\n", "")
|
|
165
|
+
local_path = f"{self.bag_path}/{local_file}"
|
|
166
|
+
fetch_map[urlparse(fields[0]).path] = local_path
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.warning(f"Error reading fetch.txt: {e}")
|
|
169
|
+
|
|
170
|
+
return fetch_map
|
|
171
|
+
|
|
172
|
+
def _get_table_name(self, table: DerivaTable | str) -> str:
|
|
173
|
+
"""Extract table name from table object or string."""
|
|
174
|
+
if isinstance(table, DerivaTable):
|
|
175
|
+
return table.name
|
|
176
|
+
# Handle schema.table format
|
|
177
|
+
if "." in table:
|
|
178
|
+
return table.split(".")[-1]
|
|
179
|
+
return table
|
|
180
|
+
|
|
181
|
+
def _is_asset_table(self, table_name: str) -> bool:
|
|
182
|
+
"""Check if a table is an asset table (has Filename, URL, etc. columns)."""
|
|
183
|
+
if self.model is None:
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
for schema in self.model.schemas.values():
|
|
187
|
+
if table_name in schema.tables:
|
|
188
|
+
table = schema.tables[table_name]
|
|
189
|
+
return ASSET_COLUMNS.issubset({c.name for c in table.columns})
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
def _localize_asset_row(self, row: dict[str, Any]) -> dict[str, Any]:
|
|
193
|
+
"""Replace URL with local path in asset table row.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
row: Dictionary of column values.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Updated dictionary with localized file path.
|
|
200
|
+
"""
|
|
201
|
+
if "URL" in row and "Filename" in row:
|
|
202
|
+
url = row.get("URL")
|
|
203
|
+
if url and url in self._asset_map:
|
|
204
|
+
row = dict(row) # Copy to avoid mutating original
|
|
205
|
+
row["Filename"] = self._asset_map[url]
|
|
206
|
+
return row
|
|
207
|
+
|
|
208
|
+
def get_table_data(
|
|
209
|
+
self,
|
|
210
|
+
table: DerivaTable | str,
|
|
211
|
+
) -> Iterator[dict[str, Any]]:
|
|
212
|
+
"""Read table data from CSV file.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
table: Table object or name.
|
|
216
|
+
|
|
217
|
+
Yields:
|
|
218
|
+
Dictionary per row with column names as keys.
|
|
219
|
+
"""
|
|
220
|
+
table_name = self._get_table_name(table)
|
|
221
|
+
csv_file = self._csv_cache.get(table_name)
|
|
222
|
+
|
|
223
|
+
if csv_file is None or not csv_file.exists():
|
|
224
|
+
logger.debug(f"No CSV file found for table {table_name}")
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
is_asset = self._is_asset_table(table_name)
|
|
228
|
+
|
|
229
|
+
with csv_file.open(newline="") as f:
|
|
230
|
+
reader = csv.DictReader(f)
|
|
231
|
+
for row in reader:
|
|
232
|
+
if is_asset and self._asset_map:
|
|
233
|
+
row = self._localize_asset_row(row)
|
|
234
|
+
yield row
|
|
235
|
+
|
|
236
|
+
def has_table(self, table: DerivaTable | str) -> bool:
|
|
237
|
+
"""Check if CSV exists for table.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
table: Table object or name.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
True if CSV file exists for this table.
|
|
244
|
+
"""
|
|
245
|
+
table_name = self._get_table_name(table)
|
|
246
|
+
return table_name in self._csv_cache
|
|
247
|
+
|
|
248
|
+
def list_available_tables(self) -> list[str]:
|
|
249
|
+
"""List all CSV files in data directory.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
List of table names (without .csv extension).
|
|
253
|
+
"""
|
|
254
|
+
return sorted(self._csv_cache.keys())
|
|
255
|
+
|
|
256
|
+
def get_row_count(self, table: DerivaTable | str) -> int:
|
|
257
|
+
"""Get the number of rows in a table's CSV file.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
table: Table object or name.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Number of data rows (excluding header).
|
|
264
|
+
"""
|
|
265
|
+
table_name = self._get_table_name(table)
|
|
266
|
+
csv_file = self._csv_cache.get(table_name)
|
|
267
|
+
|
|
268
|
+
if csv_file is None or not csv_file.exists():
|
|
269
|
+
return 0
|
|
270
|
+
|
|
271
|
+
with csv_file.open(newline="") as f:
|
|
272
|
+
# Count lines minus header
|
|
273
|
+
return sum(1 for _ in f) - 1
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class CatalogDataSource:
|
|
277
|
+
"""DataSource implementation for remote Deriva catalog.
|
|
278
|
+
|
|
279
|
+
Fetches data via ERMrest API / datapath with pagination support.
|
|
280
|
+
|
|
281
|
+
Example:
|
|
282
|
+
catalog = server.connect_ermrest(catalog_id)
|
|
283
|
+
source = CatalogDataSource(catalog, schemas=['domain', 'deriva-ml'])
|
|
284
|
+
|
|
285
|
+
# List available tables
|
|
286
|
+
print(source.list_available_tables())
|
|
287
|
+
|
|
288
|
+
# Get data for a table
|
|
289
|
+
for row in source.get_table_data("Image"):
|
|
290
|
+
print(row["Filename"])
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
def __init__(
|
|
294
|
+
self,
|
|
295
|
+
catalog: ErmrestCatalog,
|
|
296
|
+
schemas: list[str],
|
|
297
|
+
batch_size: int = 1000,
|
|
298
|
+
):
|
|
299
|
+
"""Initialize from catalog connection.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
catalog: ERMrest catalog connection.
|
|
303
|
+
schemas: Schemas to fetch data from.
|
|
304
|
+
batch_size: Number of rows per API request.
|
|
305
|
+
"""
|
|
306
|
+
self.catalog = catalog
|
|
307
|
+
self.schemas = schemas
|
|
308
|
+
self.batch_size = batch_size
|
|
309
|
+
self._pb = catalog.getPathBuilder()
|
|
310
|
+
self._model = catalog.getCatalogModel()
|
|
311
|
+
|
|
312
|
+
def _get_table_info(self, table: DerivaTable | str) -> tuple[str, str] | None:
|
|
313
|
+
"""Get schema and table name for a table.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
table: Table object or name.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
Tuple of (schema_name, table_name) or None if not found.
|
|
320
|
+
"""
|
|
321
|
+
if isinstance(table, DerivaTable):
|
|
322
|
+
return table.schema.name, table.name
|
|
323
|
+
|
|
324
|
+
# Handle schema.table format
|
|
325
|
+
if "." in table:
|
|
326
|
+
parts = table.split(".")
|
|
327
|
+
schema_name, table_name = parts[0], parts[1]
|
|
328
|
+
if schema_name in self.schemas:
|
|
329
|
+
return schema_name, table_name
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
# Search schemas for table
|
|
333
|
+
for schema_name in self.schemas:
|
|
334
|
+
if schema_name in self._model.schemas:
|
|
335
|
+
schema = self._model.schemas[schema_name]
|
|
336
|
+
if table in schema.tables:
|
|
337
|
+
return schema_name, table
|
|
338
|
+
|
|
339
|
+
return None
|
|
340
|
+
|
|
341
|
+
def get_table_data(
|
|
342
|
+
self,
|
|
343
|
+
table: DerivaTable | str,
|
|
344
|
+
) -> Iterator[dict[str, Any]]:
|
|
345
|
+
"""Fetch table data via ERMrest API.
|
|
346
|
+
|
|
347
|
+
Uses pagination to handle large tables efficiently.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
table: Table object or name.
|
|
351
|
+
|
|
352
|
+
Yields:
|
|
353
|
+
Dictionary per row with column names as keys.
|
|
354
|
+
"""
|
|
355
|
+
table_info = self._get_table_info(table)
|
|
356
|
+
if table_info is None:
|
|
357
|
+
logger.warning(f"Table {table} not found in schemas {self.schemas}")
|
|
358
|
+
return
|
|
359
|
+
|
|
360
|
+
schema_name, table_name = table_info
|
|
361
|
+
|
|
362
|
+
# Build path
|
|
363
|
+
path = self._pb.schemas[schema_name].tables[table_name]
|
|
364
|
+
|
|
365
|
+
# Paginated fetch using RID ordering
|
|
366
|
+
last_rid = None
|
|
367
|
+
while True:
|
|
368
|
+
# Build query with optional RID filter
|
|
369
|
+
query = path.entities()
|
|
370
|
+
if last_rid is not None:
|
|
371
|
+
query = query.filter(path.RID > last_rid)
|
|
372
|
+
|
|
373
|
+
# Fetch batch ordered by RID
|
|
374
|
+
try:
|
|
375
|
+
entities = list(query.sort(path.RID).fetch(limit=self.batch_size))
|
|
376
|
+
except Exception as e:
|
|
377
|
+
logger.error(f"Error fetching from {schema_name}.{table_name}: {e}")
|
|
378
|
+
break
|
|
379
|
+
|
|
380
|
+
if not entities:
|
|
381
|
+
break
|
|
382
|
+
|
|
383
|
+
for entity in entities:
|
|
384
|
+
yield dict(entity)
|
|
385
|
+
|
|
386
|
+
# Track last RID for pagination
|
|
387
|
+
last_rid = entities[-1]["RID"]
|
|
388
|
+
|
|
389
|
+
if len(entities) < self.batch_size:
|
|
390
|
+
break
|
|
391
|
+
|
|
392
|
+
def has_table(self, table: DerivaTable | str) -> bool:
|
|
393
|
+
"""Check if table exists in catalog.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
table: Table object or name.
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
True if table exists in configured schemas.
|
|
400
|
+
"""
|
|
401
|
+
return self._get_table_info(table) is not None
|
|
402
|
+
|
|
403
|
+
def list_available_tables(self) -> list[str]:
|
|
404
|
+
"""List all tables in configured schemas.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
List of fully-qualified table names (schema.table).
|
|
408
|
+
"""
|
|
409
|
+
tables = []
|
|
410
|
+
for schema_name in self.schemas:
|
|
411
|
+
if schema_name in self._model.schemas:
|
|
412
|
+
schema = self._model.schemas[schema_name]
|
|
413
|
+
for table_name in schema.tables.keys():
|
|
414
|
+
tables.append(f"{schema_name}.{table_name}")
|
|
415
|
+
return sorted(tables)
|
|
416
|
+
|
|
417
|
+
def get_row_count(self, table: DerivaTable | str) -> int:
|
|
418
|
+
"""Get the number of rows in a table.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
table: Table object or name.
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
Number of rows in the table.
|
|
425
|
+
"""
|
|
426
|
+
table_info = self._get_table_info(table)
|
|
427
|
+
if table_info is None:
|
|
428
|
+
return 0
|
|
429
|
+
|
|
430
|
+
schema_name, table_name = table_info
|
|
431
|
+
path = self._pb.schemas[schema_name].tables[table_name]
|
|
432
|
+
|
|
433
|
+
try:
|
|
434
|
+
# Use count aggregate
|
|
435
|
+
result = path.aggregates(path.RID.cnt.alias("count")).fetch()
|
|
436
|
+
return result[0]["count"] if result else 0
|
|
437
|
+
except Exception as e:
|
|
438
|
+
logger.error(f"Error counting {schema_name}.{table_name}: {e}")
|
|
439
|
+
return 0
|