deriva-ml 1.17.10__py3-none-any.whl → 1.17.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/__init__.py +43 -1
- deriva_ml/asset/__init__.py +17 -0
- deriva_ml/asset/asset.py +357 -0
- deriva_ml/asset/aux_classes.py +100 -0
- deriva_ml/bump_version.py +254 -11
- deriva_ml/catalog/__init__.py +21 -0
- deriva_ml/catalog/clone.py +1199 -0
- deriva_ml/catalog/localize.py +426 -0
- deriva_ml/core/__init__.py +29 -0
- deriva_ml/core/base.py +817 -1067
- deriva_ml/core/config.py +169 -21
- deriva_ml/core/constants.py +120 -19
- deriva_ml/core/definitions.py +123 -13
- deriva_ml/core/enums.py +47 -73
- deriva_ml/core/ermrest.py +226 -193
- deriva_ml/core/exceptions.py +297 -14
- deriva_ml/core/filespec.py +99 -28
- deriva_ml/core/logging_config.py +225 -0
- deriva_ml/core/mixins/__init__.py +42 -0
- deriva_ml/core/mixins/annotation.py +915 -0
- deriva_ml/core/mixins/asset.py +384 -0
- deriva_ml/core/mixins/dataset.py +237 -0
- deriva_ml/core/mixins/execution.py +408 -0
- deriva_ml/core/mixins/feature.py +365 -0
- deriva_ml/core/mixins/file.py +263 -0
- deriva_ml/core/mixins/path_builder.py +145 -0
- deriva_ml/core/mixins/rid_resolution.py +204 -0
- deriva_ml/core/mixins/vocabulary.py +400 -0
- deriva_ml/core/mixins/workflow.py +322 -0
- deriva_ml/core/validation.py +389 -0
- deriva_ml/dataset/__init__.py +2 -1
- deriva_ml/dataset/aux_classes.py +20 -4
- deriva_ml/dataset/catalog_graph.py +575 -0
- deriva_ml/dataset/dataset.py +1242 -1008
- deriva_ml/dataset/dataset_bag.py +1311 -182
- deriva_ml/dataset/history.py +27 -14
- deriva_ml/dataset/upload.py +225 -38
- deriva_ml/demo_catalog.py +126 -110
- deriva_ml/execution/__init__.py +46 -2
- deriva_ml/execution/base_config.py +639 -0
- deriva_ml/execution/execution.py +543 -242
- deriva_ml/execution/execution_configuration.py +26 -11
- deriva_ml/execution/execution_record.py +592 -0
- deriva_ml/execution/find_caller.py +298 -0
- deriva_ml/execution/model_protocol.py +175 -0
- deriva_ml/execution/multirun_config.py +153 -0
- deriva_ml/execution/runner.py +595 -0
- deriva_ml/execution/workflow.py +223 -34
- deriva_ml/experiment/__init__.py +8 -0
- deriva_ml/experiment/experiment.py +411 -0
- deriva_ml/feature.py +6 -1
- deriva_ml/install_kernel.py +143 -6
- deriva_ml/interfaces.py +862 -0
- deriva_ml/model/__init__.py +99 -0
- deriva_ml/model/annotations.py +1278 -0
- deriva_ml/model/catalog.py +286 -60
- deriva_ml/model/database.py +144 -649
- deriva_ml/model/deriva_ml_database.py +308 -0
- deriva_ml/model/handles.py +14 -0
- deriva_ml/run_model.py +319 -0
- deriva_ml/run_notebook.py +507 -38
- deriva_ml/schema/__init__.py +18 -2
- deriva_ml/schema/annotations.py +62 -33
- deriva_ml/schema/create_schema.py +169 -69
- deriva_ml/schema/validation.py +601 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -4
- deriva_ml-1.17.11.dist-info/RECORD +77 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +1 -0
- deriva_ml/protocols/dataset.py +0 -19
- deriva_ml/test.py +0 -94
- deriva_ml-1.17.10.dist-info/RECORD +0 -45
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""Database-backed implementation of DerivaMLCatalog protocol.
|
|
2
|
+
|
|
3
|
+
This module provides a DerivaMLDatabase class that implements the DerivaMLCatalog
|
|
4
|
+
protocol using a DatabaseModel (SQLite) instead of a live catalog connection.
|
|
5
|
+
This allows code written against the DerivaMLCatalog protocol to work identically
|
|
6
|
+
with both live catalogs and downloaded bags.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Generator, Iterable
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from deriva.core.ermrest_model import Table
|
|
16
|
+
|
|
17
|
+
from deriva_ml.core.definitions import RID, MLVocab, VocabularyTerm
|
|
18
|
+
from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
|
|
19
|
+
from deriva_ml.dataset.aux_classes import DatasetSpec, DatasetVersion
|
|
20
|
+
from deriva_ml.dataset.dataset_bag import DatasetBag
|
|
21
|
+
from deriva_ml.feature import Feature
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from deriva_ml.model.database import DatabaseModel
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DerivaMLDatabase:
|
|
28
|
+
"""Database-backed implementation of DerivaMLCatalog protocol.
|
|
29
|
+
|
|
30
|
+
Provides the same interface as DerivaML but operates on downloaded
|
|
31
|
+
bags via SQLite. Read-only operations work; write operations raise
|
|
32
|
+
DerivaMLException since bags are immutable snapshots.
|
|
33
|
+
|
|
34
|
+
This class allows code written against the DerivaMLCatalog protocol
|
|
35
|
+
to work identically with both live catalogs (DerivaML) and downloaded
|
|
36
|
+
bags (DerivaMLDatabase).
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
ml_schema: Name of the ML schema.
|
|
40
|
+
domain_schemas: Frozenset of domain schema names.
|
|
41
|
+
default_schema: Default schema for table creation.
|
|
42
|
+
model: The underlying DatabaseModel.
|
|
43
|
+
working_dir: Working directory path.
|
|
44
|
+
cache_dir: Cache directory path.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, database_model: "DatabaseModel"):
|
|
48
|
+
"""Create a new DerivaMLDatabase.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
database_model: The DatabaseModel containing the SQLite database.
|
|
52
|
+
"""
|
|
53
|
+
self._database_model = database_model
|
|
54
|
+
|
|
55
|
+
# ==================== Protocol Properties ====================
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def ml_schema(self) -> str:
|
|
59
|
+
"""Get the ML schema name."""
|
|
60
|
+
return self._database_model.ml_schema
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def domain_schemas(self) -> frozenset[str]:
|
|
64
|
+
"""Get the domain schema names."""
|
|
65
|
+
return self._database_model.domain_schemas
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def default_schema(self) -> str | None:
|
|
69
|
+
"""Get the default schema name."""
|
|
70
|
+
return self._database_model.default_schema
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def model(self) -> "DatabaseModel":
|
|
74
|
+
"""Get the underlying database model."""
|
|
75
|
+
return self._database_model
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def working_dir(self) -> Path:
|
|
79
|
+
"""Get the working directory path."""
|
|
80
|
+
return self._database_model.dbase_path
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def cache_dir(self) -> Path:
|
|
84
|
+
"""Get the cache directory path (same as working_dir for bags)."""
|
|
85
|
+
return self._database_model.dbase_path
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def catalog_id(self) -> str:
|
|
89
|
+
"""Get the catalog ID (derived from bag path)."""
|
|
90
|
+
return str(self._database_model.bag_path)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def _dataset_table(self) -> Table:
|
|
94
|
+
"""Get the Dataset table from the model."""
|
|
95
|
+
return self._database_model.dataset_table
|
|
96
|
+
|
|
97
|
+
# ==================== Read Operations (Supported) ====================
|
|
98
|
+
|
|
99
|
+
def lookup_dataset(
|
|
100
|
+
self, dataset: RID | DatasetSpec, deleted: bool = False
|
|
101
|
+
) -> DatasetBag:
|
|
102
|
+
"""Look up a dataset by RID or spec.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
dataset: Dataset RID or DatasetSpec to look up.
|
|
106
|
+
deleted: Whether to include deleted datasets (ignored for bags).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
DatasetBag for the specified dataset.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
DerivaMLException: If dataset not found in bag.
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(dataset, DatasetSpec):
|
|
115
|
+
rid = dataset.rid
|
|
116
|
+
else:
|
|
117
|
+
rid = dataset
|
|
118
|
+
|
|
119
|
+
# Validate the dataset exists
|
|
120
|
+
self._database_model.rid_lookup(rid)
|
|
121
|
+
|
|
122
|
+
# Get dataset metadata
|
|
123
|
+
dataset_record = next(
|
|
124
|
+
(d for d in self._database_model._get_table_contents("Dataset") if d["RID"] == rid),
|
|
125
|
+
None
|
|
126
|
+
)
|
|
127
|
+
if not dataset_record:
|
|
128
|
+
raise DerivaMLException(f"Dataset {rid} not found in bag")
|
|
129
|
+
|
|
130
|
+
# Get dataset types from association table
|
|
131
|
+
atable = f"Dataset_{MLVocab.dataset_type.value}"
|
|
132
|
+
ds_types = [
|
|
133
|
+
t[MLVocab.dataset_type.value]
|
|
134
|
+
for t in self._database_model._get_table_contents(atable)
|
|
135
|
+
if t["Dataset"] == rid
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
return DatasetBag(
|
|
139
|
+
catalog=self,
|
|
140
|
+
dataset_rid=rid,
|
|
141
|
+
description=dataset_record.get("Description", ""),
|
|
142
|
+
execution_rid=(self._database_model._get_dataset_execution(rid) or {}).get("Execution"),
|
|
143
|
+
dataset_types=ds_types,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def find_datasets(self, deleted: bool = False) -> Iterable[DatasetBag]:
|
|
147
|
+
"""List all datasets in the bag.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
deleted: Whether to include deleted datasets (ignored for bags).
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Iterable of DatasetBag objects.
|
|
154
|
+
"""
|
|
155
|
+
# Get dataset types for all datasets from association table
|
|
156
|
+
atable = f"Dataset_{MLVocab.dataset_type.value}"
|
|
157
|
+
ds_types = list(self._database_model._get_table_contents(atable))
|
|
158
|
+
|
|
159
|
+
datasets = []
|
|
160
|
+
for dataset in self._database_model._get_table_contents("Dataset"):
|
|
161
|
+
my_types = [t[MLVocab.dataset_type.value] for t in ds_types if t["Dataset"] == dataset["RID"]]
|
|
162
|
+
datasets.append(
|
|
163
|
+
DatasetBag(
|
|
164
|
+
catalog=self,
|
|
165
|
+
dataset_rid=dataset["RID"],
|
|
166
|
+
description=dataset.get("Description", ""),
|
|
167
|
+
execution_rid=(self._database_model._get_dataset_execution(dataset["RID"]) or {}).get("Execution"),
|
|
168
|
+
dataset_types=my_types,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
return datasets
|
|
172
|
+
|
|
173
|
+
def lookup_term(self, table: str | Table, term_name: str) -> VocabularyTerm:
|
|
174
|
+
"""Look up a vocabulary term by name.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
table: Vocabulary table to search.
|
|
178
|
+
term_name: Name or synonym of the term.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
The matching VocabularyTerm.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
DerivaMLException: If table is not a vocabulary or term not found.
|
|
185
|
+
"""
|
|
186
|
+
# Get table object if string provided
|
|
187
|
+
if isinstance(table, str):
|
|
188
|
+
table_obj = self._database_model.name_to_table(table)
|
|
189
|
+
else:
|
|
190
|
+
table_obj = table
|
|
191
|
+
|
|
192
|
+
# Validate it's a vocabulary table
|
|
193
|
+
if not self._database_model.is_vocabulary(table_obj):
|
|
194
|
+
raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
|
|
195
|
+
|
|
196
|
+
# Search for term in SQLite
|
|
197
|
+
for term in self.get_table_as_dict(table_obj.name):
|
|
198
|
+
if term_name == term.get("Name") or (
|
|
199
|
+
term.get("Synonyms") and term_name in term.get("Synonyms", [])
|
|
200
|
+
):
|
|
201
|
+
# Convert synonyms to list if needed
|
|
202
|
+
synonyms = term.get("Synonyms")
|
|
203
|
+
if synonyms and not isinstance(synonyms, list):
|
|
204
|
+
synonyms = list(synonyms)
|
|
205
|
+
term["Synonyms"] = synonyms or []
|
|
206
|
+
return VocabularyTerm.model_validate(term)
|
|
207
|
+
|
|
208
|
+
raise DerivaMLInvalidTerm(table, term_name)
|
|
209
|
+
|
|
210
|
+
def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
|
|
211
|
+
"""Get table contents as a pandas DataFrame.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
table: Name of the table to retrieve.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
DataFrame containing all table contents.
|
|
218
|
+
"""
|
|
219
|
+
return pd.DataFrame(list(self.get_table_as_dict(table)))
|
|
220
|
+
|
|
221
|
+
def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
|
|
222
|
+
"""Get table contents as dictionaries.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
table: Name of the table to retrieve.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Generator yielding dictionaries for each row.
|
|
229
|
+
"""
|
|
230
|
+
yield from self._database_model._get_table_contents(table)
|
|
231
|
+
|
|
232
|
+
def list_dataset_element_types(self) -> list[Table]:
|
|
233
|
+
"""List the types of elements that can be in datasets.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
List of Table objects representing element types.
|
|
237
|
+
"""
|
|
238
|
+
return self._database_model.list_dataset_element_types()
|
|
239
|
+
|
|
240
|
+
def find_features(self, table: str | Table) -> Iterable[Feature]:
|
|
241
|
+
"""Find features associated with a table.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
table: Table to find features for.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Iterable of Feature objects.
|
|
248
|
+
"""
|
|
249
|
+
return self._database_model.find_features(table)
|
|
250
|
+
|
|
251
|
+
# ==================== Write Operations (Not Supported) ====================
|
|
252
|
+
|
|
253
|
+
def create_dataset(
|
|
254
|
+
self,
|
|
255
|
+
execution_rid: RID | None = None,
|
|
256
|
+
version: DatasetVersion | str | None = None,
|
|
257
|
+
description: str = "",
|
|
258
|
+
dataset_types: list[str] | None = None,
|
|
259
|
+
) -> DatasetBag:
|
|
260
|
+
"""Create a new dataset.
|
|
261
|
+
|
|
262
|
+
Raises:
|
|
263
|
+
DerivaMLException: Always, since bags are read-only.
|
|
264
|
+
"""
|
|
265
|
+
raise DerivaMLException(
|
|
266
|
+
"Cannot create datasets in a downloaded bag. "
|
|
267
|
+
"Bags are immutable snapshots of catalog data."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def pathBuilder(self):
|
|
271
|
+
"""Get the catalog path builder.
|
|
272
|
+
|
|
273
|
+
Raises:
|
|
274
|
+
DerivaMLException: Always, since SQLite doesn't use pathBuilder.
|
|
275
|
+
"""
|
|
276
|
+
raise DerivaMLException(
|
|
277
|
+
"pathBuilder is not available for database-backed catalogs. "
|
|
278
|
+
"Use get_table_as_dict() or get_table_as_dataframe() instead."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def catalog_snapshot(self, version_snapshot: str):
|
|
282
|
+
"""Create a catalog snapshot.
|
|
283
|
+
|
|
284
|
+
Raises:
|
|
285
|
+
DerivaMLException: Always, since bags are already snapshots.
|
|
286
|
+
"""
|
|
287
|
+
raise DerivaMLException(
|
|
288
|
+
"catalog_snapshot is not available for database-backed catalogs. "
|
|
289
|
+
"Bags are already immutable snapshots."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def resolve_rid(self, rid: RID) -> dict[str, Any]:
|
|
293
|
+
"""Resolve a RID to its location.
|
|
294
|
+
|
|
295
|
+
For database-backed catalogs, this validates that the RID exists
|
|
296
|
+
in the bag and returns basic information about it.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
rid: RID to resolve.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Dictionary with RID and version information.
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
DerivaMLException: If RID not found in bag.
|
|
306
|
+
"""
|
|
307
|
+
version = self._database_model.rid_lookup(rid)
|
|
308
|
+
return {"RID": rid, "version": version}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Handle wrappers for ERMrest model objects.
|
|
2
|
+
|
|
3
|
+
This module re-exports TableHandle and ColumnHandle from deriva-py
|
|
4
|
+
for backwards compatibility. New code should import directly from
|
|
5
|
+
deriva.core.model_handles.
|
|
6
|
+
|
|
7
|
+
Classes:
|
|
8
|
+
ColumnHandle: Wrapper for ERMrest Column with simplified property access.
|
|
9
|
+
TableHandle: Wrapper for ERMrest Table with simplified operations.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from deriva.core.model_handles import TableHandle, ColumnHandle
|
|
13
|
+
|
|
14
|
+
__all__ = ["TableHandle", "ColumnHandle"]
|
deriva_ml/run_model.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""Command-line interface for executing ML models with DerivaML tracking.
|
|
2
|
+
|
|
3
|
+
This module provides a CLI tool for running ML models using hydra-zen configuration
|
|
4
|
+
while automatically tracking the execution in a Deriva catalog. It handles:
|
|
5
|
+
|
|
6
|
+
- Configuration loading from a user's configs module
|
|
7
|
+
- Hydra-zen configuration composition with command-line overrides
|
|
8
|
+
- Execution tracking with workflow provenance
|
|
9
|
+
- Multirun/sweep support with parent-child execution nesting
|
|
10
|
+
|
|
11
|
+
Usage:
|
|
12
|
+
deriva-ml-run --host localhost --catalog 45 model_config=my_model
|
|
13
|
+
deriva-ml-run +experiment=my_experiment
|
|
14
|
+
deriva-ml-run --multirun model_config=m1,m2
|
|
15
|
+
deriva-ml-run --info # Show available Hydra config options
|
|
16
|
+
|
|
17
|
+
This parallels `deriva-ml-run-notebook` but for Python model functions instead
|
|
18
|
+
of Jupyter notebooks.
|
|
19
|
+
|
|
20
|
+
See Also:
|
|
21
|
+
- run_notebook: CLI for running Jupyter notebooks
|
|
22
|
+
- runner.run_model: The underlying function that executes models
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import sys
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
|
|
28
|
+
from deriva.core import BaseCLI
|
|
29
|
+
from hydra_zen import store, zen
|
|
30
|
+
|
|
31
|
+
from deriva_ml.execution import (
|
|
32
|
+
run_model,
|
|
33
|
+
load_configs,
|
|
34
|
+
get_multirun_config,
|
|
35
|
+
get_all_multirun_configs,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DerivaMLRunCLI(BaseCLI):
|
|
40
|
+
"""Command-line interface for running ML models with DerivaML execution tracking.
|
|
41
|
+
|
|
42
|
+
This CLI extends Deriva's BaseCLI to provide model execution capabilities using
|
|
43
|
+
hydra-zen. It automatically loads configuration modules from the project's
|
|
44
|
+
configs directory.
|
|
45
|
+
|
|
46
|
+
The CLI supports:
|
|
47
|
+
- Host and catalog arguments (optional, can use Hydra config defaults)
|
|
48
|
+
- Hydra configuration overrides as positional arguments
|
|
49
|
+
- --info flag to display available configuration options
|
|
50
|
+
- --multirun flag for parameter sweeps
|
|
51
|
+
- --config-dir to specify custom config location
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
parser: ArgumentParser instance with configured arguments.
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> cli = DerivaMLRunCLI(
|
|
58
|
+
... description="Run ML model",
|
|
59
|
+
... epilog="See documentation for more details"
|
|
60
|
+
... )
|
|
61
|
+
>>> cli.main() # Parses args and runs model
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, description: str, epilog: str, **kwargs) -> None:
|
|
65
|
+
"""Initialize the model runner CLI with command-line arguments.
|
|
66
|
+
|
|
67
|
+
Sets up argument parsing for model execution, including host/catalog,
|
|
68
|
+
config directory, and Hydra overrides.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
description: Description text shown in --help output.
|
|
72
|
+
epilog: Additional text shown after argument help.
|
|
73
|
+
**kwargs: Additional keyword arguments passed to BaseCLI.
|
|
74
|
+
"""
|
|
75
|
+
BaseCLI.__init__(self, description, epilog, **kwargs)
|
|
76
|
+
|
|
77
|
+
self.parser.add_argument(
|
|
78
|
+
"--catalog",
|
|
79
|
+
type=str,
|
|
80
|
+
default=None,
|
|
81
|
+
help="Catalog number or identifier (optional if defined in Hydra config)"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
self.parser.add_argument(
|
|
85
|
+
"--config-dir",
|
|
86
|
+
"-c",
|
|
87
|
+
type=Path,
|
|
88
|
+
default=Path("src/configs"),
|
|
89
|
+
help="Path to the configs directory (default: src/configs)",
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.parser.add_argument(
|
|
93
|
+
"--config-name",
|
|
94
|
+
type=str,
|
|
95
|
+
default="deriva_model",
|
|
96
|
+
help="Name of the main hydra-zen config (default: deriva_model)",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.parser.add_argument(
|
|
100
|
+
"--info",
|
|
101
|
+
action="store_true",
|
|
102
|
+
help="Display available Hydra configuration groups and options.",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self.parser.add_argument(
|
|
106
|
+
"--multirun", "-m",
|
|
107
|
+
action="store_true",
|
|
108
|
+
help="Run multiple configurations (Hydra multirun mode).",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self.parser.add_argument(
|
|
112
|
+
"hydra_overrides",
|
|
113
|
+
nargs="*",
|
|
114
|
+
help="Hydra-zen configuration overrides (e.g., model_config=cifar10_quick)",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def main(self) -> int:
|
|
118
|
+
"""Parse command-line arguments and execute the model.
|
|
119
|
+
|
|
120
|
+
This is the main entry point that orchestrates:
|
|
121
|
+
1. Parsing command-line arguments
|
|
122
|
+
2. Loading configuration modules
|
|
123
|
+
3. Either showing config info or executing the model
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Exit code (0 for success, 1 for failure).
|
|
127
|
+
"""
|
|
128
|
+
args = self.parse_cli()
|
|
129
|
+
|
|
130
|
+
# Resolve config directory
|
|
131
|
+
config_dir = args.config_dir.resolve()
|
|
132
|
+
if not config_dir.exists():
|
|
133
|
+
print(f"Error: Config directory not found: {config_dir}")
|
|
134
|
+
return 1
|
|
135
|
+
|
|
136
|
+
# Add the parent of the config directory to sys.path
|
|
137
|
+
src_dir = config_dir.parent
|
|
138
|
+
if src_dir.exists() and str(src_dir) not in sys.path:
|
|
139
|
+
sys.path.insert(0, str(src_dir))
|
|
140
|
+
|
|
141
|
+
# Also add project root
|
|
142
|
+
project_root = src_dir.parent
|
|
143
|
+
if project_root.exists() and str(project_root) not in sys.path:
|
|
144
|
+
sys.path.insert(0, str(project_root))
|
|
145
|
+
|
|
146
|
+
# Load configurations from the configs module
|
|
147
|
+
config_module_name = config_dir.name
|
|
148
|
+
loaded = load_configs(config_module_name)
|
|
149
|
+
if not loaded:
|
|
150
|
+
# Try the old way
|
|
151
|
+
try:
|
|
152
|
+
exec(f"from {config_module_name} import load_all_configs; load_all_configs()")
|
|
153
|
+
except ImportError:
|
|
154
|
+
print(f"Error: Could not load configs from '{config_module_name}'")
|
|
155
|
+
print("Make sure the config directory contains an __init__.py with load_all_configs()")
|
|
156
|
+
return 1
|
|
157
|
+
|
|
158
|
+
if args.info:
|
|
159
|
+
self._show_hydra_info()
|
|
160
|
+
return 0
|
|
161
|
+
|
|
162
|
+
# Build Hydra overrides list
|
|
163
|
+
hydra_overrides = list(args.hydra_overrides) if args.hydra_overrides else []
|
|
164
|
+
|
|
165
|
+
# Check for +multirun=<name> and expand it
|
|
166
|
+
multirun_description = None
|
|
167
|
+
use_multirun = args.multirun
|
|
168
|
+
expanded_overrides = []
|
|
169
|
+
|
|
170
|
+
for override in hydra_overrides:
|
|
171
|
+
if override.startswith("+multirun="):
|
|
172
|
+
# Extract the multirun config name
|
|
173
|
+
multirun_name = override.split("=", 1)[1]
|
|
174
|
+
multirun_spec = get_multirun_config(multirun_name)
|
|
175
|
+
|
|
176
|
+
if multirun_spec is None:
|
|
177
|
+
available = get_all_multirun_configs()
|
|
178
|
+
print(f"Error: Unknown multirun config '{multirun_name}'")
|
|
179
|
+
if available:
|
|
180
|
+
print("Available multirun configs:")
|
|
181
|
+
for name in sorted(available.keys()):
|
|
182
|
+
print(f" - {name}")
|
|
183
|
+
else:
|
|
184
|
+
print("No multirun configs registered. Define them in configs/multiruns.py")
|
|
185
|
+
return 1
|
|
186
|
+
|
|
187
|
+
# Expand the multirun config's overrides
|
|
188
|
+
expanded_overrides.extend(multirun_spec.overrides)
|
|
189
|
+
multirun_description = multirun_spec.description
|
|
190
|
+
use_multirun = True # Automatically enable multirun mode
|
|
191
|
+
else:
|
|
192
|
+
# Keep non-multirun overrides (they can override multirun config values)
|
|
193
|
+
expanded_overrides.append(override)
|
|
194
|
+
|
|
195
|
+
hydra_overrides = expanded_overrides
|
|
196
|
+
|
|
197
|
+
# Add host/catalog overrides if provided on command line
|
|
198
|
+
if args.host:
|
|
199
|
+
hydra_overrides.append(f"deriva_ml.hostname={args.host}")
|
|
200
|
+
if args.catalog:
|
|
201
|
+
hydra_overrides.append(f"deriva_ml.catalog_id={args.catalog}")
|
|
202
|
+
|
|
203
|
+
# If we have a multirun description, add it as an override
|
|
204
|
+
# This gets passed to run_model which uses it for the parent execution
|
|
205
|
+
if multirun_description:
|
|
206
|
+
# Escape the description for Hydra command line
|
|
207
|
+
# Use single quotes and escape any internal single quotes
|
|
208
|
+
escaped_desc = multirun_description.replace("'", "\\'")
|
|
209
|
+
hydra_overrides.append(f"description='{escaped_desc}'")
|
|
210
|
+
|
|
211
|
+
# Finalize the hydra-zen store
|
|
212
|
+
store.add_to_hydra_store()
|
|
213
|
+
|
|
214
|
+
# Build argv for Hydra
|
|
215
|
+
hydra_argv = [sys.argv[0]] + hydra_overrides
|
|
216
|
+
if use_multirun:
|
|
217
|
+
hydra_argv.insert(1, "--multirun")
|
|
218
|
+
|
|
219
|
+
# Save and replace sys.argv for Hydra
|
|
220
|
+
original_argv = sys.argv
|
|
221
|
+
sys.argv = hydra_argv
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
zen(run_model).hydra_main(
|
|
225
|
+
config_name=args.config_name,
|
|
226
|
+
version_base="1.3",
|
|
227
|
+
config_path=None,
|
|
228
|
+
)
|
|
229
|
+
finally:
|
|
230
|
+
sys.argv = original_argv
|
|
231
|
+
|
|
232
|
+
return 0
|
|
233
|
+
|
|
234
|
+
@staticmethod
|
|
235
|
+
def _show_hydra_info() -> None:
|
|
236
|
+
"""Display available Hydra configuration groups and options.
|
|
237
|
+
|
|
238
|
+
Inspects the hydra-zen store and prints all registered configuration
|
|
239
|
+
groups and their available options.
|
|
240
|
+
"""
|
|
241
|
+
print("Available Hydra Configuration Groups:")
|
|
242
|
+
print("=" * 50)
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
groups: dict[str, list[str]] = {}
|
|
246
|
+
|
|
247
|
+
for group, name in store._queue:
|
|
248
|
+
if group:
|
|
249
|
+
if group not in groups:
|
|
250
|
+
groups[group] = []
|
|
251
|
+
if name not in groups[group]:
|
|
252
|
+
groups[group].append(name)
|
|
253
|
+
else:
|
|
254
|
+
if "__root__" not in groups:
|
|
255
|
+
groups["__root__"] = []
|
|
256
|
+
if name not in groups["__root__"]:
|
|
257
|
+
groups["__root__"].append(name)
|
|
258
|
+
|
|
259
|
+
for group in sorted(groups.keys()):
|
|
260
|
+
if group == "__root__":
|
|
261
|
+
print("\nTop-level configs:")
|
|
262
|
+
else:
|
|
263
|
+
print(f"\n{group}:")
|
|
264
|
+
for name in sorted(groups[group]):
|
|
265
|
+
print(f" - {name}")
|
|
266
|
+
|
|
267
|
+
# Show multirun configs if any are registered
|
|
268
|
+
multirun_configs = get_all_multirun_configs()
|
|
269
|
+
if multirun_configs:
|
|
270
|
+
print("\nmultirun:")
|
|
271
|
+
for name in sorted(multirun_configs.keys()):
|
|
272
|
+
spec = multirun_configs[name]
|
|
273
|
+
# Show first line of description or overrides summary
|
|
274
|
+
if spec.description:
|
|
275
|
+
first_line = spec.description.strip().split('\n')[0]
|
|
276
|
+
# Remove markdown formatting for display
|
|
277
|
+
first_line = first_line.lstrip('#').strip()
|
|
278
|
+
if len(first_line) > 50:
|
|
279
|
+
first_line = first_line[:47] + "..."
|
|
280
|
+
print(f" - {name}: {first_line}")
|
|
281
|
+
else:
|
|
282
|
+
print(f" - {name}: {', '.join(spec.overrides[:2])}")
|
|
283
|
+
|
|
284
|
+
print("\n" + "=" * 50)
|
|
285
|
+
print("Usage: deriva-ml-run [options] <group>=<option> ...")
|
|
286
|
+
print("Example: deriva-ml-run --host localhost --catalog 45 model_config=cifar10_quick")
|
|
287
|
+
print("Example: deriva-ml-run +experiment=cifar10_quick")
|
|
288
|
+
print("Example: deriva-ml-run +multirun=quick_vs_extended")
|
|
289
|
+
print("Example: deriva-ml-run --multirun +experiment=cifar10_quick,cifar10_extended")
|
|
290
|
+
|
|
291
|
+
except Exception as e:
|
|
292
|
+
print(f"Error inspecting Hydra store: {e}")
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def main() -> int:
|
|
296
|
+
"""Main entry point for the model runner CLI.
|
|
297
|
+
|
|
298
|
+
Creates and runs the DerivaMLRunCLI instance.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Exit code (0 for success, 1 for failure).
|
|
302
|
+
"""
|
|
303
|
+
cli = DerivaMLRunCLI(
|
|
304
|
+
description="Run ML models with DerivaML execution tracking",
|
|
305
|
+
epilog=(
|
|
306
|
+
"Examples:\n"
|
|
307
|
+
" deriva-ml-run model_config=my_model\n"
|
|
308
|
+
" deriva-ml-run --host localhost --catalog 45 +experiment=cifar10_quick\n"
|
|
309
|
+
" deriva-ml-run +multirun=quick_vs_extended\n"
|
|
310
|
+
" deriva-ml-run +multirun=lr_sweep model_config.epochs=5\n"
|
|
311
|
+
" deriva-ml-run --multirun +experiment=cifar10_quick,cifar10_extended\n"
|
|
312
|
+
" deriva-ml-run --info\n"
|
|
313
|
+
),
|
|
314
|
+
)
|
|
315
|
+
return cli.main()
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
if __name__ == "__main__":
|
|
319
|
+
sys.exit(main())
|