deriva-ml 1.17.9__py3-none-any.whl → 1.17.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. deriva_ml/__init__.py +43 -1
  2. deriva_ml/asset/__init__.py +17 -0
  3. deriva_ml/asset/asset.py +357 -0
  4. deriva_ml/asset/aux_classes.py +100 -0
  5. deriva_ml/bump_version.py +254 -11
  6. deriva_ml/catalog/__init__.py +21 -0
  7. deriva_ml/catalog/clone.py +1199 -0
  8. deriva_ml/catalog/localize.py +426 -0
  9. deriva_ml/core/__init__.py +29 -0
  10. deriva_ml/core/base.py +817 -1067
  11. deriva_ml/core/config.py +169 -21
  12. deriva_ml/core/constants.py +120 -19
  13. deriva_ml/core/definitions.py +123 -13
  14. deriva_ml/core/enums.py +47 -73
  15. deriva_ml/core/ermrest.py +226 -193
  16. deriva_ml/core/exceptions.py +297 -14
  17. deriva_ml/core/filespec.py +99 -28
  18. deriva_ml/core/logging_config.py +225 -0
  19. deriva_ml/core/mixins/__init__.py +42 -0
  20. deriva_ml/core/mixins/annotation.py +915 -0
  21. deriva_ml/core/mixins/asset.py +384 -0
  22. deriva_ml/core/mixins/dataset.py +237 -0
  23. deriva_ml/core/mixins/execution.py +408 -0
  24. deriva_ml/core/mixins/feature.py +365 -0
  25. deriva_ml/core/mixins/file.py +263 -0
  26. deriva_ml/core/mixins/path_builder.py +145 -0
  27. deriva_ml/core/mixins/rid_resolution.py +204 -0
  28. deriva_ml/core/mixins/vocabulary.py +400 -0
  29. deriva_ml/core/mixins/workflow.py +322 -0
  30. deriva_ml/core/validation.py +389 -0
  31. deriva_ml/dataset/__init__.py +2 -1
  32. deriva_ml/dataset/aux_classes.py +20 -4
  33. deriva_ml/dataset/catalog_graph.py +575 -0
  34. deriva_ml/dataset/dataset.py +1242 -1008
  35. deriva_ml/dataset/dataset_bag.py +1311 -182
  36. deriva_ml/dataset/history.py +27 -14
  37. deriva_ml/dataset/upload.py +225 -38
  38. deriva_ml/demo_catalog.py +186 -105
  39. deriva_ml/execution/__init__.py +46 -2
  40. deriva_ml/execution/base_config.py +639 -0
  41. deriva_ml/execution/execution.py +545 -244
  42. deriva_ml/execution/execution_configuration.py +26 -11
  43. deriva_ml/execution/execution_record.py +592 -0
  44. deriva_ml/execution/find_caller.py +298 -0
  45. deriva_ml/execution/model_protocol.py +175 -0
  46. deriva_ml/execution/multirun_config.py +153 -0
  47. deriva_ml/execution/runner.py +595 -0
  48. deriva_ml/execution/workflow.py +224 -35
  49. deriva_ml/experiment/__init__.py +8 -0
  50. deriva_ml/experiment/experiment.py +411 -0
  51. deriva_ml/feature.py +6 -1
  52. deriva_ml/install_kernel.py +143 -6
  53. deriva_ml/interfaces.py +862 -0
  54. deriva_ml/model/__init__.py +99 -0
  55. deriva_ml/model/annotations.py +1278 -0
  56. deriva_ml/model/catalog.py +286 -60
  57. deriva_ml/model/database.py +144 -649
  58. deriva_ml/model/deriva_ml_database.py +308 -0
  59. deriva_ml/model/handles.py +14 -0
  60. deriva_ml/run_model.py +319 -0
  61. deriva_ml/run_notebook.py +507 -38
  62. deriva_ml/schema/__init__.py +18 -2
  63. deriva_ml/schema/annotations.py +62 -33
  64. deriva_ml/schema/create_schema.py +169 -69
  65. deriva_ml/schema/validation.py +601 -0
  66. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -5
  67. deriva_ml-1.17.11.dist-info/RECORD +77 -0
  68. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
  69. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +2 -0
  70. deriva_ml/protocols/dataset.py +0 -19
  71. deriva_ml/test.py +0 -94
  72. deriva_ml-1.17.9.dist-info/RECORD +0 -45
  73. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
  74. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,308 @@
