deriva-ml 1.17.10__py3-none-any.whl → 1.17.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/__init__.py +69 -1
- deriva_ml/asset/__init__.py +17 -0
- deriva_ml/asset/asset.py +357 -0
- deriva_ml/asset/aux_classes.py +100 -0
- deriva_ml/bump_version.py +254 -11
- deriva_ml/catalog/__init__.py +31 -0
- deriva_ml/catalog/clone.py +1939 -0
- deriva_ml/catalog/localize.py +426 -0
- deriva_ml/core/__init__.py +29 -0
- deriva_ml/core/base.py +845 -1067
- deriva_ml/core/config.py +169 -21
- deriva_ml/core/constants.py +120 -19
- deriva_ml/core/definitions.py +123 -13
- deriva_ml/core/enums.py +47 -73
- deriva_ml/core/ermrest.py +226 -193
- deriva_ml/core/exceptions.py +297 -14
- deriva_ml/core/filespec.py +99 -28
- deriva_ml/core/logging_config.py +225 -0
- deriva_ml/core/mixins/__init__.py +42 -0
- deriva_ml/core/mixins/annotation.py +915 -0
- deriva_ml/core/mixins/asset.py +384 -0
- deriva_ml/core/mixins/dataset.py +237 -0
- deriva_ml/core/mixins/execution.py +408 -0
- deriva_ml/core/mixins/feature.py +365 -0
- deriva_ml/core/mixins/file.py +263 -0
- deriva_ml/core/mixins/path_builder.py +145 -0
- deriva_ml/core/mixins/rid_resolution.py +204 -0
- deriva_ml/core/mixins/vocabulary.py +400 -0
- deriva_ml/core/mixins/workflow.py +322 -0
- deriva_ml/core/validation.py +389 -0
- deriva_ml/dataset/__init__.py +2 -1
- deriva_ml/dataset/aux_classes.py +20 -4
- deriva_ml/dataset/catalog_graph.py +575 -0
- deriva_ml/dataset/dataset.py +1242 -1008
- deriva_ml/dataset/dataset_bag.py +1311 -182
- deriva_ml/dataset/history.py +27 -14
- deriva_ml/dataset/upload.py +225 -38
- deriva_ml/demo_catalog.py +126 -110
- deriva_ml/execution/__init__.py +46 -2
- deriva_ml/execution/base_config.py +639 -0
- deriva_ml/execution/execution.py +543 -242
- deriva_ml/execution/execution_configuration.py +26 -11
- deriva_ml/execution/execution_record.py +592 -0
- deriva_ml/execution/find_caller.py +298 -0
- deriva_ml/execution/model_protocol.py +175 -0
- deriva_ml/execution/multirun_config.py +153 -0
- deriva_ml/execution/runner.py +595 -0
- deriva_ml/execution/workflow.py +223 -34
- deriva_ml/experiment/__init__.py +8 -0
- deriva_ml/experiment/experiment.py +411 -0
- deriva_ml/feature.py +6 -1
- deriva_ml/install_kernel.py +143 -6
- deriva_ml/interfaces.py +862 -0
- deriva_ml/model/__init__.py +99 -0
- deriva_ml/model/annotations.py +1278 -0
- deriva_ml/model/catalog.py +286 -60
- deriva_ml/model/database.py +144 -649
- deriva_ml/model/deriva_ml_database.py +308 -0
- deriva_ml/model/handles.py +14 -0
- deriva_ml/run_model.py +319 -0
- deriva_ml/run_notebook.py +507 -38
- deriva_ml/schema/__init__.py +18 -2
- deriva_ml/schema/annotations.py +62 -33
- deriva_ml/schema/create_schema.py +169 -69
- deriva_ml/schema/validation.py +601 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/METADATA +4 -4
- deriva_ml-1.17.12.dist-info/RECORD +77 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/WHEEL +1 -1
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/entry_points.txt +1 -0
- deriva_ml/protocols/dataset.py +0 -19
- deriva_ml/test.py +0 -94
- deriva_ml-1.17.10.dist-info/RECORD +0 -45
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
"""Localize remote hatrac assets to a local catalog server."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import tempfile
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
from urllib.parse import urlparse, quote as urlquote
|
|
11
|
+
|
|
12
|
+
from deriva.core import ErmrestCatalog, HatracStore, get_credential
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from deriva_ml import DerivaML
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("deriva_ml")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class LocalizeResult:
|
|
22
|
+
"""Result of an asset localization operation.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
assets_processed: Number of assets successfully localized.
|
|
26
|
+
assets_skipped: Number of assets skipped (already local or errors).
|
|
27
|
+
assets_failed: Number of assets that failed to localize.
|
|
28
|
+
errors: List of error messages for failed assets.
|
|
29
|
+
localized_assets: List of (RID, old_url, new_url) tuples for successfully localized assets.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
assets_processed: int = 0
|
|
33
|
+
assets_skipped: int = 0
|
|
34
|
+
assets_failed: int = 0
|
|
35
|
+
errors: list[str] = field(default_factory=list)
|
|
36
|
+
localized_assets: list[tuple[str, str, str]] = field(default_factory=list)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def localize_assets(
|
|
40
|
+
catalog: DerivaML | ErmrestCatalog,
|
|
41
|
+
asset_table: str,
|
|
42
|
+
asset_rids: list[str],
|
|
43
|
+
schema_name: str | None = None,
|
|
44
|
+
hatrac_namespace: str | None = None,
|
|
45
|
+
chunk_size: int | None = None,
|
|
46
|
+
dry_run: bool = False,
|
|
47
|
+
) -> LocalizeResult:
|
|
48
|
+
"""Localize remote hatrac assets to the local catalog server.
|
|
49
|
+
|
|
50
|
+
Downloads assets from remote hatrac servers (determined from the URL in each
|
|
51
|
+
asset record) and uploads them to the local hatrac server, updating the asset
|
|
52
|
+
table URLs to point to the local copies.
|
|
53
|
+
|
|
54
|
+
This is useful after cloning a catalog with asset_mode="refs" where the
|
|
55
|
+
asset URLs still point to the source server. Use this function to make
|
|
56
|
+
the assets fully local.
|
|
57
|
+
|
|
58
|
+
The source hatrac server for each asset is determined automatically from
|
|
59
|
+
the URL stored in the asset record.
|
|
60
|
+
|
|
61
|
+
This function is optimized for bulk operations:
|
|
62
|
+
- Fetches all asset records in a single query
|
|
63
|
+
- Caches connections to remote hatrac servers
|
|
64
|
+
- Batches catalog updates for efficiency
|
|
65
|
+
- Supports chunked uploads for large files
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
catalog: A DerivaML instance or ErmrestCatalog connected to the catalog.
|
|
69
|
+
asset_table: Name of the asset table containing the assets to localize.
|
|
70
|
+
asset_rids: List of asset RIDs to localize. Each RID should be a record
|
|
71
|
+
in the asset table.
|
|
72
|
+
schema_name: Schema containing the asset table. If None, searches all schemas.
|
|
73
|
+
hatrac_namespace: Optional hatrac namespace for uploaded files. If None,
|
|
74
|
+
uses "/hatrac/{asset_table}/{md5}.{filename}" pattern.
|
|
75
|
+
chunk_size: Optional chunk size in bytes for large file uploads. If None,
|
|
76
|
+
uses default chunking behavior.
|
|
77
|
+
dry_run: If True, only report what would be done without making changes.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
LocalizeResult with counts and details of the operation.
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If asset_table is not found.
|
|
84
|
+
|
|
85
|
+
Examples:
|
|
86
|
+
Localize specific assets using DerivaML:
|
|
87
|
+
>>> from deriva_ml import DerivaML
|
|
88
|
+
>>> ml = DerivaML("localhost", "42")
|
|
89
|
+
>>> result = localize_assets(
|
|
90
|
+
... ml,
|
|
91
|
+
... asset_table="Image",
|
|
92
|
+
... asset_rids=["1-ABC", "2-DEF", "3-GHI"],
|
|
93
|
+
... )
|
|
94
|
+
>>> print(f"Localized {result.assets_processed} assets")
|
|
95
|
+
|
|
96
|
+
Localize using ErmrestCatalog:
|
|
97
|
+
>>> from deriva.core import DerivaServer
|
|
98
|
+
>>> server = DerivaServer("https", "localhost")
|
|
99
|
+
>>> catalog = server.connect_ermrest("42")
|
|
100
|
+
>>> result = localize_assets(
|
|
101
|
+
... catalog,
|
|
102
|
+
... asset_table="Model_Weights",
|
|
103
|
+
... asset_rids=["4-JKL"],
|
|
104
|
+
... dry_run=True,
|
|
105
|
+
... )
|
|
106
|
+
"""
|
|
107
|
+
result = LocalizeResult()
|
|
108
|
+
|
|
109
|
+
# Extract catalog and hostname from the input
|
|
110
|
+
ermrest_catalog, hostname, credential = _get_catalog_info(catalog)
|
|
111
|
+
|
|
112
|
+
# Create pathbuilder for datapath queries
|
|
113
|
+
pb = ermrest_catalog.getPathBuilder()
|
|
114
|
+
|
|
115
|
+
# Find the asset table
|
|
116
|
+
table_path, found_schema = _find_asset_table_path(pb, asset_table, schema_name)
|
|
117
|
+
if table_path is None:
|
|
118
|
+
raise ValueError(f"Asset table '{asset_table}' not found in catalog")
|
|
119
|
+
|
|
120
|
+
# Set up local hatrac
|
|
121
|
+
local_hatrac = HatracStore("https", hostname, credentials=credential)
|
|
122
|
+
|
|
123
|
+
# Determine hatrac namespace
|
|
124
|
+
if hatrac_namespace is None:
|
|
125
|
+
hatrac_namespace = f"/hatrac/{asset_table}"
|
|
126
|
+
|
|
127
|
+
# Fetch all asset records in a single query
|
|
128
|
+
logger.info(f"Fetching {len(asset_rids)} asset records...")
|
|
129
|
+
all_records = _fetch_asset_records(table_path, asset_rids)
|
|
130
|
+
|
|
131
|
+
# Build a map of RID -> record for easy lookup
|
|
132
|
+
records_by_rid = {r["RID"]: r for r in all_records}
|
|
133
|
+
|
|
134
|
+
# Identify which assets need to be localized
|
|
135
|
+
assets_to_localize = []
|
|
136
|
+
for rid in asset_rids:
|
|
137
|
+
record = records_by_rid.get(rid)
|
|
138
|
+
if record is None:
|
|
139
|
+
logger.warning(f"Asset {rid} not found")
|
|
140
|
+
result.assets_skipped += 1
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
current_url = record.get("URL")
|
|
144
|
+
if not current_url:
|
|
145
|
+
logger.warning(f"Asset {rid} has no URL, skipping")
|
|
146
|
+
result.assets_skipped += 1
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
# Parse the URL to get source hostname
|
|
150
|
+
parsed_url = urlparse(current_url)
|
|
151
|
+
source_hostname = parsed_url.netloc
|
|
152
|
+
|
|
153
|
+
if not source_hostname:
|
|
154
|
+
logger.info(f"Asset {rid} has relative URL, already local")
|
|
155
|
+
result.assets_skipped += 1
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
if source_hostname == hostname:
|
|
159
|
+
logger.info(f"Asset {rid} is already local, skipping")
|
|
160
|
+
result.assets_skipped += 1
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
# Extract the hatrac path from the URL
|
|
164
|
+
source_path = _extract_hatrac_path(current_url)
|
|
165
|
+
if not source_path:
|
|
166
|
+
logger.warning(f"Could not extract hatrac path from URL: {current_url}")
|
|
167
|
+
result.assets_skipped += 1
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
assets_to_localize.append({
|
|
171
|
+
"rid": rid,
|
|
172
|
+
"record": record,
|
|
173
|
+
"source_hostname": source_hostname,
|
|
174
|
+
"source_path": source_path,
|
|
175
|
+
"current_url": current_url,
|
|
176
|
+
})
|
|
177
|
+
|
|
178
|
+
if not assets_to_localize:
|
|
179
|
+
logger.info("No assets need to be localized")
|
|
180
|
+
return result
|
|
181
|
+
|
|
182
|
+
logger.info(f"Localizing {len(assets_to_localize)} assets...")
|
|
183
|
+
|
|
184
|
+
if dry_run:
|
|
185
|
+
for asset_info in assets_to_localize:
|
|
186
|
+
logger.info(
|
|
187
|
+
f"[DRY RUN] Would download {asset_info['source_path']} from "
|
|
188
|
+
f"{asset_info['source_hostname']} and upload to {hatrac_namespace}"
|
|
189
|
+
)
|
|
190
|
+
result.assets_processed += 1
|
|
191
|
+
return result
|
|
192
|
+
|
|
193
|
+
# Cache for remote hatrac connections (keyed by hostname)
|
|
194
|
+
remote_hatrac_cache: dict[str, HatracStore] = {}
|
|
195
|
+
|
|
196
|
+
# Ensure local namespace exists
|
|
197
|
+
_ensure_hatrac_namespace(local_hatrac, hatrac_namespace)
|
|
198
|
+
|
|
199
|
+
# Collect updates for batch catalog update
|
|
200
|
+
catalog_updates: list[dict] = []
|
|
201
|
+
|
|
202
|
+
# Process each asset
|
|
203
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
204
|
+
scratch_dir = Path(tmpdir)
|
|
205
|
+
|
|
206
|
+
for i, asset_info in enumerate(assets_to_localize):
|
|
207
|
+
rid = asset_info["rid"]
|
|
208
|
+
record = asset_info["record"]
|
|
209
|
+
source_hostname = asset_info["source_hostname"]
|
|
210
|
+
source_path = asset_info["source_path"]
|
|
211
|
+
current_url = asset_info["current_url"]
|
|
212
|
+
filename = record.get("Filename")
|
|
213
|
+
md5 = record.get("MD5")
|
|
214
|
+
|
|
215
|
+
logger.info(f"[{i+1}/{len(assets_to_localize)}] Localizing {rid}: {filename} from {source_hostname}")
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
# Get or create remote hatrac connection
|
|
219
|
+
if source_hostname not in remote_hatrac_cache:
|
|
220
|
+
source_cred = get_credential(source_hostname)
|
|
221
|
+
remote_hatrac_cache[source_hostname] = HatracStore(
|
|
222
|
+
"https", source_hostname, credentials=source_cred
|
|
223
|
+
)
|
|
224
|
+
source_hatrac = remote_hatrac_cache[source_hostname]
|
|
225
|
+
|
|
226
|
+
# Download from source
|
|
227
|
+
local_file = scratch_dir / (md5 or rid) / (filename or "asset")
|
|
228
|
+
local_file.parent.mkdir(parents=True, exist_ok=True)
|
|
229
|
+
|
|
230
|
+
source_hatrac.get_obj(path=source_path, destfilename=str(local_file))
|
|
231
|
+
|
|
232
|
+
# Upload to local hatrac
|
|
233
|
+
dest_path = f"{hatrac_namespace}/{md5}.{filename}" if md5 and filename else f"{hatrac_namespace}/{rid}"
|
|
234
|
+
|
|
235
|
+
new_url = local_hatrac.put_loc(
|
|
236
|
+
dest_path,
|
|
237
|
+
str(local_file),
|
|
238
|
+
headers={"Content-Disposition": f"filename*=UTF-8''{urlquote(filename or 'asset')}"},
|
|
239
|
+
chunked=chunk_size is not None,
|
|
240
|
+
chunk_size=chunk_size or 0,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Queue the catalog update
|
|
244
|
+
catalog_updates.append({"RID": rid, "URL": new_url})
|
|
245
|
+
|
|
246
|
+
logger.info(f"Localized asset {rid}: {current_url} -> {new_url}")
|
|
247
|
+
result.assets_processed += 1
|
|
248
|
+
result.localized_assets.append((rid, current_url, new_url))
|
|
249
|
+
|
|
250
|
+
# Clean up scratch file
|
|
251
|
+
if local_file.exists():
|
|
252
|
+
local_file.unlink()
|
|
253
|
+
|
|
254
|
+
except Exception as e:
|
|
255
|
+
error_msg = f"Failed to localize asset {rid}: {e}"
|
|
256
|
+
logger.error(error_msg)
|
|
257
|
+
result.errors.append(error_msg)
|
|
258
|
+
result.assets_failed += 1
|
|
259
|
+
|
|
260
|
+
# Batch update the catalog records
|
|
261
|
+
if catalog_updates:
|
|
262
|
+
logger.info(f"Updating {len(catalog_updates)} catalog records...")
|
|
263
|
+
try:
|
|
264
|
+
table_path.path.update(catalog_updates)
|
|
265
|
+
logger.info("Catalog records updated successfully")
|
|
266
|
+
except Exception as e:
|
|
267
|
+
# If batch update fails, try individual updates as fallback
|
|
268
|
+
logger.warning(f"Batch update failed ({e}), falling back to individual updates...")
|
|
269
|
+
for update in catalog_updates:
|
|
270
|
+
try:
|
|
271
|
+
table_path.path.filter(table_path.RID == update["RID"]).update([update])
|
|
272
|
+
except Exception as e2:
|
|
273
|
+
error_msg = f"Failed to update catalog record {update['RID']}: {e2}"
|
|
274
|
+
logger.error(error_msg)
|
|
275
|
+
result.errors.append(error_msg)
|
|
276
|
+
|
|
277
|
+
return result
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _get_catalog_info(
|
|
281
|
+
catalog: DerivaML | ErmrestCatalog,
|
|
282
|
+
) -> tuple[ErmrestCatalog, str, dict | None]:
|
|
283
|
+
"""Extract catalog, hostname, and credential from a DerivaML or ErmrestCatalog.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
catalog: DerivaML instance or ErmrestCatalog.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Tuple of (ErmrestCatalog, hostname, credential).
|
|
290
|
+
"""
|
|
291
|
+
# Check if it's a DerivaML instance
|
|
292
|
+
if hasattr(catalog, "catalog") and hasattr(catalog, "host_name"):
|
|
293
|
+
# It's a DerivaML instance
|
|
294
|
+
hostname = catalog.host_name
|
|
295
|
+
ermrest_catalog = catalog.catalog
|
|
296
|
+
credential = getattr(catalog, "credential", None) or get_credential(hostname)
|
|
297
|
+
return (ermrest_catalog, hostname, credential)
|
|
298
|
+
|
|
299
|
+
# It's an ErmrestCatalog
|
|
300
|
+
ermrest_catalog = catalog
|
|
301
|
+
# Extract hostname from the catalog's server_uri
|
|
302
|
+
server_uri = ermrest_catalog.get_server_uri()
|
|
303
|
+
parsed = urlparse(server_uri)
|
|
304
|
+
hostname = parsed.netloc
|
|
305
|
+
|
|
306
|
+
credential = get_credential(hostname) if hostname else None
|
|
307
|
+
return (ermrest_catalog, hostname, credential)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _find_asset_table_path(
|
|
311
|
+
pb,
|
|
312
|
+
table_name: str,
|
|
313
|
+
schema_name: str | None,
|
|
314
|
+
) -> tuple | None:
|
|
315
|
+
"""Find an asset table using pathbuilder.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
pb: PathBuilder instance.
|
|
319
|
+
table_name: Name of the table to find.
|
|
320
|
+
schema_name: Optional schema name. If None, searches all schemas.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Tuple of (table_path, schema_name) if found, (None, None) otherwise.
|
|
324
|
+
"""
|
|
325
|
+
if schema_name:
|
|
326
|
+
try:
|
|
327
|
+
table_path = pb.schemas[schema_name].tables[table_name]
|
|
328
|
+
return (table_path, schema_name)
|
|
329
|
+
except KeyError:
|
|
330
|
+
return (None, None)
|
|
331
|
+
|
|
332
|
+
# Search all schemas
|
|
333
|
+
for sname in pb.schemas:
|
|
334
|
+
try:
|
|
335
|
+
table_path = pb.schemas[sname].tables[table_name]
|
|
336
|
+
return (table_path, sname)
|
|
337
|
+
except KeyError:
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
return (None, None)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _fetch_asset_records(table_path, rids: list[str]) -> list[dict]:
|
|
344
|
+
"""Fetch multiple asset records in a single query.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
table_path: Datapath table object.
|
|
348
|
+
rids: List of RIDs to fetch.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
List of record dictionaries.
|
|
352
|
+
"""
|
|
353
|
+
if not rids:
|
|
354
|
+
return []
|
|
355
|
+
|
|
356
|
+
# Use datapath to fetch all records with RIDs in the list
|
|
357
|
+
# Build a filter for RID in (rid1, rid2, ...)
|
|
358
|
+
from deriva.core.datapath import DataPathException
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
# Fetch records where RID is in our list
|
|
362
|
+
# Use multiple OR conditions since datapath doesn't have an "in" operator
|
|
363
|
+
path = table_path.path
|
|
364
|
+
|
|
365
|
+
# Build filter: RID == rid1 OR RID == rid2 OR ...
|
|
366
|
+
filter_expr = None
|
|
367
|
+
for rid in rids:
|
|
368
|
+
condition = table_path.RID == rid
|
|
369
|
+
if filter_expr is None:
|
|
370
|
+
filter_expr = condition
|
|
371
|
+
else:
|
|
372
|
+
filter_expr = filter_expr | condition
|
|
373
|
+
|
|
374
|
+
if filter_expr is not None:
|
|
375
|
+
path = path.filter(filter_expr)
|
|
376
|
+
|
|
377
|
+
return list(path.entities().fetch())
|
|
378
|
+
except DataPathException as e:
|
|
379
|
+
logger.warning(f"Bulk fetch failed: {e}, falling back to individual fetches")
|
|
380
|
+
# Fallback: fetch records individually
|
|
381
|
+
records = []
|
|
382
|
+
for rid in rids:
|
|
383
|
+
try:
|
|
384
|
+
result = list(table_path.path.filter(table_path.RID == rid).entities().fetch())
|
|
385
|
+
records.extend(result)
|
|
386
|
+
except Exception:
|
|
387
|
+
pass
|
|
388
|
+
return records
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def _extract_hatrac_path(url: str) -> str | None:
|
|
392
|
+
"""Extract the hatrac path from a full URL.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
url: Full URL like "https://host/hatrac/namespace/file"
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Hatrac path like "/hatrac/namespace/file" or None if not a hatrac URL.
|
|
399
|
+
"""
|
|
400
|
+
parsed = urlparse(url)
|
|
401
|
+
path = parsed.path
|
|
402
|
+
|
|
403
|
+
if "/hatrac/" in path:
|
|
404
|
+
# Find the /hatrac/ part and return from there
|
|
405
|
+
idx = path.find("/hatrac/")
|
|
406
|
+
return path[idx:]
|
|
407
|
+
|
|
408
|
+
if path.startswith("/hatrac/"):
|
|
409
|
+
return path
|
|
410
|
+
|
|
411
|
+
return None
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def _ensure_hatrac_namespace(hatrac: HatracStore, namespace: str) -> None:
|
|
415
|
+
"""Ensure a hatrac namespace exists, creating it if necessary.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
hatrac: HatracStore instance.
|
|
419
|
+
namespace: Namespace path like "/hatrac/MyTable".
|
|
420
|
+
"""
|
|
421
|
+
try:
|
|
422
|
+
# Try to create the namespace (will fail if exists, which is fine)
|
|
423
|
+
hatrac.create_namespace(namespace, parents=True)
|
|
424
|
+
except Exception:
|
|
425
|
+
# Namespace likely already exists
|
|
426
|
+
pass
|
deriva_ml/core/__init__.py
CHANGED
|
@@ -1,3 +1,21 @@
|
|
|
1
|
+
"""Core module for DerivaML.
|
|
2
|
+
|
|
3
|
+
This module provides the primary public interface to DerivaML functionality. It exports
|
|
4
|
+
the main DerivaML class along with configuration, definitions, and exceptions needed
|
|
5
|
+
for interacting with Deriva-based ML catalogs.
|
|
6
|
+
|
|
7
|
+
Key exports:
|
|
8
|
+
- DerivaML: Main class for catalog operations and ML workflow management.
|
|
9
|
+
- DerivaMLConfig: Configuration class for DerivaML instances.
|
|
10
|
+
- Exceptions: DerivaMLException and specialized exception types.
|
|
11
|
+
- Definitions: Type definitions, enums, and constants used throughout the package.
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
>>> from deriva_ml.core import DerivaML, DerivaMLConfig
|
|
15
|
+
>>> ml = DerivaML('deriva.example.org', 'my_catalog')
|
|
16
|
+
>>> datasets = ml.find_datasets()
|
|
17
|
+
"""
|
|
18
|
+
|
|
1
19
|
from deriva_ml.core.base import DerivaML
|
|
2
20
|
from deriva_ml.core.config import DerivaMLConfig
|
|
3
21
|
from deriva_ml.core.definitions import (
|
|
@@ -15,6 +33,8 @@ from deriva_ml.core.definitions import (
|
|
|
15
33
|
UploadState,
|
|
16
34
|
)
|
|
17
35
|
from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm, DerivaMLTableTypeError
|
|
36
|
+
from deriva_ml.core.logging_config import LoggerMixin, configure_logging, get_logger, is_hydra_initialized
|
|
37
|
+
from deriva_ml.core.validation import DERIVA_ML_CONFIG, STRICT_VALIDATION_CONFIG, VALIDATION_CONFIG
|
|
18
38
|
|
|
19
39
|
__all__ = [
|
|
20
40
|
"DerivaML",
|
|
@@ -36,4 +56,13 @@ __all__ = [
|
|
|
36
56
|
"MLVocab",
|
|
37
57
|
"TableDefinition",
|
|
38
58
|
"UploadState",
|
|
59
|
+
# Validation
|
|
60
|
+
"VALIDATION_CONFIG",
|
|
61
|
+
"DERIVA_ML_CONFIG",
|
|
62
|
+
"STRICT_VALIDATION_CONFIG",
|
|
63
|
+
# Logging
|
|
64
|
+
"get_logger",
|
|
65
|
+
"configure_logging",
|
|
66
|
+
"is_hydra_initialized",
|
|
67
|
+
"LoggerMixin",
|
|
39
68
|
]
|