deriva-ml 1.17.14__py3-none-any.whl → 1.17.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. deriva_ml/__init__.py +2 -2
  2. deriva_ml/asset/asset.py +0 -4
  3. deriva_ml/catalog/__init__.py +6 -0
  4. deriva_ml/catalog/clone.py +1591 -38
  5. deriva_ml/catalog/localize.py +66 -29
  6. deriva_ml/core/base.py +12 -9
  7. deriva_ml/core/definitions.py +13 -12
  8. deriva_ml/core/ermrest.py +11 -12
  9. deriva_ml/core/mixins/annotation.py +2 -2
  10. deriva_ml/core/mixins/asset.py +3 -3
  11. deriva_ml/core/mixins/dataset.py +3 -3
  12. deriva_ml/core/mixins/execution.py +1 -0
  13. deriva_ml/core/mixins/feature.py +2 -2
  14. deriva_ml/core/mixins/file.py +2 -2
  15. deriva_ml/core/mixins/path_builder.py +2 -2
  16. deriva_ml/core/mixins/rid_resolution.py +2 -2
  17. deriva_ml/core/mixins/vocabulary.py +2 -2
  18. deriva_ml/core/mixins/workflow.py +3 -3
  19. deriva_ml/dataset/catalog_graph.py +3 -4
  20. deriva_ml/dataset/dataset.py +5 -3
  21. deriva_ml/dataset/dataset_bag.py +0 -2
  22. deriva_ml/dataset/upload.py +2 -2
  23. deriva_ml/demo_catalog.py +0 -1
  24. deriva_ml/execution/__init__.py +8 -8
  25. deriva_ml/execution/base_config.py +2 -2
  26. deriva_ml/execution/execution.py +5 -3
  27. deriva_ml/execution/execution_record.py +0 -1
  28. deriva_ml/execution/model_protocol.py +1 -1
  29. deriva_ml/execution/multirun_config.py +0 -1
  30. deriva_ml/execution/runner.py +3 -3
  31. deriva_ml/experiment/experiment.py +3 -3
  32. deriva_ml/feature.py +2 -2
  33. deriva_ml/interfaces.py +2 -2
  34. deriva_ml/model/__init__.py +45 -24
  35. deriva_ml/model/annotations.py +0 -1
  36. deriva_ml/model/catalog.py +3 -2
  37. deriva_ml/model/data_loader.py +330 -0
  38. deriva_ml/model/data_sources.py +439 -0
  39. deriva_ml/model/database.py +216 -32
  40. deriva_ml/model/fk_orderer.py +379 -0
  41. deriva_ml/model/handles.py +1 -1
  42. deriva_ml/model/schema_builder.py +816 -0
  43. deriva_ml/run_model.py +3 -3
  44. deriva_ml/schema/annotations.py +2 -1
  45. deriva_ml/schema/create_schema.py +1 -1
  46. deriva_ml/schema/validation.py +1 -1
  47. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/METADATA +1 -1
  48. deriva_ml-1.17.16.dist-info/RECORD +81 -0
  49. deriva_ml-1.17.14.dist-info/RECORD +0 -77
  50. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/WHEEL +0 -0
  51. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/entry_points.txt +0 -0
  52. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/licenses/LICENSE +0 -0
  53. {deriva_ml-1.17.14.dist-info → deriva_ml-1.17.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,439 @@
1
+ """Data sources for populating SQLite databases.
2
+
3
+ This module provides the DataSource protocol and implementations for
4
+ reading data from various sources:
5
+
6
+ - BagDataSource: Reads from BDBag CSV files
7
+ - CatalogDataSource: Fetches from remote Deriva catalog via ERMrest API
8
+
9
+ These are used with DataLoader in Phase 2 of the two-phase pattern.
10
+
11
+ Example:
12
+ # From bag
13
+ source = BagDataSource(bag_path)
14
+ for row in source.get_table_data(table):
15
+ print(row)
16
+
17
+ # From catalog
18
+ source = CatalogDataSource(catalog, schemas=['domain', 'deriva-ml'])
19
+ for row in source.get_table_data(table):
20
+ print(row)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import csv
26
+ import logging
27
+ from pathlib import Path
28
+ from typing import Any, Iterator, Protocol, runtime_checkable
29
+ from urllib.parse import urlparse
30
+
31
+ from deriva.core import ErmrestCatalog
32
+ from deriva.core.ermrest_model import Model
33
+ from deriva.core.ermrest_model import Table as DerivaTable
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ # Standard asset table columns
39
+ ASSET_COLUMNS = {"Filename", "URL", "Length", "MD5", "Description"}
40
+
41
+
42
+ @runtime_checkable
43
+ class DataSource(Protocol):
44
+ """Protocol for data sources that can fill a database.
45
+
46
+ Implementations provide data for populating SQLite tables from
47
+ different sources (bags, remote catalogs, etc.).
48
+
49
+ This is used with DataLoader in Phase 2 of the two-phase pattern.
50
+ """
51
+
52
+ def get_table_data(
53
+ self,
54
+ table: DerivaTable | str,
55
+ ) -> Iterator[dict[str, Any]]:
56
+ """Yield rows for a table as dictionaries.
57
+
58
+ Args:
59
+ table: Table object or name to get data for.
60
+
61
+ Yields:
62
+ Dictionary per row with column names as keys.
63
+ """
64
+ ...
65
+
66
+ def has_table(self, table: DerivaTable | str) -> bool:
67
+ """Check if this source has data for the table.
68
+
69
+ Args:
70
+ table: Table object or name to check.
71
+
72
+ Returns:
73
+ True if data is available for this table.
74
+ """
75
+ ...
76
+
77
+ def list_available_tables(self) -> list[str]:
78
+ """List tables with available data.
79
+
80
+ Returns:
81
+ List of table names (may include schema prefix).
82
+ """
83
+ ...
84
+
85
+
86
+ class BagDataSource:
87
+ """DataSource implementation for BDBag directories.
88
+
89
+ Reads data from CSV files in a bag's data/ directory.
90
+ Handles asset URL localization via fetch.txt.
91
+
92
+ Example:
93
+ source = BagDataSource(Path("/path/to/bag"))
94
+
95
+ # List available tables
96
+ print(source.list_available_tables())
97
+
98
+ # Get data for a table
99
+ for row in source.get_table_data("Image"):
100
+ print(row["Filename"])
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ bag_path: Path,
106
+ model: Model | None = None,
107
+ asset_localization: bool = True,
108
+ ):
109
+ """Initialize from a bag path.
110
+
111
+ Args:
112
+ bag_path: Path to BDBag directory.
113
+ model: Optional ERMrest Model for schema info. If not provided,
114
+ will try to load from bag's schema.json.
115
+ asset_localization: Whether to localize asset URLs to local paths
116
+ using fetch.txt mapping.
117
+ """
118
+ self.bag_path = Path(bag_path)
119
+ self.data_path = self.bag_path / "data"
120
+
121
+ # Load model if not provided
122
+ if model is None:
123
+ schema_file = self.data_path / "schema.json"
124
+ if schema_file.exists():
125
+ self.model = Model.fromfile("file-system", schema_file)
126
+ else:
127
+ self.model = None
128
+ logger.warning(f"No schema.json found in {self.bag_path}")
129
+ else:
130
+ self.model = model
131
+
132
+ # Build asset map for URL localization
133
+ self._asset_map = self._build_asset_map() if asset_localization else {}
134
+
135
+ # Cache of table name -> csv file path
136
+ self._csv_cache: dict[str, Path] = {}
137
+ self._build_csv_cache()
138
+
139
+ def _build_csv_cache(self) -> None:
140
+ """Build cache mapping table names to CSV file paths."""
141
+ for csv_file in self.data_path.rglob("*.csv"):
142
+ table_name = csv_file.stem
143
+ self._csv_cache[table_name] = csv_file
144
+
145
+ def _build_asset_map(self) -> dict[str, str]:
146
+ """Build a map from remote URLs to local file paths using fetch.txt.
147
+
148
+ Returns:
149
+ Dictionary mapping URL paths to local file paths.
150
+ """
151
+ fetch_map = {}
152
+ fetch_file = self.bag_path / "fetch.txt"
153
+
154
+ if not fetch_file.exists():
155
+ logger.debug(f"No fetch.txt in bag {self.bag_path.name}")
156
+ return fetch_map
157
+
158
+ try:
159
+ with fetch_file.open(newline="\n") as f:
160
+ for row in f:
161
+ # Rows in fetch.txt are tab-separated: URL, size, local_path
162
+ fields = row.split("\t")
163
+ if len(fields) >= 3:
164
+ local_file = fields[2].replace("\n", "")
165
+ local_path = f"{self.bag_path}/{local_file}"
166
+ fetch_map[urlparse(fields[0]).path] = local_path
167
+ except Exception as e:
168
+ logger.warning(f"Error reading fetch.txt: {e}")
169
+
170
+ return fetch_map
171
+
172
+ def _get_table_name(self, table: DerivaTable | str) -> str:
173
+ """Extract table name from table object or string."""
174
+ if isinstance(table, DerivaTable):
175
+ return table.name
176
+ # Handle schema.table format
177
+ if "." in table:
178
+ return table.split(".")[-1]
179
+ return table
180
+
181
+ def _is_asset_table(self, table_name: str) -> bool:
182
+ """Check if a table is an asset table (has Filename, URL, etc. columns)."""
183
+ if self.model is None:
184
+ return False
185
+
186
+ for schema in self.model.schemas.values():
187
+ if table_name in schema.tables:
188
+ table = schema.tables[table_name]
189
+ return ASSET_COLUMNS.issubset({c.name for c in table.columns})
190
+ return False
191
+
192
+ def _localize_asset_row(self, row: dict[str, Any]) -> dict[str, Any]:
193
+ """Replace URL with local path in asset table row.
194
+
195
+ Args:
196
+ row: Dictionary of column values.
197
+
198
+ Returns:
199
+ Updated dictionary with localized file path.
200
+ """
201
+ if "URL" in row and "Filename" in row:
202
+ url = row.get("URL")
203
+ if url and url in self._asset_map:
204
+ row = dict(row) # Copy to avoid mutating original
205
+ row["Filename"] = self._asset_map[url]
206
+ return row
207
+
208
+ def get_table_data(
209
+ self,
210
+ table: DerivaTable | str,
211
+ ) -> Iterator[dict[str, Any]]:
212
+ """Read table data from CSV file.
213
+
214
+ Args:
215
+ table: Table object or name.
216
+
217
+ Yields:
218
+ Dictionary per row with column names as keys.
219
+ """
220
+ table_name = self._get_table_name(table)
221
+ csv_file = self._csv_cache.get(table_name)
222
+
223
+ if csv_file is None or not csv_file.exists():
224
+ logger.debug(f"No CSV file found for table {table_name}")
225
+ return
226
+
227
+ is_asset = self._is_asset_table(table_name)
228
+
229
+ with csv_file.open(newline="") as f:
230
+ reader = csv.DictReader(f)
231
+ for row in reader:
232
+ if is_asset and self._asset_map:
233
+ row = self._localize_asset_row(row)
234
+ yield row
235
+
236
+ def has_table(self, table: DerivaTable | str) -> bool:
237
+ """Check if CSV exists for table.
238
+
239
+ Args:
240
+ table: Table object or name.
241
+
242
+ Returns:
243
+ True if CSV file exists for this table.
244
+ """
245
+ table_name = self._get_table_name(table)
246
+ return table_name in self._csv_cache
247
+
248
+ def list_available_tables(self) -> list[str]:
249
+ """List all CSV files in data directory.
250
+
251
+ Returns:
252
+ List of table names (without .csv extension).
253
+ """
254
+ return sorted(self._csv_cache.keys())
255
+
256
+ def get_row_count(self, table: DerivaTable | str) -> int:
257
+ """Get the number of rows in a table's CSV file.
258
+
259
+ Args:
260
+ table: Table object or name.
261
+
262
+ Returns:
263
+ Number of data rows (excluding header).
264
+ """
265
+ table_name = self._get_table_name(table)
266
+ csv_file = self._csv_cache.get(table_name)
267
+
268
+ if csv_file is None or not csv_file.exists():
269
+ return 0
270
+
271
+ with csv_file.open(newline="") as f:
272
+ # Count lines minus header
273
+ return sum(1 for _ in f) - 1
274
+
275
+
276
+ class CatalogDataSource:
277
+ """DataSource implementation for remote Deriva catalog.
278
+
279
+ Fetches data via ERMrest API / datapath with pagination support.
280
+
281
+ Example:
282
+ catalog = server.connect_ermrest(catalog_id)
283
+ source = CatalogDataSource(catalog, schemas=['domain', 'deriva-ml'])
284
+
285
+ # List available tables
286
+ print(source.list_available_tables())
287
+
288
+ # Get data for a table
289
+ for row in source.get_table_data("Image"):
290
+ print(row["Filename"])
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ catalog: ErmrestCatalog,
296
+ schemas: list[str],
297
+ batch_size: int = 1000,
298
+ ):
299
+ """Initialize from catalog connection.
300
+
301
+ Args:
302
+ catalog: ERMrest catalog connection.
303
+ schemas: Schemas to fetch data from.
304
+ batch_size: Number of rows per API request.
305
+ """
306
+ self.catalog = catalog
307
+ self.schemas = schemas
308
+ self.batch_size = batch_size
309
+ self._pb = catalog.getPathBuilder()
310
+ self._model = catalog.getCatalogModel()
311
+
312
+ def _get_table_info(self, table: DerivaTable | str) -> tuple[str, str] | None:
313
+ """Get schema and table name for a table.
314
+
315
+ Args:
316
+ table: Table object or name.
317
+
318
+ Returns:
319
+ Tuple of (schema_name, table_name) or None if not found.
320
+ """
321
+ if isinstance(table, DerivaTable):
322
+ return table.schema.name, table.name
323
+
324
+ # Handle schema.table format
325
+ if "." in table:
326
+ parts = table.split(".")
327
+ schema_name, table_name = parts[0], parts[1]
328
+ if schema_name in self.schemas:
329
+ return schema_name, table_name
330
+ return None
331
+
332
+ # Search schemas for table
333
+ for schema_name in self.schemas:
334
+ if schema_name in self._model.schemas:
335
+ schema = self._model.schemas[schema_name]
336
+ if table in schema.tables:
337
+ return schema_name, table
338
+
339
+ return None
340
+
341
+ def get_table_data(
342
+ self,
343
+ table: DerivaTable | str,
344
+ ) -> Iterator[dict[str, Any]]:
345
+ """Fetch table data via ERMrest API.
346
+
347
+ Uses pagination to handle large tables efficiently.
348
+
349
+ Args:
350
+ table: Table object or name.
351
+
352
+ Yields:
353
+ Dictionary per row with column names as keys.
354
+ """
355
+ table_info = self._get_table_info(table)
356
+ if table_info is None:
357
+ logger.warning(f"Table {table} not found in schemas {self.schemas}")
358
+ return
359
+
360
+ schema_name, table_name = table_info
361
+
362
+ # Build path
363
+ path = self._pb.schemas[schema_name].tables[table_name]
364
+
365
+ # Paginated fetch using RID ordering
366
+ last_rid = None
367
+ while True:
368
+ # Build query with optional RID filter
369
+ query = path.entities()
370
+ if last_rid is not None:
371
+ query = query.filter(path.RID > last_rid)
372
+
373
+ # Fetch batch ordered by RID
374
+ try:
375
+ entities = list(query.sort(path.RID).fetch(limit=self.batch_size))
376
+ except Exception as e:
377
+ logger.error(f"Error fetching from {schema_name}.{table_name}: {e}")
378
+ break
379
+
380
+ if not entities:
381
+ break
382
+
383
+ for entity in entities:
384
+ yield dict(entity)
385
+
386
+ # Track last RID for pagination
387
+ last_rid = entities[-1]["RID"]
388
+
389
+ if len(entities) < self.batch_size:
390
+ break
391
+
392
+ def has_table(self, table: DerivaTable | str) -> bool:
393
+ """Check if table exists in catalog.
394
+
395
+ Args:
396
+ table: Table object or name.
397
+
398
+ Returns:
399
+ True if table exists in configured schemas.
400
+ """
401
+ return self._get_table_info(table) is not None
402
+
403
+ def list_available_tables(self) -> list[str]:
404
+ """List all tables in configured schemas.
405
+
406
+ Returns:
407
+ List of fully-qualified table names (schema.table).
408
+ """
409
+ tables = []
410
+ for schema_name in self.schemas:
411
+ if schema_name in self._model.schemas:
412
+ schema = self._model.schemas[schema_name]
413
+ for table_name in schema.tables.keys():
414
+ tables.append(f"{schema_name}.{table_name}")
415
+ return sorted(tables)
416
+
417
+ def get_row_count(self, table: DerivaTable | str) -> int:
418
+ """Get the number of rows in a table.
419
+
420
+ Args:
421
+ table: Table object or name.
422
+
423
+ Returns:
424
+ Number of rows in the table.
425
+ """
426
+ table_info = self._get_table_info(table)
427
+ if table_info is None:
428
+ return 0
429
+
430
+ schema_name, table_name = table_info
431
+ path = self._pb.schemas[schema_name].tables[table_name]
432
+
433
+ try:
434
+ # Use count aggregate
435
+ result = path.aggregates(path.RID.cnt.alias("count")).fetch()
436
+ return result[0]["count"] if result else 0
437
+ except Exception as e:
438
+ logger.error(f"Error counting {schema_name}.{table_name}: {e}")
439
+ return 0