deriva-ml 1.17.10__py3-none-any.whl → 1.17.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. deriva_ml/__init__.py +69 -1
  2. deriva_ml/asset/__init__.py +17 -0
  3. deriva_ml/asset/asset.py +357 -0
  4. deriva_ml/asset/aux_classes.py +100 -0
  5. deriva_ml/bump_version.py +254 -11
  6. deriva_ml/catalog/__init__.py +31 -0
  7. deriva_ml/catalog/clone.py +1939 -0
  8. deriva_ml/catalog/localize.py +426 -0
  9. deriva_ml/core/__init__.py +29 -0
  10. deriva_ml/core/base.py +845 -1067
  11. deriva_ml/core/config.py +169 -21
  12. deriva_ml/core/constants.py +120 -19
  13. deriva_ml/core/definitions.py +123 -13
  14. deriva_ml/core/enums.py +47 -73
  15. deriva_ml/core/ermrest.py +226 -193
  16. deriva_ml/core/exceptions.py +297 -14
  17. deriva_ml/core/filespec.py +99 -28
  18. deriva_ml/core/logging_config.py +225 -0
  19. deriva_ml/core/mixins/__init__.py +42 -0
  20. deriva_ml/core/mixins/annotation.py +915 -0
  21. deriva_ml/core/mixins/asset.py +384 -0
  22. deriva_ml/core/mixins/dataset.py +237 -0
  23. deriva_ml/core/mixins/execution.py +408 -0
  24. deriva_ml/core/mixins/feature.py +365 -0
  25. deriva_ml/core/mixins/file.py +263 -0
  26. deriva_ml/core/mixins/path_builder.py +145 -0
  27. deriva_ml/core/mixins/rid_resolution.py +204 -0
  28. deriva_ml/core/mixins/vocabulary.py +400 -0
  29. deriva_ml/core/mixins/workflow.py +322 -0
  30. deriva_ml/core/validation.py +389 -0
  31. deriva_ml/dataset/__init__.py +2 -1
  32. deriva_ml/dataset/aux_classes.py +20 -4
  33. deriva_ml/dataset/catalog_graph.py +575 -0
  34. deriva_ml/dataset/dataset.py +1242 -1008
  35. deriva_ml/dataset/dataset_bag.py +1311 -182
  36. deriva_ml/dataset/history.py +27 -14
  37. deriva_ml/dataset/upload.py +225 -38
  38. deriva_ml/demo_catalog.py +126 -110
  39. deriva_ml/execution/__init__.py +46 -2
  40. deriva_ml/execution/base_config.py +639 -0
  41. deriva_ml/execution/execution.py +543 -242
  42. deriva_ml/execution/execution_configuration.py +26 -11
  43. deriva_ml/execution/execution_record.py +592 -0
  44. deriva_ml/execution/find_caller.py +298 -0
  45. deriva_ml/execution/model_protocol.py +175 -0
  46. deriva_ml/execution/multirun_config.py +153 -0
  47. deriva_ml/execution/runner.py +595 -0
  48. deriva_ml/execution/workflow.py +223 -34
  49. deriva_ml/experiment/__init__.py +8 -0
  50. deriva_ml/experiment/experiment.py +411 -0
  51. deriva_ml/feature.py +6 -1
  52. deriva_ml/install_kernel.py +143 -6
  53. deriva_ml/interfaces.py +862 -0
  54. deriva_ml/model/__init__.py +99 -0
  55. deriva_ml/model/annotations.py +1278 -0
  56. deriva_ml/model/catalog.py +286 -60
  57. deriva_ml/model/database.py +144 -649
  58. deriva_ml/model/deriva_ml_database.py +308 -0
  59. deriva_ml/model/handles.py +14 -0
  60. deriva_ml/run_model.py +319 -0
  61. deriva_ml/run_notebook.py +507 -38
  62. deriva_ml/schema/__init__.py +18 -2
  63. deriva_ml/schema/annotations.py +62 -33
  64. deriva_ml/schema/create_schema.py +169 -69
  65. deriva_ml/schema/validation.py +601 -0
  66. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/METADATA +4 -4
  67. deriva_ml-1.17.12.dist-info/RECORD +77 -0
  68. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/WHEEL +1 -1
  69. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/entry_points.txt +1 -0
  70. deriva_ml/protocols/dataset.py +0 -19
  71. deriva_ml/test.py +0 -94
  72. deriva_ml-1.17.10.dist-info/RECORD +0 -45
  73. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/licenses/LICENSE +0 -0
  74. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,426 @@
