deriva-ml 1.17.9__py3-none-any.whl → 1.17.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deriva_ml/__init__.py +43 -1
- deriva_ml/asset/__init__.py +17 -0
- deriva_ml/asset/asset.py +357 -0
- deriva_ml/asset/aux_classes.py +100 -0
- deriva_ml/bump_version.py +254 -11
- deriva_ml/catalog/__init__.py +21 -0
- deriva_ml/catalog/clone.py +1199 -0
- deriva_ml/catalog/localize.py +426 -0
- deriva_ml/core/__init__.py +29 -0
- deriva_ml/core/base.py +817 -1067
- deriva_ml/core/config.py +169 -21
- deriva_ml/core/constants.py +120 -19
- deriva_ml/core/definitions.py +123 -13
- deriva_ml/core/enums.py +47 -73
- deriva_ml/core/ermrest.py +226 -193
- deriva_ml/core/exceptions.py +297 -14
- deriva_ml/core/filespec.py +99 -28
- deriva_ml/core/logging_config.py +225 -0
- deriva_ml/core/mixins/__init__.py +42 -0
- deriva_ml/core/mixins/annotation.py +915 -0
- deriva_ml/core/mixins/asset.py +384 -0
- deriva_ml/core/mixins/dataset.py +237 -0
- deriva_ml/core/mixins/execution.py +408 -0
- deriva_ml/core/mixins/feature.py +365 -0
- deriva_ml/core/mixins/file.py +263 -0
- deriva_ml/core/mixins/path_builder.py +145 -0
- deriva_ml/core/mixins/rid_resolution.py +204 -0
- deriva_ml/core/mixins/vocabulary.py +400 -0
- deriva_ml/core/mixins/workflow.py +322 -0
- deriva_ml/core/validation.py +389 -0
- deriva_ml/dataset/__init__.py +2 -1
- deriva_ml/dataset/aux_classes.py +20 -4
- deriva_ml/dataset/catalog_graph.py +575 -0
- deriva_ml/dataset/dataset.py +1242 -1008
- deriva_ml/dataset/dataset_bag.py +1311 -182
- deriva_ml/dataset/history.py +27 -14
- deriva_ml/dataset/upload.py +225 -38
- deriva_ml/demo_catalog.py +186 -105
- deriva_ml/execution/__init__.py +46 -2
- deriva_ml/execution/base_config.py +639 -0
- deriva_ml/execution/execution.py +545 -244
- deriva_ml/execution/execution_configuration.py +26 -11
- deriva_ml/execution/execution_record.py +592 -0
- deriva_ml/execution/find_caller.py +298 -0
- deriva_ml/execution/model_protocol.py +175 -0
- deriva_ml/execution/multirun_config.py +153 -0
- deriva_ml/execution/runner.py +595 -0
- deriva_ml/execution/workflow.py +224 -35
- deriva_ml/experiment/__init__.py +8 -0
- deriva_ml/experiment/experiment.py +411 -0
- deriva_ml/feature.py +6 -1
- deriva_ml/install_kernel.py +143 -6
- deriva_ml/interfaces.py +862 -0
- deriva_ml/model/__init__.py +99 -0
- deriva_ml/model/annotations.py +1278 -0
- deriva_ml/model/catalog.py +286 -60
- deriva_ml/model/database.py +144 -649
- deriva_ml/model/deriva_ml_database.py +308 -0
- deriva_ml/model/handles.py +14 -0
- deriva_ml/run_model.py +319 -0
- deriva_ml/run_notebook.py +507 -38
- deriva_ml/schema/__init__.py +18 -2
- deriva_ml/schema/annotations.py +62 -33
- deriva_ml/schema/create_schema.py +169 -69
- deriva_ml/schema/validation.py +601 -0
- {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -5
- deriva_ml-1.17.11.dist-info/RECORD +77 -0
- {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
- {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +2 -0
- deriva_ml/protocols/dataset.py +0 -19
- deriva_ml/test.py +0 -94
- deriva_ml-1.17.9.dist-info/RECORD +0 -45
- {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
- {deriva_ml-1.17.9.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from types import FrameType
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
try: # optional imports — used only when running in notebooks
|
|
10
|
+
from IPython.core.getipython import get_ipython # type: ignore
|
|
11
|
+
except Exception: # pragma: no cover - optional
|
|
12
|
+
|
|
13
|
+
def get_ipython(): # type: ignore
|
|
14
|
+
return None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
try: # optional — only available when inside a kernel
|
|
18
|
+
from ipykernel.connect import get_connection_file as _get_kernel_connection
|
|
19
|
+
except Exception: # pragma: no cover - optional
|
|
20
|
+
_get_kernel_connection = None # type: ignore
|
|
21
|
+
|
|
22
|
+
try: # optional — only available when jupyter-server is installed
|
|
23
|
+
from jupyter_server.serverapp import list_running_servers as _list_running_servers # type: ignore
|
|
24
|
+
except Exception: # pragma: no cover - optional
|
|
25
|
+
_list_running_servers = None # type: ignore
|
|
26
|
+
|
|
27
|
+
try: # optional — HTTP call to Jupyter server API
|
|
28
|
+
import requests # type: ignore
|
|
29
|
+
from requests import RequestException # type: ignore
|
|
30
|
+
except Exception: # pragma: no cover - optional
|
|
31
|
+
requests = None # type: ignore
|
|
32
|
+
RequestException = Exception # type: ignore
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _norm(p: str) -> str:
|
|
36
|
+
"""Normalize a path string using pathlib.
|
|
37
|
+
|
|
38
|
+
- Expands ~
|
|
39
|
+
- Resolves to absolute path
|
|
40
|
+
- Returns a string path
|
|
41
|
+
Note: We no longer apply os.path.normcase explicitly; pathlib's resolve
|
|
42
|
+
provides a consistent absolute path. This should be sufficient for our
|
|
43
|
+
use-cases across platforms.
|
|
44
|
+
"""
|
|
45
|
+
try:
|
|
46
|
+
return str(Path(p).expanduser().resolve())
|
|
47
|
+
except Exception:
|
|
48
|
+
# As a very last resort, return the original string
|
|
49
|
+
return p
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Treat certain pseudo filenames from IPython/Jupyter as user code so they
|
|
53
|
+
# can be selected as the calling location when appropriate (e.g., in REPL).
|
|
54
|
+
def _is_pseudo_user_filename(filename: str) -> bool:
|
|
55
|
+
"""Return True if filename looks like an IPython/Jupyter pseudo file.
|
|
56
|
+
|
|
57
|
+
Examples that should return True:
|
|
58
|
+
- "<ipython-input-7-abcdef>"
|
|
59
|
+
- "<jupyter-input-3-123456>"
|
|
60
|
+
- "<ipykernel_12345>"
|
|
61
|
+
|
|
62
|
+
Other pseudo files like "<stdin>" or "<string>" should return False here
|
|
63
|
+
so they can be treated by the generic pseudo-file handling below.
|
|
64
|
+
"""
|
|
65
|
+
if not (filename.startswith("<") and filename.endswith(">")):
|
|
66
|
+
return False
|
|
67
|
+
lower = filename.lower()
|
|
68
|
+
return lower.startswith("<ipython-input-") or lower.startswith("<jupyter-input-") or lower.startswith("<ipykernel_")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# Names that frequently represent "system/tooling" frames rather than user code
|
|
72
|
+
_SYSTEM_MODULE_PREFIXES = (
|
|
73
|
+
# pytest + plugin stack
|
|
74
|
+
"pytest",
|
|
75
|
+
"_pytest",
|
|
76
|
+
"pluggy",
|
|
77
|
+
# IPython/Jupyter stack
|
|
78
|
+
"IPython",
|
|
79
|
+
"traitlets",
|
|
80
|
+
"tornado",
|
|
81
|
+
"jupyter_client",
|
|
82
|
+
"jupyter_core",
|
|
83
|
+
"ipykernel",
|
|
84
|
+
# IDE/debugger stack (PyCharm)
|
|
85
|
+
"pydevd",
|
|
86
|
+
"_pydevd_bundle",
|
|
87
|
+
"_pydev_bundle",
|
|
88
|
+
# Python internals
|
|
89
|
+
"importlib",
|
|
90
|
+
"runpy",
|
|
91
|
+
"inspect",
|
|
92
|
+
"traceback",
|
|
93
|
+
"contextlib",
|
|
94
|
+
"asyncio",
|
|
95
|
+
"threading",
|
|
96
|
+
# DerivaML CLI runners - skip to find user's model code
|
|
97
|
+
"deriva_ml.run_model",
|
|
98
|
+
"deriva_ml.run_notebook",
|
|
99
|
+
# Hydra/hydra-zen internals
|
|
100
|
+
"hydra",
|
|
101
|
+
"hydra_zen",
|
|
102
|
+
"omegaconf",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# --- Helpers focused on determining the current "python model" (file) ---
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _top_user_frame() -> Optional[FrameType]:
|
|
110
|
+
"""Return the outermost (top-level) non-tooling frame from the current stack.
|
|
111
|
+
|
|
112
|
+
This function traverses the call stack from the current execution point
|
|
113
|
+
back to the entry point, filtering out known tooling (pytest, IDE helpers,
|
|
114
|
+
Jupyter internals) and returns the highest-level frame that belongs to
|
|
115
|
+
user code.
|
|
116
|
+
"""
|
|
117
|
+
tooling_prefixes = _SYSTEM_MODULE_PREFIXES
|
|
118
|
+
tooling_filename_parts = (
|
|
119
|
+
"pydevconsole.py", # PyCharm REPL console
|
|
120
|
+
"/pydev/", # PyCharm helpers path segment
|
|
121
|
+
"/_pydevd_bundle/",
|
|
122
|
+
"/_pydev_bundle/",
|
|
123
|
+
"_pytest",
|
|
124
|
+
"/pycharm/",
|
|
125
|
+
# DerivaML CLI entry points - skip to find user's model code
|
|
126
|
+
"/deriva_ml/run_model.py",
|
|
127
|
+
"/deriva_ml/run_notebook.py",
|
|
128
|
+
# Hydra/hydra-zen internals
|
|
129
|
+
"/hydra/",
|
|
130
|
+
"/hydra_zen/",
|
|
131
|
+
"/omegaconf/",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
f = inspect.currentframe()
|
|
135
|
+
last_user_frame = None
|
|
136
|
+
|
|
137
|
+
if f is not None:
|
|
138
|
+
f = f.f_back # Skip the _top_user_frame itself
|
|
139
|
+
|
|
140
|
+
while f is not None:
|
|
141
|
+
filename = f.f_code.co_filename or ""
|
|
142
|
+
mod_name = f.f_globals.get("__name__", "") or ""
|
|
143
|
+
|
|
144
|
+
# 1. Treat IPython cell as user code
|
|
145
|
+
if _is_pseudo_user_filename(filename):
|
|
146
|
+
last_user_frame = f
|
|
147
|
+
f = f.f_back
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
# 2. Skip other pseudo files like <stdin>, <string>, etc., unless __main__
|
|
151
|
+
if filename.startswith("<") and filename.endswith(">") and mod_name not in ("__main__", "__mp_main__"):
|
|
152
|
+
f = f.f_back
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# 3. Skip known tooling frames by module prefix
|
|
156
|
+
if any(mod_name == p or mod_name.startswith(p + ".") for p in tooling_prefixes):
|
|
157
|
+
f = f.f_back
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
# 4. Skip known tooling frames by filename patterns
|
|
161
|
+
if any(part in filename for part in tooling_filename_parts):
|
|
162
|
+
f = f.f_back
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
# 5. Skip frames that belong to this helper module (find_caller.py)
|
|
166
|
+
try:
|
|
167
|
+
cur = str(Path(filename).resolve())
|
|
168
|
+
this = str(Path(__file__).resolve())
|
|
169
|
+
if cur == this:
|
|
170
|
+
f = f.f_back
|
|
171
|
+
continue
|
|
172
|
+
except Exception:
|
|
173
|
+
pass
|
|
174
|
+
|
|
175
|
+
# If it passed all filters, it is a user frame.
|
|
176
|
+
# We record it and keep going back to find an even "higher" one.
|
|
177
|
+
last_user_frame = f
|
|
178
|
+
f = f.f_back
|
|
179
|
+
|
|
180
|
+
return last_user_frame
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _get_notebook_path() -> Optional[str]:
|
|
184
|
+
"""Best‑effort to obtain the current Jupyter notebook path.
|
|
185
|
+
|
|
186
|
+
Returns absolute path string if discoverable, else None.
|
|
187
|
+
"""
|
|
188
|
+
ip = get_ipython()
|
|
189
|
+
if ip is None:
|
|
190
|
+
return None
|
|
191
|
+
|
|
192
|
+
# Must be running inside a kernel with a connection file
|
|
193
|
+
if _get_kernel_connection is None:
|
|
194
|
+
return None
|
|
195
|
+
try:
|
|
196
|
+
connection_file = Path(_get_kernel_connection()).name # type: ignore[operator]
|
|
197
|
+
except Exception:
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
# Need jupyter-server and requests to query sessions
|
|
201
|
+
if _list_running_servers is None or requests is None:
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
# Extract kernel ID from connection filename.
|
|
205
|
+
# Standard Jupyter format: "kernel-<kernel_id>.json"
|
|
206
|
+
# PyCharm/other formats may vary: "<kernel_id>.json" or other patterns
|
|
207
|
+
kernel_id = None
|
|
208
|
+
if connection_file.startswith("kernel-") and "-" in connection_file:
|
|
209
|
+
# Standard format: kernel-<uuid>.json
|
|
210
|
+
parts = connection_file.split("-", 1)
|
|
211
|
+
if len(parts) > 1:
|
|
212
|
+
kernel_id = parts[1].rsplit(".", 1)[0]
|
|
213
|
+
else:
|
|
214
|
+
# Fallback: assume filename (without extension) is the kernel ID
|
|
215
|
+
kernel_id = connection_file.rsplit(".", 1)[0]
|
|
216
|
+
|
|
217
|
+
if not kernel_id:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
servers = list(_list_running_servers()) # type: ignore[func-returns-value]
|
|
222
|
+
except Exception:
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
for server in servers:
|
|
226
|
+
try:
|
|
227
|
+
token = server.get("token", "")
|
|
228
|
+
headers = {"Authorization": f"token {token}"} if token else {}
|
|
229
|
+
url = server["url"] + "api/sessions"
|
|
230
|
+
resp = requests.get(url, headers=headers, timeout=3) # type: ignore[attr-defined]
|
|
231
|
+
resp.raise_for_status()
|
|
232
|
+
for sess in resp.json():
|
|
233
|
+
if sess.get("kernel", {}).get("id") == kernel_id:
|
|
234
|
+
rel = sess.get("notebook", {}).get("path")
|
|
235
|
+
if rel:
|
|
236
|
+
root_dir = server.get("root_dir") or server.get("notebook_dir")
|
|
237
|
+
if root_dir:
|
|
238
|
+
return str(Path(root_dir) / rel)
|
|
239
|
+
except RequestException:
|
|
240
|
+
continue
|
|
241
|
+
except Exception:
|
|
242
|
+
continue
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _get_calling_module() -> str:
|
|
247
|
+
"""Return the relevant source filename for the current execution context.
|
|
248
|
+
|
|
249
|
+
Behavior:
|
|
250
|
+
1) In Jupyter Notebook/Hub: returns the .ipynb file path.
|
|
251
|
+
2) In a script: returns the script filename.
|
|
252
|
+
3) In pytest or any REPL (PyCharm or regular): returns the filename that
|
|
253
|
+
contains the function currently executing (nearest user frame).
|
|
254
|
+
4) If executing code from an installed package in a venv, still returns that
|
|
255
|
+
package module file (we do NOT exclude site-packages).
|
|
256
|
+
"""
|
|
257
|
+
# 1) Jupyter notebook
|
|
258
|
+
nb = _get_notebook_path()
|
|
259
|
+
if nb:
|
|
260
|
+
return str(Path(nb))
|
|
261
|
+
|
|
262
|
+
# 2) If running as a script (python myscript.py), prefer __main__.__file__ or argv[0]
|
|
263
|
+
def _is_tooling_script_path(p: str) -> bool:
|
|
264
|
+
# Normalize path to forward slashes and lowercase for robust substring checks
|
|
265
|
+
pn = p.replace("\\", "/").casefold()
|
|
266
|
+
# Detect common IDE/console helper scripts and CLI runners
|
|
267
|
+
tooling_markers = (
|
|
268
|
+
"pydevconsole.py",
|
|
269
|
+
"/pydev/",
|
|
270
|
+
"/_pydevd_bundle/",
|
|
271
|
+
"/_pydev_bundle/",
|
|
272
|
+
# DerivaML CLI entry points - skip to find user's model code
|
|
273
|
+
"/deriva_ml/run_model.py",
|
|
274
|
+
"/deriva_ml/run_notebook.py",
|
|
275
|
+
)
|
|
276
|
+
return any(m in pn for m in tooling_markers)
|
|
277
|
+
|
|
278
|
+
f = _top_user_frame()
|
|
279
|
+
if f is not None:
|
|
280
|
+
return _norm(f.f_code.co_filename)
|
|
281
|
+
main_mod = sys.modules.get("__main__")
|
|
282
|
+
main_file = getattr(main_mod, "__file__", None)
|
|
283
|
+
|
|
284
|
+
if isinstance(main_file, str) and main_file:
|
|
285
|
+
if not _is_tooling_script_path(main_file):
|
|
286
|
+
return _norm(main_file)
|
|
287
|
+
if sys.argv and sys.argv[0] and sys.argv[0] != "-c":
|
|
288
|
+
if not _is_tooling_script_path(sys.argv[0]):
|
|
289
|
+
return _norm(sys.argv[0])
|
|
290
|
+
|
|
291
|
+
# 3) Pytest/REPL/IDE: use nearest user frame
|
|
292
|
+
f = _top_user_frame()
|
|
293
|
+
|
|
294
|
+
if f is not None:
|
|
295
|
+
return _norm(f.f_code.co_filename)
|
|
296
|
+
|
|
297
|
+
# Fallback: <stdin> or current working directory marker
|
|
298
|
+
return str(Path.cwd() / "REPL")
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DerivaML Model Protocol
|
|
3
|
+
=======================
|
|
4
|
+
|
|
5
|
+
This module defines the protocol (interface) that model functions must follow
|
|
6
|
+
to work with DerivaML's execution framework and the run_model() function.
|
|
7
|
+
|
|
8
|
+
The DerivaMLModel protocol specifies that models must accept two special
|
|
9
|
+
keyword arguments that are injected at runtime:
|
|
10
|
+
|
|
11
|
+
- ml_instance: The DerivaML (or subclass) instance for catalog operations
|
|
12
|
+
- execution: The Execution context for managing inputs, outputs, and provenance
|
|
13
|
+
|
|
14
|
+
All other parameters are model-specific and configured via Hydra.
|
|
15
|
+
|
|
16
|
+
Example
|
|
17
|
+
-------
|
|
18
|
+
A compliant model function:
|
|
19
|
+
|
|
20
|
+
def train_classifier(
|
|
21
|
+
# Model-specific parameters (configured via Hydra)
|
|
22
|
+
epochs: int = 10,
|
|
23
|
+
learning_rate: float = 0.001,
|
|
24
|
+
batch_size: int = 32,
|
|
25
|
+
# Runtime parameters (injected by run_model)
|
|
26
|
+
ml_instance: DerivaML = None,
|
|
27
|
+
execution: Execution = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
'''Train a classifier within the DerivaML execution context.'''
|
|
30
|
+
|
|
31
|
+
# Download input datasets
|
|
32
|
+
for dataset_spec in execution.datasets:
|
|
33
|
+
bag = execution.download_dataset_bag(dataset_spec)
|
|
34
|
+
images = load_images_from_bag(bag)
|
|
35
|
+
|
|
36
|
+
# Train the model
|
|
37
|
+
model = MyClassifier()
|
|
38
|
+
for epoch in range(epochs):
|
|
39
|
+
train_one_epoch(model, images, learning_rate, batch_size)
|
|
40
|
+
|
|
41
|
+
# Save outputs (will be uploaded to catalog)
|
|
42
|
+
model_path = execution.asset_file_path("Model", "model.pt")
|
|
43
|
+
torch.save(model.state_dict(), model_path)
|
|
44
|
+
|
|
45
|
+
metrics_path = execution.asset_file_path("Execution_Metadata", "metrics.json")
|
|
46
|
+
with open(metrics_path, "w") as f:
|
|
47
|
+
json.dump({"final_accuracy": 0.95}, f)
|
|
48
|
+
|
|
49
|
+
Registering with Hydra-Zen
|
|
50
|
+
--------------------------
|
|
51
|
+
Wrap your model with builds() and zen_partial=True:
|
|
52
|
+
|
|
53
|
+
from hydra_zen import builds, store
|
|
54
|
+
|
|
55
|
+
TrainClassifierConfig = builds(
|
|
56
|
+
train_classifier,
|
|
57
|
+
epochs=10,
|
|
58
|
+
learning_rate=0.001,
|
|
59
|
+
batch_size=32,
|
|
60
|
+
zen_partial=True, # Creates a partial function
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Register in the model_config group
|
|
64
|
+
store(TrainClassifierConfig, group="model_config", name="default_model")
|
|
65
|
+
|
|
66
|
+
# Create variants with different defaults
|
|
67
|
+
store(TrainClassifierConfig, epochs=50, group="model_config", name="extended")
|
|
68
|
+
store(TrainClassifierConfig, epochs=5, group="model_config", name="quick")
|
|
69
|
+
|
|
70
|
+
Type Checking
|
|
71
|
+
-------------
|
|
72
|
+
Use the DerivaMLModel protocol for type hints in utilities:
|
|
73
|
+
|
|
74
|
+
from deriva_ml.execution.model_protocol import DerivaMLModel
|
|
75
|
+
|
|
76
|
+
def validate_model(model: DerivaMLModel) -> bool:
|
|
77
|
+
'''Check if a callable conforms to the model protocol.'''
|
|
78
|
+
return isinstance(model, DerivaMLModel)
|
|
79
|
+
|
|
80
|
+
The protocol uses @runtime_checkable, so isinstance() checks work at runtime.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
from __future__ import annotations
|
|
84
|
+
|
|
85
|
+
from typing import Protocol, Any, runtime_checkable, TYPE_CHECKING
|
|
86
|
+
|
|
87
|
+
if TYPE_CHECKING:
|
|
88
|
+
from deriva_ml import DerivaML
|
|
89
|
+
from deriva_ml.execution.execution import Execution
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@runtime_checkable
|
|
93
|
+
class DerivaMLModel(Protocol):
|
|
94
|
+
"""Protocol for model functions compatible with DerivaML's run_model().
|
|
95
|
+
|
|
96
|
+
A model function must accept keyword arguments `ml_instance` and `execution`
|
|
97
|
+
that are injected at runtime by run_model(). All other parameters are
|
|
98
|
+
configured via Hydra and passed through the model_config.
|
|
99
|
+
|
|
100
|
+
The model function is responsible for:
|
|
101
|
+
1. Downloading input datasets via execution.download_dataset_bag()
|
|
102
|
+
2. Performing the ML computation (training, inference, etc.)
|
|
103
|
+
3. Registering output files via execution.asset_file_path()
|
|
104
|
+
|
|
105
|
+
Output files registered with asset_file_path() are automatically uploaded
|
|
106
|
+
to the catalog after the model completes.
|
|
107
|
+
|
|
108
|
+
Attributes
|
|
109
|
+
----------
|
|
110
|
+
This protocol defines a callable signature, not attributes.
|
|
111
|
+
|
|
112
|
+
Examples
|
|
113
|
+
--------
|
|
114
|
+
Basic model function:
|
|
115
|
+
|
|
116
|
+
def my_model(
|
|
117
|
+
epochs: int = 10,
|
|
118
|
+
ml_instance: DerivaML = None,
|
|
119
|
+
execution: Execution = None,
|
|
120
|
+
) -> None:
|
|
121
|
+
# Training logic here
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
With domain-specific DerivaML subclass:
|
|
125
|
+
|
|
126
|
+
def eyeai_model(
|
|
127
|
+
threshold: float = 0.5,
|
|
128
|
+
ml_instance: EyeAI = None, # EyeAI is a DerivaML subclass
|
|
129
|
+
execution: Execution = None,
|
|
130
|
+
) -> None:
|
|
131
|
+
# Can use EyeAI-specific methods
|
|
132
|
+
ml_instance.some_eyeai_method()
|
|
133
|
+
|
|
134
|
+
Checking protocol compliance:
|
|
135
|
+
|
|
136
|
+
>>> from deriva_ml.execution.model_protocol import DerivaMLModel
|
|
137
|
+
>>> isinstance(my_model, DerivaMLModel)
|
|
138
|
+
True
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __call__(
|
|
142
|
+
self,
|
|
143
|
+
*args: Any,
|
|
144
|
+
ml_instance: "DerivaML",
|
|
145
|
+
execution: "Execution",
|
|
146
|
+
**kwargs: Any,
|
|
147
|
+
) -> None:
|
|
148
|
+
"""Execute the model within a DerivaML execution context.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
*args : Any
|
|
153
|
+
Positional arguments (typically not used; prefer keyword args).
|
|
154
|
+
ml_instance : DerivaML
|
|
155
|
+
The DerivaML instance (or subclass like EyeAI) connected to the
|
|
156
|
+
catalog. Use this for catalog operations not available through
|
|
157
|
+
the execution context.
|
|
158
|
+
execution : Execution
|
|
159
|
+
The execution context manager. Provides:
|
|
160
|
+
- execution.datasets: List of input DatasetSpec objects
|
|
161
|
+
- execution.download_dataset_bag(): Download dataset as BDBag
|
|
162
|
+
- execution.asset_file_path(): Register output file for upload
|
|
163
|
+
- execution.working_dir: Path to local working directory
|
|
164
|
+
**kwargs : Any
|
|
165
|
+
Model-specific parameters configured via Hydra.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
None
|
|
170
|
+
Models should not return values. Results are captured through:
|
|
171
|
+
- Files registered with asset_file_path() (uploaded to catalog)
|
|
172
|
+
- Datasets created with execution.create_dataset()
|
|
173
|
+
- Status updates via execution.update_status()
|
|
174
|
+
"""
|
|
175
|
+
...
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Multirun configuration for DerivaML experiments.
|
|
2
|
+
|
|
3
|
+
This module provides a way to define named multirun configurations that bundle
|
|
4
|
+
together Hydra overrides and a description. This allows you to document complex
|
|
5
|
+
experiment sweeps in code rather than on the command line.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# In configs/multiruns.py
|
|
9
|
+
from deriva_ml.execution import multirun_config
|
|
10
|
+
|
|
11
|
+
multirun_config(
|
|
12
|
+
"quick_vs_extended",
|
|
13
|
+
overrides=[
|
|
14
|
+
"+experiment=cifar10_quick,cifar10_extended",
|
|
15
|
+
],
|
|
16
|
+
description="## Quick vs Extended Comparison\\n\\nComparing training configs...",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
multirun_config(
|
|
20
|
+
"lr_sweep",
|
|
21
|
+
overrides=[
|
|
22
|
+
"+experiment=cifar10_lr_sweep",
|
|
23
|
+
"model_config.learning_rate=0.0001,0.001,0.01,0.1",
|
|
24
|
+
],
|
|
25
|
+
description="## Learning Rate Sweep\\n\\nExploring optimal learning rates...",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
Then run with:
|
|
29
|
+
deriva-ml-run +multirun=quick_vs_extended
|
|
30
|
+
deriva-ml-run +multirun=lr_sweep model_config.epochs=5 # Can still override
|
|
31
|
+
|
|
32
|
+
Benefits:
|
|
33
|
+
- Explicit declaration of multirun experiments
|
|
34
|
+
- Rich markdown descriptions for parent executions
|
|
35
|
+
- Reproducible sweeps documented in code
|
|
36
|
+
- Same Hydra override syntax as command line
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
from dataclasses import dataclass, field
|
|
40
|
+
from typing import Any
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class MultirunSpec:
|
|
45
|
+
"""Specification for a multirun experiment.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
name: Unique identifier for this multirun configuration.
|
|
49
|
+
overrides: List of Hydra override strings (same syntax as command line).
|
|
50
|
+
Examples:
|
|
51
|
+
- "+experiment=cifar10_quick,cifar10_extended"
|
|
52
|
+
- "model_config.learning_rate=0.0001,0.001,0.01"
|
|
53
|
+
- "model_config.epochs=5,10,25,50"
|
|
54
|
+
description: Rich description for the parent execution. Supports full
|
|
55
|
+
markdown formatting (headers, tables, bold, etc.).
|
|
56
|
+
"""
|
|
57
|
+
name: str
|
|
58
|
+
overrides: list[str] = field(default_factory=list)
|
|
59
|
+
description: str = ""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Global registry of multirun configurations
|
|
63
|
+
_multirun_registry: dict[str, MultirunSpec] = {}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def multirun_config(
|
|
67
|
+
name: str,
|
|
68
|
+
overrides: list[str],
|
|
69
|
+
description: str = "",
|
|
70
|
+
) -> MultirunSpec:
|
|
71
|
+
"""Register a named multirun configuration.
|
|
72
|
+
|
|
73
|
+
This function registers a multirun specification that can be invoked with
|
|
74
|
+
`deriva-ml-run +multirun=<name>`. The overrides use the same syntax as
|
|
75
|
+
Hydra command-line overrides.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
name: Unique name for this multirun configuration. Used to invoke it
|
|
79
|
+
via `+multirun=<name>`.
|
|
80
|
+
overrides: List of Hydra override strings. These are the same overrides
|
|
81
|
+
you would pass on the command line after `--multirun`. Examples:
|
|
82
|
+
- "+experiment=cifar10_quick,cifar10_extended" - run multiple experiments
|
|
83
|
+
- "model_config.learning_rate=0.0001,0.001,0.01" - sweep a parameter
|
|
84
|
+
- "datasets=small,medium,large" - sweep datasets
|
|
85
|
+
description: Rich description for the parent execution. This supports
|
|
86
|
+
full markdown formatting since it's defined in Python, not on the
|
|
87
|
+
command line. Use this to document:
|
|
88
|
+
- What experiments are being compared and why
|
|
89
|
+
- Expected outcomes
|
|
90
|
+
- Methodology and metrics to analyze
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The registered MultirunSpec instance.
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
>>> from deriva_ml.execution import multirun_config
|
|
97
|
+
>>>
|
|
98
|
+
>>> multirun_config(
|
|
99
|
+
... "lr_sweep",
|
|
100
|
+
... overrides=[
|
|
101
|
+
... "+experiment=cifar10_lr_sweep",
|
|
102
|
+
... "model_config.learning_rate=0.0001,0.001,0.01,0.1",
|
|
103
|
+
... ],
|
|
104
|
+
... description='''## Learning Rate Sweep
|
|
105
|
+
...
|
|
106
|
+
... **Objective:** Find optimal learning rate for CIFAR-10 CNN.
|
|
107
|
+
...
|
|
108
|
+
... | Learning Rate | Expected Behavior |
|
|
109
|
+
... |--------------|-------------------|
|
|
110
|
+
... | 0.0001 | Slow convergence |
|
|
111
|
+
... | 0.001 | Standard baseline |
|
|
112
|
+
... | 0.01 | Fast, may overshoot |
|
|
113
|
+
... | 0.1 | Likely unstable |
|
|
114
|
+
... ''',
|
|
115
|
+
... )
|
|
116
|
+
"""
|
|
117
|
+
spec = MultirunSpec(
|
|
118
|
+
name=name,
|
|
119
|
+
overrides=overrides,
|
|
120
|
+
description=description,
|
|
121
|
+
)
|
|
122
|
+
_multirun_registry[name] = spec
|
|
123
|
+
return spec
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_multirun_config(name: str) -> MultirunSpec | None:
|
|
127
|
+
"""Look up a registered multirun configuration by name.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
name: The name of the multirun configuration.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
The MultirunSpec if found, None otherwise.
|
|
134
|
+
"""
|
|
135
|
+
return _multirun_registry.get(name)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def list_multirun_configs() -> list[str]:
|
|
139
|
+
"""List all registered multirun configuration names.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
List of registered multirun config names.
|
|
143
|
+
"""
|
|
144
|
+
return list(_multirun_registry.keys())
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_all_multirun_configs() -> dict[str, MultirunSpec]:
|
|
148
|
+
"""Get all registered multirun configurations.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Dictionary mapping names to MultirunSpec instances.
|
|
152
|
+
"""
|
|
153
|
+
return dict(_multirun_registry)
|