1
+ """Database-backed implementation of DerivaMLCatalog protocol.
2
+
3
+ This module provides a DerivaMLDatabase class that implements the DerivaMLCatalog
4
+ protocol using a DatabaseModel (SQLite) instead of a live catalog connection.
5
+ This allows code written against the DerivaMLCatalog protocol to work identically
6
+ with both live catalogs and downloaded bags.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, Generator, Iterable
13
+
14
+ import pandas as pd
15
+ from deriva.core.ermrest_model import Table
16
+
17
+ from deriva_ml.core.definitions import RID, MLVocab, VocabularyTerm
18
+ from deriva_ml.core.exceptions import DerivaMLException, DerivaMLInvalidTerm
19
+ from deriva_ml.dataset.aux_classes import DatasetSpec, DatasetVersion
20
+ from deriva_ml.dataset.dataset_bag import DatasetBag
21
+ from deriva_ml.feature import Feature
22
+
23
+ if TYPE_CHECKING:
24
+ from deriva_ml.model.database import DatabaseModel
25
+
26
+
27
+ class DerivaMLDatabase:
28
+ """Database-backed implementation of DerivaMLCatalog protocol.
29
+
30
+ Provides the same interface as DerivaML but operates on downloaded
31
+ bags via SQLite. Read-only operations work; write operations raise
32
+ DerivaMLException since bags are immutable snapshots.
33
+
34
+ This class allows code written against the DerivaMLCatalog protocol
35
+ to work identically with both live catalogs (DerivaML) and downloaded
36
+ bags (DerivaMLDatabase).
37
+
38
+ Attributes:
39
+ ml_schema: Name of the ML schema.
40
+ domain_schemas: Frozenset of domain schema names.
41
+ default_schema: Default schema for table creation.
42
+ model: The underlying DatabaseModel.
43
+ working_dir: Working directory path.
44
+ cache_dir: Cache directory path.
45
+ """
46
+
47
+ def __init__(self, database_model: "DatabaseModel"):
48
+ """Create a new DerivaMLDatabase.
49
+
50
+ Args:
51
+ database_model: The DatabaseModel containing the SQLite database.
52
+ """
53
+ self._database_model = database_model
54
+
55
+ # ==================== Protocol Properties ====================
56
+
57
+ @property
58
+ def ml_schema(self) -> str:
59
+ """Get the ML schema name."""
60
+ return self._database_model.ml_schema
61
+
62
+ @property
63
+ def domain_schemas(self) -> frozenset[str]:
64
+ """Get the domain schema names."""
65
+ return self._database_model.domain_schemas
66
+
67
+ @property
68
+ def default_schema(self) -> str | None:
69
+ """Get the default schema name."""
70
+ return self._database_model.default_schema
71
+
72
+ @property
73
+ def model(self) -> "DatabaseModel":
74
+ """Get the underlying database model."""
75
+ return self._database_model
76
+
77
+ @property
78
+ def working_dir(self) -> Path:
79
+ """Get the working directory path."""
80
+ return self._database_model.dbase_path
81
+
82
+ @property
83
+ def cache_dir(self) -> Path:
84
+ """Get the cache directory path (same as working_dir for bags)."""
85
+ return self._database_model.dbase_path
86
+
87
+ @property
88
+ def catalog_id(self) -> str:
89
+ """Get the catalog ID (derived from bag path)."""
90
+ return str(self._database_model.bag_path)
91
+
92
+ @property
93
+ def _dataset_table(self) -> Table:
94
+ """Get the Dataset table from the model."""
95
+ return self._database_model.dataset_table
96
+
97
+ # ==================== Read Operations (Supported) ====================
98
+
99
+ def lookup_dataset(
100
+ self, dataset: RID | DatasetSpec, deleted: bool = False
101
+ ) -> DatasetBag:
102
+ """Look up a dataset by RID or spec.
103
+
104
+ Args:
105
+ dataset: Dataset RID or DatasetSpec to look up.
106
+ deleted: Whether to include deleted datasets (ignored for bags).
107
+
108
+ Returns:
109
+ DatasetBag for the specified dataset.
110
+
111
+ Raises:
112
+ DerivaMLException: If dataset not found in bag.
113
+ """
114
+ if isinstance(dataset, DatasetSpec):
115
+ rid = dataset.rid
116
+ else:
117
+ rid = dataset
118
+
119
+ # Validate the dataset exists
120
+ self._database_model.rid_lookup(rid)
121
+
122
+ # Get dataset metadata
123
+ dataset_record = next(
124
+ (d for d in self._database_model._get_table_contents("Dataset") if d["RID"] == rid),
125
+ None
126
+ )
127
+ if not dataset_record:
128
+ raise DerivaMLException(f"Dataset {rid} not found in bag")
129
+
130
+ # Get dataset types from association table
131
+ atable = f"Dataset_{MLVocab.dataset_type.value}"
132
+ ds_types = [
133
+ t[MLVocab.dataset_type.value]
134
+ for t in self._database_model._get_table_contents(atable)
135
+ if t["Dataset"] == rid
136
+ ]
137
+
138
+ return DatasetBag(
139
+ catalog=self,
140
+ dataset_rid=rid,
141
+ description=dataset_record.get("Description", ""),
142
+ execution_rid=(self._database_model._get_dataset_execution(rid) or {}).get("Execution"),
143
+ dataset_types=ds_types,
144
+ )
145
+
146
+ def find_datasets(self, deleted: bool = False) -> Iterable[DatasetBag]:
147
+ """List all datasets in the bag.
148
+
149
+ Args:
150
+ deleted: Whether to include deleted datasets (ignored for bags).
151
+
152
+ Returns:
153
+ Iterable of DatasetBag objects.
154
+ """
155
+ # Get dataset types for all datasets from association table
156
+ atable = f"Dataset_{MLVocab.dataset_type.value}"
157
+ ds_types = list(self._database_model._get_table_contents(atable))
158
+
159
+ datasets = []
160
+ for dataset in self._database_model._get_table_contents("Dataset"):
161
+ my_types = [t[MLVocab.dataset_type.value] for t in ds_types if t["Dataset"] == dataset["RID"]]
162
+ datasets.append(
163
+ DatasetBag(
164
+ catalog=self,
165
+ dataset_rid=dataset["RID"],
166
+ description=dataset.get("Description", ""),
167
+ execution_rid=(self._database_model._get_dataset_execution(dataset["RID"]) or {}).get("Execution"),
168
+ dataset_types=my_types,
169
+ )
170
+ )
171
+ return datasets
172
+
173
+ def lookup_term(self, table: str | Table, term_name: str) -> VocabularyTerm:
174
+ """Look up a vocabulary term by name.
175
+
176
+ Args:
177
+ table: Vocabulary table to search.
178
+ term_name: Name or synonym of the term.
179
+
180
+ Returns:
181
+ The matching VocabularyTerm.
182
+
183
+ Raises:
184
+ DerivaMLException: If table is not a vocabulary or term not found.
185
+ """
186
+ # Get table object if string provided
187
+ if isinstance(table, str):
188
+ table_obj = self._database_model.name_to_table(table)
189
+ else:
190
+ table_obj = table
191
+
192
+ # Validate it's a vocabulary table
193
+ if not self._database_model.is_vocabulary(table_obj):
194
+ raise DerivaMLException(f"The table {table} is not a controlled vocabulary")
195
+
196
+ # Search for term in SQLite
197
+ for term in self.get_table_as_dict(table_obj.name):
198
+ if term_name == term.get("Name") or (
199
+ term.get("Synonyms") and term_name in term.get("Synonyms", [])
200
+ ):
201
+ # Convert synonyms to list if needed
202
+ synonyms = term.get("Synonyms")
203
+ if synonyms and not isinstance(synonyms, list):
204
+ synonyms = list(synonyms)
205
+ term["Synonyms"] = synonyms or []
206
+ return VocabularyTerm.model_validate(term)
207
+
208
+ raise DerivaMLInvalidTerm(table, term_name)
209
+
210
+ def get_table_as_dataframe(self, table: str) -> pd.DataFrame:
211
+ """Get table contents as a pandas DataFrame.
212
+
213
+ Args:
214
+ table: Name of the table to retrieve.
215
+
216
+ Returns:
217
+ DataFrame containing all table contents.
218
+ """
219
+ return pd.DataFrame(list(self.get_table_as_dict(table)))
220
+
221
+ def get_table_as_dict(self, table: str) -> Generator[dict[str, Any], None, None]:
222
+ """Get table contents as dictionaries.
223
+
224
+ Args:
225
+ table: Name of the table to retrieve.
226
+
227
+ Returns:
228
+ Generator yielding dictionaries for each row.
229
+ """
230
+ yield from self._database_model._get_table_contents(table)
231
+
232
+ def list_dataset_element_types(self) -> list[Table]:
233
+ """List the types of elements that can be in datasets.
234
+
235
+ Returns:
236
+ List of Table objects representing element types.
237
+ """
238
+ return self._database_model.list_dataset_element_types()
239
+
240
+ def find_features(self, table: str | Table) -> Iterable[Feature]:
241
+ """Find features associated with a table.
242
+
243
+ Args:
244
+ table: Table to find features for.
245
+
246
+ Returns:
247
+ Iterable of Feature objects.
248
+ """
249
+ return self._database_model.find_features(table)
250
+
251
+ # ==================== Write Operations (Not Supported) ====================
252
+
253
+ def create_dataset(
254
+ self,
255
+ execution_rid: RID | None = None,
256
+ version: DatasetVersion | str | None = None,
257
+ description: str = "",
258
+ dataset_types: list[str] | None = None,
259
+ ) -> DatasetBag:
260
+ """Create a new dataset.
261
+
262
+ Raises:
263
+ DerivaMLException: Always, since bags are read-only.
264
+ """
265
+ raise DerivaMLException(
266
+ "Cannot create datasets in a downloaded bag. "
267
+ "Bags are immutable snapshots of catalog data."
268
+ )
269
+
270
+ def pathBuilder(self):
271
+ """Get the catalog path builder.
272
+
273
+ Raises:
274
+ DerivaMLException: Always, since SQLite doesn't use pathBuilder.
275
+ """
276
+ raise DerivaMLException(
277
+ "pathBuilder is not available for database-backed catalogs. "
278
+ "Use get_table_as_dict() or get_table_as_dataframe() instead."
279
+ )
280
+
281
+ def catalog_snapshot(self, version_snapshot: str):
282
+ """Create a catalog snapshot.
283
+
284
+ Raises:
285
+ DerivaMLException: Always, since bags are already snapshots.
286
+ """
287
+ raise DerivaMLException(
288
+ "catalog_snapshot is not available for database-backed catalogs. "
289
+ "Bags are already immutable snapshots."
290
+ )
291
+
292
+ def resolve_rid(self, rid: RID) -> dict[str, Any]:
293
+ """Resolve a RID to its location.
294
+
295
+ For database-backed catalogs, this validates that the RID exists
296
+ in the bag and returns basic information about it.
297
+
298
+ Args:
299
+ rid: RID to resolve.
300
+
301
+ Returns:
302
+ Dictionary with RID and version information.
303
+
304
+ Raises:
305
+ DerivaMLException: If RID not found in bag.
306
+ """
307
+ version = self._database_model.rid_lookup(rid)
308
+ return {"RID": rid, "version": version}
@@ -0,0 +1,14 @@
1
+ """Handle wrappers for ERMrest model objects.
2
+
3
+ This module re-exports TableHandle and ColumnHandle from deriva-py
4
+ for backwards compatibility. New code should import directly from
5
+ deriva.core.model_handles.
6
+
7
+ Classes:
8
+ ColumnHandle: Wrapper for ERMrest Column with simplified property access.
9
+ TableHandle: Wrapper for ERMrest Table with simplified operations.
10
+ """
11
+
12
+ from deriva.core.model_handles import TableHandle, ColumnHandle
13
+
14
+ __all__ = ["TableHandle", "ColumnHandle"]
deriva_ml/run_model.py ADDED
@@ -0,0 +1,319 @@
1
+ """Command-line interface for executing ML models with DerivaML tracking.
2
+
3
+ This module provides a CLI tool for running ML models using hydra-zen configuration
4
+ while automatically tracking the execution in a Deriva catalog. It handles:
5
+
6
+ - Configuration loading from a user's configs module
7
+ - Hydra-zen configuration composition with command-line overrides
8
+ - Execution tracking with workflow provenance
9
+ - Multirun/sweep support with parent-child execution nesting
10
+
11
+ Usage:
12
+ deriva-ml-run --host localhost --catalog 45 model_config=my_model
13
+ deriva-ml-run +experiment=my_experiment
14
+ deriva-ml-run --multirun model_config=m1,m2
15
+ deriva-ml-run --info # Show available Hydra config options
16
+
17
+ This parallels `deriva-ml-run-notebook` but for Python model functions instead
18
+ of Jupyter notebooks.
19
+
20
+ See Also:
21
+ - run_notebook: CLI for running Jupyter notebooks
22
+ - runner.run_model: The underlying function that executes models
23
+ """
24
+
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ from deriva.core import BaseCLI
29
+ from hydra_zen import store, zen
30
+
31
+ from deriva_ml.execution import (
32
+ run_model,
33
+ load_configs,
34
+ get_multirun_config,
35
+ get_all_multirun_configs,
36
+ )
37
+
38
+
39
+ class DerivaMLRunCLI(BaseCLI):
40
+ """Command-line interface for running ML models with DerivaML execution tracking.
41
+
42
+ This CLI extends Deriva's BaseCLI to provide model execution capabilities using
43
+ hydra-zen. It automatically loads configuration modules from the project's
44
+ configs directory.
45
+
46
+ The CLI supports:
47
+ - Host and catalog arguments (optional, can use Hydra config defaults)
48
+ - Hydra configuration overrides as positional arguments
49
+ - --info flag to display available configuration options
50
+ - --multirun flag for parameter sweeps
51
+ - --config-dir to specify custom config location
52
+
53
+ Attributes:
54
+ parser: ArgumentParser instance with configured arguments.
55
+
56
+ Example:
57
+ >>> cli = DerivaMLRunCLI(
58
+ ... description="Run ML model",
59
+ ... epilog="See documentation for more details"
60
+ ... )
61
+ >>> cli.main() # Parses args and runs model
62
+ """
63
+
64
+ def __init__(self, description: str, epilog: str, **kwargs) -> None:
65
+ """Initialize the model runner CLI with command-line arguments.
66
+
67
+ Sets up argument parsing for model execution, including host/catalog,
68
+ config directory, and Hydra overrides.
69
+
70
+ Args:
71
+ description: Description text shown in --help output.
72
+ epilog: Additional text shown after argument help.
73
+ **kwargs: Additional keyword arguments passed to BaseCLI.
74
+ """
75
+ BaseCLI.__init__(self, description, epilog, **kwargs)
76
+
77
+ self.parser.add_argument(
78
+ "--catalog",
79
+ type=str,
80
+ default=None,
81
+ help="Catalog number or identifier (optional if defined in Hydra config)"
82
+ )
83
+
84
+ self.parser.add_argument(
85
+ "--config-dir",
86
+ "-c",
87
+ type=Path,
88
+ default=Path("src/configs"),
89
+ help="Path to the configs directory (default: src/configs)",
90
+ )
91
+
92
+ self.parser.add_argument(
93
+ "--config-name",
94
+ type=str,
95
+ default="deriva_model",
96
+ help="Name of the main hydra-zen config (default: deriva_model)",
97
+ )
98
+
99
+ self.parser.add_argument(
100
+ "--info",
101
+ action="store_true",
102
+ help="Display available Hydra configuration groups and options.",
103
+ )
104
+
105
+ self.parser.add_argument(
106
+ "--multirun", "-m",
107
+ action="store_true",
108
+ help="Run multiple configurations (Hydra multirun mode).",
109
+ )
110
+
111
+ self.parser.add_argument(
112
+ "hydra_overrides",
113
+ nargs="*",
114
+ help="Hydra-zen configuration overrides (e.g., model_config=cifar10_quick)",
115
+ )
116
+
117
+ def main(self) -> int:
118
+ """Parse command-line arguments and execute the model.
119
+
120
+ This is the main entry point that orchestrates:
121
+ 1. Parsing command-line arguments
122
+ 2. Loading configuration modules
123
+ 3. Either showing config info or executing the model
124
+
125
+ Returns:
126
+ Exit code (0 for success, 1 for failure).
127
+ """
128
+ args = self.parse_cli()
129
+
130
+ # Resolve config directory
131
+ config_dir = args.config_dir.resolve()
132
+ if not config_dir.exists():
133
+ print(f"Error: Config directory not found: {config_dir}")
134
+ return 1
135
+
136
+ # Add the parent of the config directory to sys.path
137
+ src_dir = config_dir.parent
138
+ if src_dir.exists() and str(src_dir) not in sys.path:
139
+ sys.path.insert(0, str(src_dir))
140
+
141
+ # Also add project root
142
+ project_root = src_dir.parent
143
+ if project_root.exists() and str(project_root) not in sys.path:
144
+ sys.path.insert(0, str(project_root))
145
+
146
+ # Load configurations from the configs module
147
+ config_module_name = config_dir.name
148
+ loaded = load_configs(config_module_name)
149
+ if not loaded:
150
+ # Try the old way
151
+ try:
152
+ exec(f"from {config_module_name} import load_all_configs; load_all_configs()")
153
+ except ImportError:
154
+ print(f"Error: Could not load configs from '{config_module_name}'")
155
+ print("Make sure the config directory contains an __init__.py with load_all_configs()")
156
+ return 1
157
+
158
+ if args.info:
159
+ self._show_hydra_info()
160
+ return 0
161
+
162
+ # Build Hydra overrides list
163
+ hydra_overrides = list(args.hydra_overrides) if args.hydra_overrides else []
164
+
165
+ # Check for +multirun=<name> and expand it
166
+ multirun_description = None
167
+ use_multirun = args.multirun
168
+ expanded_overrides = []
169
+
170
+ for override in hydra_overrides:
171
+ if override.startswith("+multirun="):
172
+ # Extract the multirun config name
173
+ multirun_name = override.split("=", 1)[1]
174
+ multirun_spec = get_multirun_config(multirun_name)
175
+
176
+ if multirun_spec is None:
177
+ available = get_all_multirun_configs()
178
+ print(f"Error: Unknown multirun config '{multirun_name}'")
179
+ if available:
180
+ print("Available multirun configs:")
181
+ for name in sorted(available.keys()):
182
+ print(f" - {name}")
183
+ else:
184
+ print("No multirun configs registered. Define them in configs/multiruns.py")
185
+ return 1
186
+
187
+ # Expand the multirun config's overrides
188
+ expanded_overrides.extend(multirun_spec.overrides)
189
+ multirun_description = multirun_spec.description
190
+ use_multirun = True # Automatically enable multirun mode
191
+ else:
192
+ # Keep non-multirun overrides (they can override multirun config values)
193
+ expanded_overrides.append(override)
194
+
195
+ hydra_overrides = expanded_overrides
196
+
197
+ # Add host/catalog overrides if provided on command line
198
+ if args.host:
199
+ hydra_overrides.append(f"deriva_ml.hostname={args.host}")
200
+ if args.catalog:
201
+ hydra_overrides.append(f"deriva_ml.catalog_id={args.catalog}")
202
+
203
+ # If we have a multirun description, add it as an override
204
+ # This gets passed to run_model which uses it for the parent execution
205
+ if multirun_description:
206
+ # Escape the description for Hydra command line
207
+ # Use single quotes and escape any internal single quotes
208
+ escaped_desc = multirun_description.replace("'", "\\'")
209
+ hydra_overrides.append(f"description='{escaped_desc}'")
210
+
211
+ # Finalize the hydra-zen store
212
+ store.add_to_hydra_store()
213
+
214
+ # Build argv for Hydra
215
+ hydra_argv = [sys.argv[0]] + hydra_overrides
216
+ if use_multirun:
217
+ hydra_argv.insert(1, "--multirun")
218
+
219
+ # Save and replace sys.argv for Hydra
220
+ original_argv = sys.argv
221
+ sys.argv = hydra_argv
222
+
223
+ try:
224
+ zen(run_model).hydra_main(
225
+ config_name=args.config_name,
226
+ version_base="1.3",
227
+ config_path=None,
228
+ )
229
+ finally:
230
+ sys.argv = original_argv
231
+
232
+ return 0
233
+
234
+ @staticmethod
235
+ def _show_hydra_info() -> None:
236
+ """Display available Hydra configuration groups and options.
237
+
238
+ Inspects the hydra-zen store and prints all registered configuration
239
+ groups and their available options.
240
+ """
241
+ print("Available Hydra Configuration Groups:")
242
+ print("=" * 50)
243
+
244
+ try:
245
+ groups: dict[str, list[str]] = {}
246
+
247
+ for group, name in store._queue:
248
+ if group:
249
+ if group not in groups:
250
+ groups[group] = []
251
+ if name not in groups[group]:
252
+ groups[group].append(name)
253
+ else:
254
+ if "__root__" not in groups:
255
+ groups["__root__"] = []
256
+ if name not in groups["__root__"]:
257
+ groups["__root__"].append(name)
258
+
259
+ for group in sorted(groups.keys()):
260
+ if group == "__root__":
261
+ print("\nTop-level configs:")
262
+ else:
263
+ print(f"\n{group}:")
264
+ for name in sorted(groups[group]):
265
+ print(f" - {name}")
266
+
267
+ # Show multirun configs if any are registered
268
+ multirun_configs = get_all_multirun_configs()
269
+ if multirun_configs:
270
+ print("\nmultirun:")
271
+ for name in sorted(multirun_configs.keys()):
272
+ spec = multirun_configs[name]
273
+ # Show first line of description or overrides summary
274
+ if spec.description:
275
+ first_line = spec.description.strip().split('\n')[0]
276
+ # Remove markdown formatting for display
277
+ first_line = first_line.lstrip('#').strip()
278
+ if len(first_line) > 50:
279
+ first_line = first_line[:47] + "..."
280
+ print(f" - {name}: {first_line}")
281
+ else:
282
+ print(f" - {name}: {', '.join(spec.overrides[:2])}")
283
+
284
+ print("\n" + "=" * 50)
285
+ print("Usage: deriva-ml-run [options] <group>=<option> ...")
286
+ print("Example: deriva-ml-run --host localhost --catalog 45 model_config=cifar10_quick")
287
+ print("Example: deriva-ml-run +experiment=cifar10_quick")
288
+ print("Example: deriva-ml-run +multirun=quick_vs_extended")
289
+ print("Example: deriva-ml-run --multirun +experiment=cifar10_quick,cifar10_extended")
290
+
291
+ except Exception as e:
292
+ print(f"Error inspecting Hydra store: {e}")
293
+
294
+
295
+ def main() -> int:
296
+ """Main entry point for the model runner CLI.
297
+
298
+ Creates and runs the DerivaMLRunCLI instance.
299
+
300
+ Returns:
301
+ Exit code (0 for success, 1 for failure).
302
+ """
303
+ cli = DerivaMLRunCLI(
304
+ description="Run ML models with DerivaML execution tracking",
305
+ epilog=(
306
+ "Examples:\n"
307
+ " deriva-ml-run model_config=my_model\n"
308
+ " deriva-ml-run --host localhost --catalog 45 +experiment=cifar10_quick\n"
309
+ " deriva-ml-run +multirun=quick_vs_extended\n"
310
+ " deriva-ml-run +multirun=lr_sweep model_config.epochs=5\n"
311
+ " deriva-ml-run --multirun +experiment=cifar10_quick,cifar10_extended\n"
312
+ " deriva-ml-run --info\n"
313
+ ),
314
+ )
315
+ return cli.main()
316
+
317
+
318
+ if __name__ == "__main__":
319
+ sys.exit(main())