1
+ """Localize remote hatrac assets to a local catalog server."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import tempfile
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING
10
+ from urllib.parse import urlparse, quote as urlquote
11
+
12
+ from deriva.core import ErmrestCatalog, HatracStore, get_credential
13
+
14
+ if TYPE_CHECKING:
15
+ from deriva_ml import DerivaML
16
+
17
+ logger = logging.getLogger("deriva_ml")
18
+
19
+
20
+ @dataclass
21
+ class LocalizeResult:
22
+ """Result of an asset localization operation.
23
+
24
+ Attributes:
25
+ assets_processed: Number of assets successfully localized.
26
+ assets_skipped: Number of assets skipped (already local or errors).
27
+ assets_failed: Number of assets that failed to localize.
28
+ errors: List of error messages for failed assets.
29
+ localized_assets: List of (RID, old_url, new_url) tuples for successfully localized assets.
30
+ """
31
+
32
+ assets_processed: int = 0
33
+ assets_skipped: int = 0
34
+ assets_failed: int = 0
35
+ errors: list[str] = field(default_factory=list)
36
+ localized_assets: list[tuple[str, str, str]] = field(default_factory=list)
37
+
38
+
39
+ def localize_assets(
40
+ catalog: DerivaML | ErmrestCatalog,
41
+ asset_table: str,
42
+ asset_rids: list[str],
43
+ schema_name: str | None = None,
44
+ hatrac_namespace: str | None = None,
45
+ chunk_size: int | None = None,
46
+ dry_run: bool = False,
47
+ ) -> LocalizeResult:
48
+ """Localize remote hatrac assets to the local catalog server.
49
+
50
+ Downloads assets from remote hatrac servers (determined from the URL in each
51
+ asset record) and uploads them to the local hatrac server, updating the asset
52
+ table URLs to point to the local copies.
53
+
54
+ This is useful after cloning a catalog with asset_mode="refs" where the
55
+ asset URLs still point to the source server. Use this function to make
56
+ the assets fully local.
57
+
58
+ The source hatrac server for each asset is determined automatically from
59
+ the URL stored in the asset record.
60
+
61
+ This function is optimized for bulk operations:
62
+ - Fetches all asset records in a single query
63
+ - Caches connections to remote hatrac servers
64
+ - Batches catalog updates for efficiency
65
+ - Supports chunked uploads for large files
66
+
67
+ Args:
68
+ catalog: A DerivaML instance or ErmrestCatalog connected to the catalog.
69
+ asset_table: Name of the asset table containing the assets to localize.
70
+ asset_rids: List of asset RIDs to localize. Each RID should be a record
71
+ in the asset table.
72
+ schema_name: Schema containing the asset table. If None, searches all schemas.
73
+ hatrac_namespace: Optional hatrac namespace for uploaded files. If None,
74
+ uses "/hatrac/{asset_table}/{md5}.{filename}" pattern.
75
+ chunk_size: Optional chunk size in bytes for large file uploads. If None,
76
+ uses default chunking behavior.
77
+ dry_run: If True, only report what would be done without making changes.
78
+
79
+ Returns:
80
+ LocalizeResult with counts and details of the operation.
81
+
82
+ Raises:
83
+ ValueError: If asset_table is not found.
84
+
85
+ Examples:
86
+ Localize specific assets using DerivaML:
87
+ >>> from deriva_ml import DerivaML
88
+ >>> ml = DerivaML("localhost", "42")
89
+ >>> result = localize_assets(
90
+ ... ml,
91
+ ... asset_table="Image",
92
+ ... asset_rids=["1-ABC", "2-DEF", "3-GHI"],
93
+ ... )
94
+ >>> print(f"Localized {result.assets_processed} assets")
95
+
96
+ Localize using ErmrestCatalog:
97
+ >>> from deriva.core import DerivaServer
98
+ >>> server = DerivaServer("https", "localhost")
99
+ >>> catalog = server.connect_ermrest("42")
100
+ >>> result = localize_assets(
101
+ ... catalog,
102
+ ... asset_table="Model_Weights",
103
+ ... asset_rids=["4-JKL"],
104
+ ... dry_run=True,
105
+ ... )
106
+ """
107
+ result = LocalizeResult()
108
+
109
+ # Extract catalog and hostname from the input
110
+ ermrest_catalog, hostname, credential = _get_catalog_info(catalog)
111
+
112
+ # Create pathbuilder for datapath queries
113
+ pb = ermrest_catalog.getPathBuilder()
114
+
115
+ # Find the asset table
116
+ table_path, found_schema = _find_asset_table_path(pb, asset_table, schema_name)
117
+ if table_path is None:
118
+ raise ValueError(f"Asset table '{asset_table}' not found in catalog")
119
+
120
+ # Set up local hatrac
121
+ local_hatrac = HatracStore("https", hostname, credentials=credential)
122
+
123
+ # Determine hatrac namespace
124
+ if hatrac_namespace is None:
125
+ hatrac_namespace = f"/hatrac/{asset_table}"
126
+
127
+ # Fetch all asset records in a single query
128
+ logger.info(f"Fetching {len(asset_rids)} asset records...")
129
+ all_records = _fetch_asset_records(table_path, asset_rids)
130
+
131
+ # Build a map of RID -> record for easy lookup
132
+ records_by_rid = {r["RID"]: r for r in all_records}
133
+
134
+ # Identify which assets need to be localized
135
+ assets_to_localize = []
136
+ for rid in asset_rids:
137
+ record = records_by_rid.get(rid)
138
+ if record is None:
139
+ logger.warning(f"Asset {rid} not found")
140
+ result.assets_skipped += 1
141
+ continue
142
+
143
+ current_url = record.get("URL")
144
+ if not current_url:
145
+ logger.warning(f"Asset {rid} has no URL, skipping")
146
+ result.assets_skipped += 1
147
+ continue
148
+
149
+ # Parse the URL to get source hostname
150
+ parsed_url = urlparse(current_url)
151
+ source_hostname = parsed_url.netloc
152
+
153
+ if not source_hostname:
154
+ logger.info(f"Asset {rid} has relative URL, already local")
155
+ result.assets_skipped += 1
156
+ continue
157
+
158
+ if source_hostname == hostname:
159
+ logger.info(f"Asset {rid} is already local, skipping")
160
+ result.assets_skipped += 1
161
+ continue
162
+
163
+ # Extract the hatrac path from the URL
164
+ source_path = _extract_hatrac_path(current_url)
165
+ if not source_path:
166
+ logger.warning(f"Could not extract hatrac path from URL: {current_url}")
167
+ result.assets_skipped += 1
168
+ continue
169
+
170
+ assets_to_localize.append({
171
+ "rid": rid,
172
+ "record": record,
173
+ "source_hostname": source_hostname,
174
+ "source_path": source_path,
175
+ "current_url": current_url,
176
+ })
177
+
178
+ if not assets_to_localize:
179
+ logger.info("No assets need to be localized")
180
+ return result
181
+
182
+ logger.info(f"Localizing {len(assets_to_localize)} assets...")
183
+
184
+ if dry_run:
185
+ for asset_info in assets_to_localize:
186
+ logger.info(
187
+ f"[DRY RUN] Would download {asset_info['source_path']} from "
188
+ f"{asset_info['source_hostname']} and upload to {hatrac_namespace}"
189
+ )
190
+ result.assets_processed += 1
191
+ return result
192
+
193
+ # Cache for remote hatrac connections (keyed by hostname)
194
+ remote_hatrac_cache: dict[str, HatracStore] = {}
195
+
196
+ # Ensure local namespace exists
197
+ _ensure_hatrac_namespace(local_hatrac, hatrac_namespace)
198
+
199
+ # Collect updates for batch catalog update
200
+ catalog_updates: list[dict] = []
201
+
202
+ # Process each asset
203
+ with tempfile.TemporaryDirectory() as tmpdir:
204
+ scratch_dir = Path(tmpdir)
205
+
206
+ for i, asset_info in enumerate(assets_to_localize):
207
+ rid = asset_info["rid"]
208
+ record = asset_info["record"]
209
+ source_hostname = asset_info["source_hostname"]
210
+ source_path = asset_info["source_path"]
211
+ current_url = asset_info["current_url"]
212
+ filename = record.get("Filename")
213
+ md5 = record.get("MD5")
214
+
215
+ logger.info(f"[{i+1}/{len(assets_to_localize)}] Localizing {rid}: {filename} from {source_hostname}")
216
+
217
+ try:
218
+ # Get or create remote hatrac connection
219
+ if source_hostname not in remote_hatrac_cache:
220
+ source_cred = get_credential(source_hostname)
221
+ remote_hatrac_cache[source_hostname] = HatracStore(
222
+ "https", source_hostname, credentials=source_cred
223
+ )
224
+ source_hatrac = remote_hatrac_cache[source_hostname]
225
+
226
+ # Download from source
227
+ local_file = scratch_dir / (md5 or rid) / (filename or "asset")
228
+ local_file.parent.mkdir(parents=True, exist_ok=True)
229
+
230
+ source_hatrac.get_obj(path=source_path, destfilename=str(local_file))
231
+
232
+ # Upload to local hatrac
233
+ dest_path = f"{hatrac_namespace}/{md5}.{filename}" if md5 and filename else f"{hatrac_namespace}/{rid}"
234
+
235
+ new_url = local_hatrac.put_loc(
236
+ dest_path,
237
+ str(local_file),
238
+ headers={"Content-Disposition": f"filename*=UTF-8''{urlquote(filename or 'asset')}"},
239
+ chunked=chunk_size is not None,
240
+ chunk_size=chunk_size or 0,
241
+ )
242
+
243
+ # Queue the catalog update
244
+ catalog_updates.append({"RID": rid, "URL": new_url})
245
+
246
+ logger.info(f"Localized asset {rid}: {current_url} -> {new_url}")
247
+ result.assets_processed += 1
248
+ result.localized_assets.append((rid, current_url, new_url))
249
+
250
+ # Clean up scratch file
251
+ if local_file.exists():
252
+ local_file.unlink()
253
+
254
+ except Exception as e:
255
+ error_msg = f"Failed to localize asset {rid}: {e}"
256
+ logger.error(error_msg)
257
+ result.errors.append(error_msg)
258
+ result.assets_failed += 1
259
+
260
+ # Batch update the catalog records
261
+ if catalog_updates:
262
+ logger.info(f"Updating {len(catalog_updates)} catalog records...")
263
+ try:
264
+ table_path.path.update(catalog_updates)
265
+ logger.info("Catalog records updated successfully")
266
+ except Exception as e:
267
+ # If batch update fails, try individual updates as fallback
268
+ logger.warning(f"Batch update failed ({e}), falling back to individual updates...")
269
+ for update in catalog_updates:
270
+ try:
271
+ table_path.path.filter(table_path.RID == update["RID"]).update([update])
272
+ except Exception as e2:
273
+ error_msg = f"Failed to update catalog record {update['RID']}: {e2}"
274
+ logger.error(error_msg)
275
+ result.errors.append(error_msg)
276
+
277
+ return result
278
+
279
+
280
+ def _get_catalog_info(
281
+ catalog: DerivaML | ErmrestCatalog,
282
+ ) -> tuple[ErmrestCatalog, str, dict | None]:
283
+ """Extract catalog, hostname, and credential from a DerivaML or ErmrestCatalog.
284
+
285
+ Args:
286
+ catalog: DerivaML instance or ErmrestCatalog.
287
+
288
+ Returns:
289
+ Tuple of (ErmrestCatalog, hostname, credential).
290
+ """
291
+ # Check if it's a DerivaML instance
292
+ if hasattr(catalog, "catalog") and hasattr(catalog, "host_name"):
293
+ # It's a DerivaML instance
294
+ hostname = catalog.host_name
295
+ ermrest_catalog = catalog.catalog
296
+ credential = getattr(catalog, "credential", None) or get_credential(hostname)
297
+ return (ermrest_catalog, hostname, credential)
298
+
299
+ # It's an ErmrestCatalog
300
+ ermrest_catalog = catalog
301
+ # Extract hostname from the catalog's server_uri
302
+ server_uri = ermrest_catalog.get_server_uri()
303
+ parsed = urlparse(server_uri)
304
+ hostname = parsed.netloc
305
+
306
+ credential = get_credential(hostname) if hostname else None
307
+ return (ermrest_catalog, hostname, credential)
308
+
309
+
310
+ def _find_asset_table_path(
311
+ pb,
312
+ table_name: str,
313
+ schema_name: str | None,
314
+ ) -> tuple | None:
315
+ """Find an asset table using pathbuilder.
316
+
317
+ Args:
318
+ pb: PathBuilder instance.
319
+ table_name: Name of the table to find.
320
+ schema_name: Optional schema name. If None, searches all schemas.
321
+
322
+ Returns:
323
+ Tuple of (table_path, schema_name) if found, (None, None) otherwise.
324
+ """
325
+ if schema_name:
326
+ try:
327
+ table_path = pb.schemas[schema_name].tables[table_name]
328
+ return (table_path, schema_name)
329
+ except KeyError:
330
+ return (None, None)
331
+
332
+ # Search all schemas
333
+ for sname in pb.schemas:
334
+ try:
335
+ table_path = pb.schemas[sname].tables[table_name]
336
+ return (table_path, sname)
337
+ except KeyError:
338
+ continue
339
+
340
+ return (None, None)
341
+
342
+
343
+ def _fetch_asset_records(table_path, rids: list[str]) -> list[dict]:
344
+ """Fetch multiple asset records in a single query.
345
+
346
+ Args:
347
+ table_path: Datapath table object.
348
+ rids: List of RIDs to fetch.
349
+
350
+ Returns:
351
+ List of record dictionaries.
352
+ """
353
+ if not rids:
354
+ return []
355
+
356
+ # Use datapath to fetch all records with RIDs in the list
357
+ # Build a filter for RID in (rid1, rid2, ...)
358
+ from deriva.core.datapath import DataPathException
359
+
360
+ try:
361
+ # Fetch records where RID is in our list
362
+ # Use multiple OR conditions since datapath doesn't have an "in" operator
363
+ path = table_path.path
364
+
365
+ # Build filter: RID == rid1 OR RID == rid2 OR ...
366
+ filter_expr = None
367
+ for rid in rids:
368
+ condition = table_path.RID == rid
369
+ if filter_expr is None:
370
+ filter_expr = condition
371
+ else:
372
+ filter_expr = filter_expr | condition
373
+
374
+ if filter_expr is not None:
375
+ path = path.filter(filter_expr)
376
+
377
+ return list(path.entities().fetch())
378
+ except DataPathException as e:
379
+ logger.warning(f"Bulk fetch failed: {e}, falling back to individual fetches")
380
+ # Fallback: fetch records individually
381
+ records = []
382
+ for rid in rids:
383
+ try:
384
+ result = list(table_path.path.filter(table_path.RID == rid).entities().fetch())
385
+ records.extend(result)
386
+ except Exception:
387
+ pass
388
+ return records
389
+
390
+
391
+ def _extract_hatrac_path(url: str) -> str | None:
392
+ """Extract the hatrac path from a full URL.
393
+
394
+ Args:
395
+ url: Full URL like "https://host/hatrac/namespace/file"
396
+
397
+ Returns:
398
+ Hatrac path like "/hatrac/namespace/file" or None if not a hatrac URL.
399
+ """
400
+ parsed = urlparse(url)
401
+ path = parsed.path
402
+
403
+ if "/hatrac/" in path:
404
+ # Find the /hatrac/ part and return from there
405
+ idx = path.find("/hatrac/")
406
+ return path[idx:]
407
+
408
+ if path.startswith("/hatrac/"):
409
+ return path
410
+
411
+ return None
412
+
413
+
414
+ def _ensure_hatrac_namespace(hatrac: HatracStore, namespace: str) -> None:
415
+ """Ensure a hatrac namespace exists, creating it if necessary.
416
+
417
+ Args:
418
+ hatrac: HatracStore instance.
419
+ namespace: Namespace path like "/hatrac/MyTable".
420
+ """
421
+ try:
422
+ # Try to create the namespace (will fail if exists, which is fine)
423
+ hatrac.create_namespace(namespace, parents=True)
424
+ except Exception:
425
+ # Namespace likely already exists
426
+ pass
@@ -1,3 +1,21 @@
1
+ """Core module for DerivaML.
2
+
3
+ This module provides the primary public interface to DerivaML functionality. It exports
4
+ the main DerivaML class along with configuration, definitions, and exceptions needed
5
+ for interacting with Deriva-based ML catalogs.
6
+
7
+ Key exports:
8
+ - DerivaML: Main class for catalog operations and ML workflow management.
9
+ - DerivaMLConfig: Configuration class for DerivaML instances.
10
+ - Exceptions: DerivaMLException and specialized exception types.
11
+ - Definitions: Type definitions, enums, and constants used throughout the package.
12
+
13
+ Example:
14
+ >>> from deriva_ml.core import DerivaML, DerivaMLConfig
15
+ >>> ml = DerivaML('deriva.example.org', 'my_catalog')
16
+ >>> datasets = ml.find_datasets()
17
+ """
18
+
1
19
  from deriva_ml.core.base import DerivaML
2
20
  from deriva_ml.core.config import DerivaMLConfig
3
21
  from deriva_ml.core.definitions import (
@@ -15,6 +33,8 @@ from deriva_ml.core.definitions import (
15
33
  UploadState,
16
34
  )
17
35
  from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm, DerivaMLTableTypeError
36
+ from deriva_ml.core.logging_config import LoggerMixin, configure_logging, get_logger, is_hydra_initialized
37
+ from deriva_ml.core.validation import DERIVA_ML_CONFIG, STRICT_VALIDATION_CONFIG, VALIDATION_CONFIG
18
38
 
19
39
  __all__ = [
20
40
  "DerivaML",
@@ -36,4 +56,13 @@ __all__ = [
36
56
  "MLVocab",
37
57
  "TableDefinition",
38
58
  "UploadState",
59
+ # Validation
60
+ "VALIDATION_CONFIG",
61
+ "DERIVA_ML_CONFIG",
62
+ "STRICT_VALIDATION_CONFIG",
63
+ # Logging
64
+ "get_logger",
65
+ "configure_logging",
66
+ "is_hydra_initialized",
67
+ "LoggerMixin",
39
68
  ]