deriva-ml 1.17.9__py3-none-any.whl → 1.17.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. deriva_ml/__init__.py +43 -1
  2. deriva_ml/asset/__init__.py +17 -0
  3. deriva_ml/asset/asset.py +357 -0
  4. deriva_ml/asset/aux_classes.py +100 -0
  5. deriva_ml/bump_version.py +254 -11
  6. deriva_ml/catalog/__init__.py +21 -0
  7. deriva_ml/catalog/clone.py +1199 -0
  8. deriva_ml/catalog/localize.py +426 -0
  9. deriva_ml/core/__init__.py +29 -0
  10. deriva_ml/core/base.py +817 -1067
  11. deriva_ml/core/config.py +169 -21
  12. deriva_ml/core/constants.py +120 -19
  13. deriva_ml/core/definitions.py +123 -13
  14. deriva_ml/core/enums.py +47 -73
  15. deriva_ml/core/ermrest.py +226 -193
  16. deriva_ml/core/exceptions.py +297 -14
  17. deriva_ml/core/filespec.py +99 -28
  18. deriva_ml/core/logging_config.py +225 -0
  19. deriva_ml/core/mixins/__init__.py +42 -0
  20. deriva_ml/core/mixins/annotation.py +915 -0
  21. deriva_ml/core/mixins/asset.py +384 -0
  22. deriva_ml/core/mixins/dataset.py +237 -0
  23. deriva_ml/core/mixins/execution.py +408 -0
  24. deriva_ml/core/mixins/feature.py +365 -0
  25. deriva_ml/core/mixins/file.py +263 -0
  26. deriva_ml/core/mixins/path_builder.py +145 -0
  27. deriva_ml/core/mixins/rid_resolution.py +204 -0
  28. deriva_ml/core/mixins/vocabulary.py +400 -0
  29. deriva_ml/core/mixins/workflow.py +322 -0
  30. deriva_ml/core/validation.py +389 -0
  31. deriva_ml/dataset/__init__.py +2 -1
  32. deriva_ml/dataset/aux_classes.py +20 -4
  33. deriva_ml/dataset/catalog_graph.py +575 -0
  34. deriva_ml/dataset/dataset.py +1242 -1008
  35. deriva_ml/dataset/dataset_bag.py +1311 -182
  36. deriva_ml/dataset/history.py +27 -14
  37. deriva_ml/dataset/upload.py +225 -38
  38. deriva_ml/demo_catalog.py +186 -105
  39. deriva_ml/execution/__init__.py +46 -2
  40. deriva_ml/execution/base_config.py +639 -0
  41. deriva_ml/execution/execution.py +545 -244
  42. deriva_ml/execution/execution_configuration.py +26 -11
  43. deriva_ml/execution/execution_record.py +592 -0
  44. deriva_ml/execution/find_caller.py +298 -0
  45. deriva_ml/execution/model_protocol.py +175 -0
  46. deriva_ml/execution/multirun_config.py +153 -0
  47. deriva_ml/execution/runner.py +595 -0
  48. deriva_ml/execution/workflow.py +224 -35
  49. deriva_ml/experiment/__init__.py +8 -0
  50. deriva_ml/experiment/experiment.py +411 -0
  51. deriva_ml/feature.py +6 -1
  52. deriva_ml/install_kernel.py +143 -6
  53. deriva_ml/interfaces.py +862 -0
  54. deriva_ml/model/__init__.py +99 -0
  55. deriva_ml/model/annotations.py +1278 -0
  56. deriva_ml/model/catalog.py +286 -60
  57. deriva_ml/model/database.py +144 -649
  58. deriva_ml/model/deriva_ml_database.py +308 -0
  59. deriva_ml/model/handles.py +14 -0
  60. deriva_ml/run_model.py +319 -0
  61. deriva_ml/run_notebook.py +507 -38
  62. deriva_ml/schema/__init__.py +18 -2
  63. deriva_ml/schema/annotations.py +62 -33
  64. deriva_ml/schema/create_schema.py +169 -69
  65. deriva_ml/schema/validation.py +601 -0
  66. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -5
  67. deriva_ml-1.17.11.dist-info/RECORD +77 -0
  68. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
  69. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +2 -0
  70. deriva_ml/protocols/dataset.py +0 -19
  71. deriva_ml/test.py +0 -94
  72. deriva_ml-1.17.9.dist-info/RECORD +0 -45
  73. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
  74. {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import sys
5
+ from pathlib import Path
6
+ from types import FrameType
7
+ from typing import Optional
8
+
9
+ try: # optional imports — used only when running in notebooks
10
+ from IPython.core.getipython import get_ipython # type: ignore
11
+ except Exception: # pragma: no cover - optional
12
+
13
+ def get_ipython(): # type: ignore
14
+ return None
15
+
16
+
17
+ try: # optional — only available when inside a kernel
18
+ from ipykernel.connect import get_connection_file as _get_kernel_connection
19
+ except Exception: # pragma: no cover - optional
20
+ _get_kernel_connection = None # type: ignore
21
+
22
+ try: # optional — only available when jupyter-server is installed
23
+ from jupyter_server.serverapp import list_running_servers as _list_running_servers # type: ignore
24
+ except Exception: # pragma: no cover - optional
25
+ _list_running_servers = None # type: ignore
26
+
27
+ try: # optional — HTTP call to Jupyter server API
28
+ import requests # type: ignore
29
+ from requests import RequestException # type: ignore
30
+ except Exception: # pragma: no cover - optional
31
+ requests = None # type: ignore
32
+ RequestException = Exception # type: ignore
33
+
34
+
35
+ def _norm(p: str) -> str:
36
+ """Normalize a path string using pathlib.
37
+
38
+ - Expands ~
39
+ - Resolves to absolute path
40
+ - Returns a string path
41
+ Note: We no longer apply os.path.normcase explicitly; pathlib's resolve
42
+ provides a consistent absolute path. This should be sufficient for our
43
+ use-cases across platforms.
44
+ """
45
+ try:
46
+ return str(Path(p).expanduser().resolve())
47
+ except Exception:
48
+ # As a very last resort, return the original string
49
+ return p
50
+
51
+
52
+ # Treat certain pseudo filenames from IPython/Jupyter as user code so they
53
+ # can be selected as the calling location when appropriate (e.g., in REPL).
54
+ def _is_pseudo_user_filename(filename: str) -> bool:
55
+ """Return True if filename looks like an IPython/Jupyter pseudo file.
56
+
57
+ Examples that should return True:
58
+ - "<ipython-input-7-abcdef>"
59
+ - "<jupyter-input-3-123456>"
60
+ - "<ipykernel_12345>"
61
+
62
+ Other pseudo files like "<stdin>" or "<string>" should return False here
63
+ so they can be treated by the generic pseudo-file handling below.
64
+ """
65
+ if not (filename.startswith("<") and filename.endswith(">")):
66
+ return False
67
+ lower = filename.lower()
68
+ return lower.startswith("<ipython-input-") or lower.startswith("<jupyter-input-") or lower.startswith("<ipykernel_")
69
+
70
+
71
+ # Names that frequently represent "system/tooling" frames rather than user code
72
+ _SYSTEM_MODULE_PREFIXES = (
73
+ # pytest + plugin stack
74
+ "pytest",
75
+ "_pytest",
76
+ "pluggy",
77
+ # IPython/Jupyter stack
78
+ "IPython",
79
+ "traitlets",
80
+ "tornado",
81
+ "jupyter_client",
82
+ "jupyter_core",
83
+ "ipykernel",
84
+ # IDE/debugger stack (PyCharm)
85
+ "pydevd",
86
+ "_pydevd_bundle",
87
+ "_pydev_bundle",
88
+ # Python internals
89
+ "importlib",
90
+ "runpy",
91
+ "inspect",
92
+ "traceback",
93
+ "contextlib",
94
+ "asyncio",
95
+ "threading",
96
+ # DerivaML CLI runners - skip to find user's model code
97
+ "deriva_ml.run_model",
98
+ "deriva_ml.run_notebook",
99
+ # Hydra/hydra-zen internals
100
+ "hydra",
101
+ "hydra_zen",
102
+ "omegaconf",
103
+ )
104
+
105
+
106
+ # --- Helpers focused on determining the current "python model" (file) ---
107
+
108
+
109
+ def _top_user_frame() -> Optional[FrameType]:
110
+ """Return the outermost (top-level) non-tooling frame from the current stack.
111
+
112
+ This function traverses the call stack from the current execution point
113
+ back to the entry point, filtering out known tooling (pytest, IDE helpers,
114
+ Jupyter internals) and returns the highest-level frame that belongs to
115
+ user code.
116
+ """
117
+ tooling_prefixes = _SYSTEM_MODULE_PREFIXES
118
+ tooling_filename_parts = (
119
+ "pydevconsole.py", # PyCharm REPL console
120
+ "/pydev/", # PyCharm helpers path segment
121
+ "/_pydevd_bundle/",
122
+ "/_pydev_bundle/",
123
+ "_pytest",
124
+ "/pycharm/",
125
+ # DerivaML CLI entry points - skip to find user's model code
126
+ "/deriva_ml/run_model.py",
127
+ "/deriva_ml/run_notebook.py",
128
+ # Hydra/hydra-zen internals
129
+ "/hydra/",
130
+ "/hydra_zen/",
131
+ "/omegaconf/",
132
+ )
133
+
134
+ f = inspect.currentframe()
135
+ last_user_frame = None
136
+
137
+ if f is not None:
138
+ f = f.f_back # Skip the _top_user_frame itself
139
+
140
+ while f is not None:
141
+ filename = f.f_code.co_filename or ""
142
+ mod_name = f.f_globals.get("__name__", "") or ""
143
+
144
+ # 1. Treat IPython cell as user code
145
+ if _is_pseudo_user_filename(filename):
146
+ last_user_frame = f
147
+ f = f.f_back
148
+ continue
149
+
150
+ # 2. Skip other pseudo files like <stdin>, <string>, etc., unless __main__
151
+ if filename.startswith("<") and filename.endswith(">") and mod_name not in ("__main__", "__mp_main__"):
152
+ f = f.f_back
153
+ continue
154
+
155
+ # 3. Skip known tooling frames by module prefix
156
+ if any(mod_name == p or mod_name.startswith(p + ".") for p in tooling_prefixes):
157
+ f = f.f_back
158
+ continue
159
+
160
+ # 4. Skip known tooling frames by filename patterns
161
+ if any(part in filename for part in tooling_filename_parts):
162
+ f = f.f_back
163
+ continue
164
+
165
+ # 5. Skip frames that belong to this helper module (find_caller.py)
166
+ try:
167
+ cur = str(Path(filename).resolve())
168
+ this = str(Path(__file__).resolve())
169
+ if cur == this:
170
+ f = f.f_back
171
+ continue
172
+ except Exception:
173
+ pass
174
+
175
+ # If it passed all filters, it is a user frame.
176
+ # We record it and keep going back to find an even "higher" one.
177
+ last_user_frame = f
178
+ f = f.f_back
179
+
180
+ return last_user_frame
181
+
182
+
183
+ def _get_notebook_path() -> Optional[str]:
184
+ """Best‑effort to obtain the current Jupyter notebook path.
185
+
186
+ Returns absolute path string if discoverable, else None.
187
+ """
188
+ ip = get_ipython()
189
+ if ip is None:
190
+ return None
191
+
192
+ # Must be running inside a kernel with a connection file
193
+ if _get_kernel_connection is None:
194
+ return None
195
+ try:
196
+ connection_file = Path(_get_kernel_connection()).name # type: ignore[operator]
197
+ except Exception:
198
+ return None
199
+
200
+ # Need jupyter-server and requests to query sessions
201
+ if _list_running_servers is None or requests is None:
202
+ return None
203
+
204
+ # Extract kernel ID from connection filename.
205
+ # Standard Jupyter format: "kernel-<kernel_id>.json"
206
+ # PyCharm/other formats may vary: "<kernel_id>.json" or other patterns
207
+ kernel_id = None
208
+ if connection_file.startswith("kernel-") and "-" in connection_file:
209
+ # Standard format: kernel-<uuid>.json
210
+ parts = connection_file.split("-", 1)
211
+ if len(parts) > 1:
212
+ kernel_id = parts[1].rsplit(".", 1)[0]
213
+ else:
214
+ # Fallback: assume filename (without extension) is the kernel ID
215
+ kernel_id = connection_file.rsplit(".", 1)[0]
216
+
217
+ if not kernel_id:
218
+ return None
219
+
220
+ try:
221
+ servers = list(_list_running_servers()) # type: ignore[func-returns-value]
222
+ except Exception:
223
+ return None
224
+
225
+ for server in servers:
226
+ try:
227
+ token = server.get("token", "")
228
+ headers = {"Authorization": f"token {token}"} if token else {}
229
+ url = server["url"] + "api/sessions"
230
+ resp = requests.get(url, headers=headers, timeout=3) # type: ignore[attr-defined]
231
+ resp.raise_for_status()
232
+ for sess in resp.json():
233
+ if sess.get("kernel", {}).get("id") == kernel_id:
234
+ rel = sess.get("notebook", {}).get("path")
235
+ if rel:
236
+ root_dir = server.get("root_dir") or server.get("notebook_dir")
237
+ if root_dir:
238
+ return str(Path(root_dir) / rel)
239
+ except RequestException:
240
+ continue
241
+ except Exception:
242
+ continue
243
+ return None
244
+
245
+
246
+ def _get_calling_module() -> str:
247
+ """Return the relevant source filename for the current execution context.
248
+
249
+ Behavior:
250
+ 1) In Jupyter Notebook/Hub: returns the .ipynb file path.
251
+ 2) In a script: returns the script filename.
252
+ 3) In pytest or any REPL (PyCharm or regular): returns the filename that
253
+ contains the function currently executing (nearest user frame).
254
+ 4) If executing code from an installed package in a venv, still returns that
255
+ package module file (we do NOT exclude site-packages).
256
+ """
257
+ # 1) Jupyter notebook
258
+ nb = _get_notebook_path()
259
+ if nb:
260
+ return str(Path(nb))
261
+
262
+ # 2) If running as a script (python myscript.py), prefer __main__.__file__ or argv[0]
263
+ def _is_tooling_script_path(p: str) -> bool:
264
+ # Normalize path to forward slashes and lowercase for robust substring checks
265
+ pn = p.replace("\\", "/").casefold()
266
+ # Detect common IDE/console helper scripts and CLI runners
267
+ tooling_markers = (
268
+ "pydevconsole.py",
269
+ "/pydev/",
270
+ "/_pydevd_bundle/",
271
+ "/_pydev_bundle/",
272
+ # DerivaML CLI entry points - skip to find user's model code
273
+ "/deriva_ml/run_model.py",
274
+ "/deriva_ml/run_notebook.py",
275
+ )
276
+ return any(m in pn for m in tooling_markers)
277
+
278
+ f = _top_user_frame()
279
+ if f is not None:
280
+ return _norm(f.f_code.co_filename)
281
+ main_mod = sys.modules.get("__main__")
282
+ main_file = getattr(main_mod, "__file__", None)
283
+
284
+ if isinstance(main_file, str) and main_file:
285
+ if not _is_tooling_script_path(main_file):
286
+ return _norm(main_file)
287
+ if sys.argv and sys.argv[0] and sys.argv[0] != "-c":
288
+ if not _is_tooling_script_path(sys.argv[0]):
289
+ return _norm(sys.argv[0])
290
+
291
+ # 3) Pytest/REPL/IDE: use nearest user frame
292
+ f = _top_user_frame()
293
+
294
+ if f is not None:
295
+ return _norm(f.f_code.co_filename)
296
+
297
+ # Fallback: <stdin> or current working directory marker
298
+ return str(Path.cwd() / "REPL")
@@ -0,0 +1,175 @@
1
+ """
2
+ DerivaML Model Protocol
3
+ =======================
4
+
5
+ This module defines the protocol (interface) that model functions must follow
6
+ to work with DerivaML's execution framework and the run_model() function.
7
+
8
+ The DerivaMLModel protocol specifies that models must accept two special
9
+ keyword arguments that are injected at runtime:
10
+
11
+ - ml_instance: The DerivaML (or subclass) instance for catalog operations
12
+ - execution: The Execution context for managing inputs, outputs, and provenance
13
+
14
+ All other parameters are model-specific and configured via Hydra.
15
+
16
+ Example
17
+ -------
18
+ A compliant model function:
19
+
20
+ def train_classifier(
21
+ # Model-specific parameters (configured via Hydra)
22
+ epochs: int = 10,
23
+ learning_rate: float = 0.001,
24
+ batch_size: int = 32,
25
+ # Runtime parameters (injected by run_model)
26
+ ml_instance: DerivaML = None,
27
+ execution: Execution = None,
28
+ ) -> None:
29
+ '''Train a classifier within the DerivaML execution context.'''
30
+
31
+ # Download input datasets
32
+ for dataset_spec in execution.datasets:
33
+ bag = execution.download_dataset_bag(dataset_spec)
34
+ images = load_images_from_bag(bag)
35
+
36
+ # Train the model
37
+ model = MyClassifier()
38
+ for epoch in range(epochs):
39
+ train_one_epoch(model, images, learning_rate, batch_size)
40
+
41
+ # Save outputs (will be uploaded to catalog)
42
+ model_path = execution.asset_file_path("Model", "model.pt")
43
+ torch.save(model.state_dict(), model_path)
44
+
45
+ metrics_path = execution.asset_file_path("Execution_Metadata", "metrics.json")
46
+ with open(metrics_path, "w") as f:
47
+ json.dump({"final_accuracy": 0.95}, f)
48
+
49
+ Registering with Hydra-Zen
50
+ --------------------------
51
+ Wrap your model with builds() and zen_partial=True:
52
+
53
+ from hydra_zen import builds, store
54
+
55
+ TrainClassifierConfig = builds(
56
+ train_classifier,
57
+ epochs=10,
58
+ learning_rate=0.001,
59
+ batch_size=32,
60
+ zen_partial=True, # Creates a partial function
61
+ )
62
+
63
+ # Register in the model_config group
64
+ store(TrainClassifierConfig, group="model_config", name="default_model")
65
+
66
+ # Create variants with different defaults
67
+ store(TrainClassifierConfig, epochs=50, group="model_config", name="extended")
68
+ store(TrainClassifierConfig, epochs=5, group="model_config", name="quick")
69
+
70
+ Type Checking
71
+ -------------
72
+ Use the DerivaMLModel protocol for type hints in utilities:
73
+
74
+ from deriva_ml.execution.model_protocol import DerivaMLModel
75
+
76
+ def validate_model(model: DerivaMLModel) -> bool:
77
+ '''Check if a callable conforms to the model protocol.'''
78
+ return isinstance(model, DerivaMLModel)
79
+
80
+ The protocol uses @runtime_checkable, so isinstance() checks work at runtime.
81
+ """
82
+
83
+ from __future__ import annotations
84
+
85
+ from typing import Protocol, Any, runtime_checkable, TYPE_CHECKING
86
+
87
+ if TYPE_CHECKING:
88
+ from deriva_ml import DerivaML
89
+ from deriva_ml.execution.execution import Execution
90
+
91
+
92
+ @runtime_checkable
93
+ class DerivaMLModel(Protocol):
94
+ """Protocol for model functions compatible with DerivaML's run_model().
95
+
96
+ A model function must accept keyword arguments `ml_instance` and `execution`
97
+ that are injected at runtime by run_model(). All other parameters are
98
+ configured via Hydra and passed through the model_config.
99
+
100
+ The model function is responsible for:
101
+ 1. Downloading input datasets via execution.download_dataset_bag()
102
+ 2. Performing the ML computation (training, inference, etc.)
103
+ 3. Registering output files via execution.asset_file_path()
104
+
105
+ Output files registered with asset_file_path() are automatically uploaded
106
+ to the catalog after the model completes.
107
+
108
+ Attributes
109
+ ----------
110
+ This protocol defines a callable signature, not attributes.
111
+
112
+ Examples
113
+ --------
114
+ Basic model function:
115
+
116
+ def my_model(
117
+ epochs: int = 10,
118
+ ml_instance: DerivaML = None,
119
+ execution: Execution = None,
120
+ ) -> None:
121
+ # Training logic here
122
+ pass
123
+
124
+ With domain-specific DerivaML subclass:
125
+
126
+ def eyeai_model(
127
+ threshold: float = 0.5,
128
+ ml_instance: EyeAI = None, # EyeAI is a DerivaML subclass
129
+ execution: Execution = None,
130
+ ) -> None:
131
+ # Can use EyeAI-specific methods
132
+ ml_instance.some_eyeai_method()
133
+
134
+ Checking protocol compliance:
135
+
136
+ >>> from deriva_ml.execution.model_protocol import DerivaMLModel
137
+ >>> isinstance(my_model, DerivaMLModel)
138
+ True
139
+ """
140
+
141
+ def __call__(
142
+ self,
143
+ *args: Any,
144
+ ml_instance: "DerivaML",
145
+ execution: "Execution",
146
+ **kwargs: Any,
147
+ ) -> None:
148
+ """Execute the model within a DerivaML execution context.
149
+
150
+ Parameters
151
+ ----------
152
+ *args : Any
153
+ Positional arguments (typically not used; prefer keyword args).
154
+ ml_instance : DerivaML
155
+ The DerivaML instance (or subclass like EyeAI) connected to the
156
+ catalog. Use this for catalog operations not available through
157
+ the execution context.
158
+ execution : Execution
159
+ The execution context manager. Provides:
160
+ - execution.datasets: List of input DatasetSpec objects
161
+ - execution.download_dataset_bag(): Download dataset as BDBag
162
+ - execution.asset_file_path(): Register output file for upload
163
+ - execution.working_dir: Path to local working directory
164
+ **kwargs : Any
165
+ Model-specific parameters configured via Hydra.
166
+
167
+ Returns
168
+ -------
169
+ None
170
+ Models should not return values. Results are captured through:
171
+ - Files registered with asset_file_path() (uploaded to catalog)
172
+ - Datasets created with execution.create_dataset()
173
+ - Status updates via execution.update_status()
174
+ """
175
+ ...
@@ -0,0 +1,153 @@
1
+ """Multirun configuration for DerivaML experiments.
2
+
3
+ This module provides a way to define named multirun configurations that bundle
4
+ together Hydra overrides and a description. This allows you to document complex
5
+ experiment sweeps in code rather than on the command line.
6
+
7
+ Usage:
8
+ # In configs/multiruns.py
9
+ from deriva_ml.execution import multirun_config
10
+
11
+ multirun_config(
12
+ "quick_vs_extended",
13
+ overrides=[
14
+ "+experiment=cifar10_quick,cifar10_extended",
15
+ ],
16
+ description="## Quick vs Extended Comparison\\n\\nComparing training configs...",
17
+ )
18
+
19
+ multirun_config(
20
+ "lr_sweep",
21
+ overrides=[
22
+ "+experiment=cifar10_lr_sweep",
23
+ "model_config.learning_rate=0.0001,0.001,0.01,0.1",
24
+ ],
25
+ description="## Learning Rate Sweep\\n\\nExploring optimal learning rates...",
26
+ )
27
+
28
+ Then run with:
29
+ deriva-ml-run +multirun=quick_vs_extended
30
+ deriva-ml-run +multirun=lr_sweep model_config.epochs=5 # Can still override
31
+
32
+ Benefits:
33
+ - Explicit declaration of multirun experiments
34
+ - Rich markdown descriptions for parent executions
35
+ - Reproducible sweeps documented in code
36
+ - Same Hydra override syntax as command line
37
+ """
38
+
39
+ from dataclasses import dataclass, field
40
+ from typing import Any
41
+
42
+
43
+ @dataclass
44
+ class MultirunSpec:
45
+ """Specification for a multirun experiment.
46
+
47
+ Attributes:
48
+ name: Unique identifier for this multirun configuration.
49
+ overrides: List of Hydra override strings (same syntax as command line).
50
+ Examples:
51
+ - "+experiment=cifar10_quick,cifar10_extended"
52
+ - "model_config.learning_rate=0.0001,0.001,0.01"
53
+ - "model_config.epochs=5,10,25,50"
54
+ description: Rich description for the parent execution. Supports full
55
+ markdown formatting (headers, tables, bold, etc.).
56
+ """
57
+ name: str
58
+ overrides: list[str] = field(default_factory=list)
59
+ description: str = ""
60
+
61
+
62
+ # Global registry of multirun configurations
63
+ _multirun_registry: dict[str, MultirunSpec] = {}
64
+
65
+
66
+ def multirun_config(
67
+ name: str,
68
+ overrides: list[str],
69
+ description: str = "",
70
+ ) -> MultirunSpec:
71
+ """Register a named multirun configuration.
72
+
73
+ This function registers a multirun specification that can be invoked with
74
+ `deriva-ml-run +multirun=<name>`. The overrides use the same syntax as
75
+ Hydra command-line overrides.
76
+
77
+ Args:
78
+ name: Unique name for this multirun configuration. Used to invoke it
79
+ via `+multirun=<name>`.
80
+ overrides: List of Hydra override strings. These are the same overrides
81
+ you would pass on the command line after `--multirun`. Examples:
82
+ - "+experiment=cifar10_quick,cifar10_extended" - run multiple experiments
83
+ - "model_config.learning_rate=0.0001,0.001,0.01" - sweep a parameter
84
+ - "datasets=small,medium,large" - sweep datasets
85
+ description: Rich description for the parent execution. This supports
86
+ full markdown formatting since it's defined in Python, not on the
87
+ command line. Use this to document:
88
+ - What experiments are being compared and why
89
+ - Expected outcomes
90
+ - Methodology and metrics to analyze
91
+
92
+ Returns:
93
+ The registered MultirunSpec instance.
94
+
95
+ Example:
96
+ >>> from deriva_ml.execution import multirun_config
97
+ >>>
98
+ >>> multirun_config(
99
+ ... "lr_sweep",
100
+ ... overrides=[
101
+ ... "+experiment=cifar10_lr_sweep",
102
+ ... "model_config.learning_rate=0.0001,0.001,0.01,0.1",
103
+ ... ],
104
+ ... description='''## Learning Rate Sweep
105
+ ...
106
+ ... **Objective:** Find optimal learning rate for CIFAR-10 CNN.
107
+ ...
108
+ ... | Learning Rate | Expected Behavior |
109
+ ... |--------------|-------------------|
110
+ ... | 0.0001 | Slow convergence |
111
+ ... | 0.001 | Standard baseline |
112
+ ... | 0.01 | Fast, may overshoot |
113
+ ... | 0.1 | Likely unstable |
114
+ ... ''',
115
+ ... )
116
+ """
117
+ spec = MultirunSpec(
118
+ name=name,
119
+ overrides=overrides,
120
+ description=description,
121
+ )
122
+ _multirun_registry[name] = spec
123
+ return spec
124
+
125
+
126
+ def get_multirun_config(name: str) -> MultirunSpec | None:
127
+ """Look up a registered multirun configuration by name.
128
+
129
+ Args:
130
+ name: The name of the multirun configuration.
131
+
132
+ Returns:
133
+ The MultirunSpec if found, None otherwise.
134
+ """
135
+ return _multirun_registry.get(name)
136
+
137
+
138
+ def list_multirun_configs() -> list[str]:
139
+ """List all registered multirun configuration names.
140
+
141
+ Returns:
142
+ List of registered multirun config names.
143
+ """
144
+ return list(_multirun_registry.keys())
145
+
146
+
147
+ def get_all_multirun_configs() -> dict[str, MultirunSpec]:
148
+ """Get all registered multirun configurations.
149
+
150
+ Returns:
151
+ Dictionary mapping names to MultirunSpec instances.
152
+ """
153
+ return dict(_multirun_registry)