deriva-ml 1.17.10__py3-none-any.whl → 1.17.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. deriva_ml/__init__.py +43 -1
  2. deriva_ml/asset/__init__.py +17 -0
  3. deriva_ml/asset/asset.py +357 -0
  4. deriva_ml/asset/aux_classes.py +100 -0
  5. deriva_ml/bump_version.py +254 -11
  6. deriva_ml/catalog/__init__.py +21 -0
  7. deriva_ml/catalog/clone.py +1199 -0
  8. deriva_ml/catalog/localize.py +426 -0
  9. deriva_ml/core/__init__.py +29 -0
  10. deriva_ml/core/base.py +817 -1067
  11. deriva_ml/core/config.py +169 -21
  12. deriva_ml/core/constants.py +120 -19
  13. deriva_ml/core/definitions.py +123 -13
  14. deriva_ml/core/enums.py +47 -73
  15. deriva_ml/core/ermrest.py +226 -193
  16. deriva_ml/core/exceptions.py +297 -14
  17. deriva_ml/core/filespec.py +99 -28
  18. deriva_ml/core/logging_config.py +225 -0
  19. deriva_ml/core/mixins/__init__.py +42 -0
  20. deriva_ml/core/mixins/annotation.py +915 -0
  21. deriva_ml/core/mixins/asset.py +384 -0
  22. deriva_ml/core/mixins/dataset.py +237 -0
  23. deriva_ml/core/mixins/execution.py +408 -0
  24. deriva_ml/core/mixins/feature.py +365 -0
  25. deriva_ml/core/mixins/file.py +263 -0
  26. deriva_ml/core/mixins/path_builder.py +145 -0
  27. deriva_ml/core/mixins/rid_resolution.py +204 -0
  28. deriva_ml/core/mixins/vocabulary.py +400 -0
  29. deriva_ml/core/mixins/workflow.py +322 -0
  30. deriva_ml/core/validation.py +389 -0
  31. deriva_ml/dataset/__init__.py +2 -1
  32. deriva_ml/dataset/aux_classes.py +20 -4
  33. deriva_ml/dataset/catalog_graph.py +575 -0
  34. deriva_ml/dataset/dataset.py +1242 -1008
  35. deriva_ml/dataset/dataset_bag.py +1311 -182
  36. deriva_ml/dataset/history.py +27 -14
  37. deriva_ml/dataset/upload.py +225 -38
  38. deriva_ml/demo_catalog.py +126 -110
  39. deriva_ml/execution/__init__.py +46 -2
  40. deriva_ml/execution/base_config.py +639 -0
  41. deriva_ml/execution/execution.py +543 -242
  42. deriva_ml/execution/execution_configuration.py +26 -11
  43. deriva_ml/execution/execution_record.py +592 -0
  44. deriva_ml/execution/find_caller.py +298 -0
  45. deriva_ml/execution/model_protocol.py +175 -0
  46. deriva_ml/execution/multirun_config.py +153 -0
  47. deriva_ml/execution/runner.py +595 -0
  48. deriva_ml/execution/workflow.py +223 -34
  49. deriva_ml/experiment/__init__.py +8 -0
  50. deriva_ml/experiment/experiment.py +411 -0
  51. deriva_ml/feature.py +6 -1
  52. deriva_ml/install_kernel.py +143 -6
  53. deriva_ml/interfaces.py +862 -0
  54. deriva_ml/model/__init__.py +99 -0
  55. deriva_ml/model/annotations.py +1278 -0
  56. deriva_ml/model/catalog.py +286 -60
  57. deriva_ml/model/database.py +144 -649
  58. deriva_ml/model/deriva_ml_database.py +308 -0
  59. deriva_ml/model/handles.py +14 -0
  60. deriva_ml/run_model.py +319 -0
  61. deriva_ml/run_notebook.py +507 -38
  62. deriva_ml/schema/__init__.py +18 -2
  63. deriva_ml/schema/annotations.py +62 -33
  64. deriva_ml/schema/create_schema.py +169 -69
  65. deriva_ml/schema/validation.py +601 -0
  66. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/METADATA +4 -4
  67. deriva_ml-1.17.11.dist-info/RECORD +77 -0
  68. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/WHEEL +1 -1
  69. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/entry_points.txt +1 -0
  70. deriva_ml/protocols/dataset.py +0 -19
  71. deriva_ml/test.py +0 -94
  72. deriva_ml-1.17.10.dist-info/RECORD +0 -45
  73. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/licenses/LICENSE +0 -0
  74. {deriva_ml-1.17.10.dist-info → deriva_ml-1.17.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,595 @@
1
+ """
2
+ Deriva ML Model Runner
3
+ ======================
4
+
5
+ Generic model runner for executing ML workflows within DerivaML execution contexts.
6
+
7
+ This module provides the infrastructure to run ML models with full provenance tracking,
8
+ configuration management via Hydra, and support for parameter sweeps.
9
+
10
+ Key Features
11
+ ------------
12
+ - **Automatic execution context**: Creates execution records in the catalog
13
+ - **Multirun/sweep support**: Parent-child execution nesting for parameter sweeps
14
+ - **Hydra configuration**: Composable configs with command-line overrides
15
+ - **Subclass support**: Works with DerivaML subclasses (EyeAI, GUDMAP, etc.)
16
+ - **Provenance tracking**: Links inputs, outputs, and configuration
17
+
18
+ Model Protocol
19
+ --------------
20
+ Models must follow this signature pattern to work with run_model:
21
+
22
+ def my_model(
23
+ param1: int = 10,
24
+ param2: float = 0.01,
25
+ # ... other model parameters ...
26
+ ml_instance: DerivaML = None, # Injected at runtime
27
+ execution: Execution = None, # Injected at runtime
28
+ ) -> None:
29
+ '''Train/run the model within the execution context.'''
30
+ # Access input datasets
31
+ for dataset in execution.datasets:
32
+ bag = execution.download_dataset_bag(dataset)
33
+ # ... process data ...
34
+
35
+ # Register output files
36
+ model_path = execution.asset_file_path("Model", "model.pt")
37
+ torch.save(model.state_dict(), model_path)
38
+
39
+ metrics_path = execution.asset_file_path("Execution_Metadata", "metrics.json")
40
+ with open(metrics_path, "w") as f:
41
+ json.dump({"accuracy": 0.95}, f)
42
+
43
+ The `ml_instance` and `execution` parameters are injected by run_model at runtime.
44
+ All other parameters are configured via Hydra.
45
+
46
+ Quick Start
47
+ -----------
48
+ 1. Create your model function following the protocol above.
49
+
50
+ 2. Create a hydra-zen configuration for your model:
51
+
52
+ from hydra_zen import builds, store
53
+
54
+ # Wrap model with builds() and zen_partial=True for deferred execution
55
+ MyModelConfig = builds(my_model, param1=10, param2=0.01, zen_partial=True)
56
+ store(MyModelConfig, group="model_config", name="default_model")
57
+
58
+ 3. Set up the main runner script:
59
+
60
+ from deriva_ml import DerivaML
61
+ from deriva_ml.execution import run_model, create_model_config
62
+ from hydra_zen import store, zen
63
+
64
+ # Create the main config (uses DerivaML by default)
65
+ deriva_model = create_model_config(DerivaML)
66
+ store(deriva_model, name="deriva_model")
67
+
68
+ # Load your config modules
69
+ store.add_to_hydra_store()
70
+
71
+ # Launch
72
+ if __name__ == "__main__":
73
+ zen(run_model).hydra_main(config_name="deriva_model", version_base="1.3")
74
+
75
+ 4. Run from command line:
76
+
77
+ python my_runner.py # Run with defaults
78
+ python my_runner.py model_config.param1=20 # Override parameter
79
+ python my_runner.py dry_run=true # Test without catalog writes
80
+ python my_runner.py --multirun model_config.param1=10,20,30 # Parameter sweep
81
+
82
+ Domain Subclasses
83
+ -----------------
84
+ For domain-specific classes like EyeAI:
85
+
86
+ from eye_ai import EyeAI
87
+
88
+ # Create config with EyeAI instead of DerivaML
89
+ deriva_model = create_model_config(EyeAI, description="EyeAI analysis")
90
+
91
+ # Your model receives an EyeAI instance:
92
+ def my_eyeai_model(
93
+ ...,
94
+ ml_instance: EyeAI = None, # Now an EyeAI instance
95
+ execution: Execution = None,
96
+ ):
97
+ # Access EyeAI-specific methods
98
+ ml_instance.some_eyeai_method()
99
+
100
+ Multirun Parameter Sweeps
101
+ -------------------------
102
+ When using Hydra's multirun mode (--multirun or -m), run_model automatically:
103
+
104
+ 1. Creates a parent execution to group all sweep jobs
105
+ 2. Links each child execution to the parent with sequence ordering
106
+ 3. Records sweep configuration in the parent's description
107
+
108
+ Example sweep:
109
+
110
+ python my_runner.py --multirun model_config.learning_rate=0.001,0.01,0.1
111
+
112
+ This creates:
113
+ - Parent execution: "Multirun sweep: ..." (contains sweep metadata)
114
+ - Child 0: learning_rate=0.001 (sequence=0)
115
+ - Child 1: learning_rate=0.01 (sequence=1)
116
+ - Child 2: learning_rate=0.1 (sequence=2)
117
+
118
+ Query nested executions via the catalog or MCP tools:
119
+ - list_nested_executions(parent_rid)
120
+ - list_parent_executions(child_rid)
121
+
122
+ Configuration Groups
123
+ --------------------
124
+ The default hydra_defaults in create_model_config() expect these config groups:
125
+
126
+ - deriva_ml: Connection settings (hostname, catalog_id, credentials)
127
+ - datasets: Dataset specifications (RIDs, versions)
128
+ - assets: Asset RIDs (model weights, etc.)
129
+ - workflow: Workflow definition (name, type, description)
130
+ - model_config: Model parameters (your model's config)
131
+
132
+ Each group should have at least a "default_*" entry. Override at runtime:
133
+
134
+ python my_runner.py deriva_ml=production datasets=full_training
135
+
136
+ See Also
137
+ --------
138
+ - DerivaMLModel protocol: defines the expected model signature
139
+ - ExecutionConfiguration: bundles inputs for an execution
140
+ - Execution: context manager for execution lifecycle
141
+ """
142
+
143
+ from __future__ import annotations
144
+
145
+ import atexit
146
+ import logging
147
+ from pathlib import Path
148
+ from typing import Any, TypeVar, TYPE_CHECKING
149
+
150
+ from hydra.core.hydra_config import HydraConfig
151
+ from hydra_zen import builds
152
+
153
+ if TYPE_CHECKING:
154
+ from deriva_ml import DerivaML
155
+ from deriva_ml.core.config import DerivaMLConfig
156
+ from deriva_ml.dataset import DatasetSpec
157
+ from deriva_ml.execution import ExecutionConfiguration, Workflow
158
+ from deriva_ml.core.definitions import RID
159
+
160
+
161
+ # Type variable for DerivaML and its subclasses
162
+ T = TypeVar("T", bound="DerivaML")
163
+
164
+
165
+ # =============================================================================
166
+ # Multirun State Management
167
+ # =============================================================================
168
+
169
+ class MultirunState:
170
+ """Manages state for multirun (sweep) executions.
171
+
172
+ In multirun mode, we create a parent execution that groups all the
173
+ individual sweep jobs. This class holds the shared state needed to
174
+ coordinate between jobs.
175
+
176
+ Attributes:
177
+ parent_execution_rid: RID of the parent execution (created on first job)
178
+ parent_execution: The parent Execution object
179
+ ml_instance: Shared DerivaML instance
180
+ job_sequence: Counter for ordering child executions
181
+ sweep_dir: Path to the sweep output directory
182
+ """
183
+ parent_execution_rid: str | None = None
184
+ parent_execution: Any = None
185
+ ml_instance: Any = None # DerivaML or subclass
186
+ job_sequence: int = 0
187
+ sweep_dir: Path | None = None
188
+
189
+
190
+ # Global instance - persists across jobs in a multirun
191
+ _multirun_state = MultirunState()
192
+
193
+
194
+ def _is_multirun() -> bool:
195
+ """Check if we're running in Hydra multirun mode."""
196
+ try:
197
+ hydra_cfg = HydraConfig.get()
198
+ # RunMode.MULTIRUN has value 2
199
+ return hydra_cfg.mode.value == 2
200
+ except Exception:
201
+ return False
202
+
203
+
204
+ def _get_job_num() -> int:
205
+ """Get the current job number in a multirun."""
206
+ try:
207
+ hydra_cfg = HydraConfig.get()
208
+ return hydra_cfg.job.num
209
+ except Exception:
210
+ return 0
211
+
212
+
213
+ def _complete_parent_execution() -> None:
214
+ """Complete the parent execution at the end of a multirun sweep.
215
+
216
+ This is registered as an atexit handler to ensure the parent execution
217
+ is properly completed and its outputs uploaded when the process exits.
218
+ """
219
+ global _multirun_state
220
+
221
+ if _multirun_state.parent_execution is None:
222
+ return
223
+
224
+ try:
225
+ parent = _multirun_state.parent_execution
226
+
227
+ # Stop the parent execution timing
228
+ parent.execution_stop()
229
+
230
+ # Upload any outputs and clean up
231
+ parent.upload_execution_outputs()
232
+
233
+ logging.info(
234
+ f"Completed parent execution: {_multirun_state.parent_execution_rid} "
235
+ f"({_multirun_state.job_sequence} child jobs)"
236
+ )
237
+ except Exception as e:
238
+ logging.warning(f"Failed to complete parent execution: {e}")
239
+ finally:
240
+ # Clear the state
241
+ reset_multirun_state()
242
+
243
+
244
+ # Track if atexit handler is registered
245
+ _atexit_registered = False
246
+
247
+
248
+ def _create_parent_execution(
249
+ ml_instance: "DerivaML",
250
+ workflow: "Workflow",
251
+ description: str,
252
+ dry_run: bool = False,
253
+ ) -> None:
254
+ """Create the parent execution for a multirun sweep.
255
+
256
+ This is called on the first job of a multirun to create the parent
257
+ execution that will group all child executions together.
258
+
259
+ Args:
260
+ ml_instance: The DerivaML (or subclass) instance.
261
+ workflow: The workflow to associate with the parent execution.
262
+ description: Description for the parent execution. When using multirun_config,
263
+ this is the rich markdown description from the config.
264
+ dry_run: If True, don't write to the catalog.
265
+ """
266
+ global _multirun_state, _atexit_registered
267
+
268
+ # Import here to avoid circular imports
269
+ from deriva_ml.execution import ExecutionConfiguration
270
+
271
+ # Use the description directly - it comes from multirun_config or the CLI
272
+ parent_description = description
273
+
274
+ # Create parent execution configuration (no datasets - those are for children)
275
+ parent_config = ExecutionConfiguration(
276
+ description=parent_description,
277
+ )
278
+
279
+ # Create the parent execution
280
+ parent_execution = ml_instance.create_execution(
281
+ parent_config,
282
+ workflow=workflow,
283
+ dry_run=dry_run,
284
+ )
285
+
286
+ # Start the parent execution
287
+ parent_execution.execution_start()
288
+
289
+ # Store in global state
290
+ _multirun_state.parent_execution = parent_execution
291
+ _multirun_state.parent_execution_rid = parent_execution.execution_rid
292
+ _multirun_state.ml_instance = ml_instance
293
+
294
+ # Register atexit handler to complete parent execution when process exits
295
+ if not _atexit_registered:
296
+ atexit.register(_complete_parent_execution)
297
+ _atexit_registered = True
298
+
299
+ logging.info(f"Created parent execution: {parent_execution.execution_rid}")
300
+
301
+
302
+ def run_model(
303
+ deriva_ml: "DerivaMLConfig",
304
+ datasets: list["DatasetSpec"],
305
+ assets: list["RID"],
306
+ description: str,
307
+ workflow: "Workflow",
308
+ model_config: Any,
309
+ dry_run: bool = False,
310
+ ml_class: type["DerivaML"] | None = None,
311
+ ) -> None:
312
+ """
313
+ Execute a machine learning model within a DerivaML execution context.
314
+
315
+ This function serves as the main entry point called by hydra-zen after
316
+ configuration resolution. It orchestrates the complete execution lifecycle:
317
+ connecting to Deriva, creating an execution record, running the model,
318
+ and uploading results.
319
+
320
+ In multirun mode, this function also:
321
+ - Creates a parent execution on the first job to group all sweep jobs
322
+ - Links each child execution to the parent with sequence ordering
323
+
324
+ Parameters
325
+ ----------
326
+ deriva_ml : DerivaMLConfig
327
+ Configuration for the DerivaML connection. Contains server URL,
328
+ catalog ID, credentials, and other connection parameters.
329
+
330
+ datasets : list[DatasetSpec]
331
+ Specifications for datasets to use in this execution. Each DatasetSpec
332
+ identifies a dataset in the Deriva catalog to be made available to
333
+ the model.
334
+
335
+ assets : list[RID]
336
+ Resource IDs (RIDs) of assets to include in the execution. Typically
337
+ used for model weight files, pretrained checkpoints, or other
338
+ artifacts needed by the model.
339
+
340
+ description : str
341
+ Human-readable description of this execution run. Stored in the
342
+ Deriva catalog for provenance tracking. In multirun mode, this is
343
+ also used for the parent execution if running via multirun_config.
344
+
345
+ workflow : Workflow
346
+ The workflow definition to associate with this execution. Defines
347
+ the computational pipeline and its metadata.
348
+
349
+ model_config : Any
350
+ A hydra-zen callable that wraps the actual model code. When called
351
+ with `ml_instance` and `execution` arguments, it runs the model
352
+ training or inference logic.
353
+
354
+ dry_run : bool, optional
355
+ If True, create the execution record but skip actual model execution.
356
+ Useful for testing configuration without running expensive computations.
357
+ Default is False.
358
+
359
+ ml_class : type[DerivaML], optional
360
+ The DerivaML class (or subclass) to instantiate. If None, uses the
361
+ base DerivaML class. Use this to instantiate domain-specific classes
362
+ like EyeAI or GUDMAP.
363
+
364
+ Returns
365
+ -------
366
+ None
367
+ Results are uploaded to the Deriva catalog as execution outputs.
368
+
369
+ Examples
370
+ --------
371
+ This function is typically not called directly, but through hydra:
372
+
373
+ # From command line:
374
+ python deriva_run.py +experiment=cifar10_cnn dry_run=True
375
+
376
+ # Multirun (creates parent + child executions):
377
+ python deriva_run.py --multirun +experiment=cifar10_quick,cifar10_extended
378
+
379
+ # With a custom DerivaML subclass (in your script):
380
+ from functools import partial
381
+ run_model_eyeai = partial(run_model, ml_class=EyeAI)
382
+ """
383
+ global _multirun_state
384
+
385
+ # Import here to avoid circular imports
386
+ from deriva_ml import DerivaML
387
+ from deriva_ml.execution import ExecutionConfiguration
388
+
389
+ # ---------------------------------------------------------------------------
390
+ # Clear hydra's logging configuration
391
+ # ---------------------------------------------------------------------------
392
+ # Hydra sets up its own logging handlers which can interfere with DerivaML's
393
+ # logging. Remove them to ensure consistent log output.
394
+ root = logging.getLogger()
395
+ for handler in root.handlers[:]:
396
+ root.removeHandler(handler)
397
+
398
+ # ---------------------------------------------------------------------------
399
+ # Connect to the Deriva catalog
400
+ # ---------------------------------------------------------------------------
401
+ # Use the provided ml_class or default to DerivaML
402
+ if ml_class is None:
403
+ ml_class = DerivaML
404
+
405
+ ml_instance = ml_class.instantiate(deriva_ml)
406
+
407
+ # ---------------------------------------------------------------------------
408
+ # Handle multirun mode - create parent execution on first job
409
+ # ---------------------------------------------------------------------------
410
+ is_multirun = _is_multirun()
411
+ if is_multirun and _multirun_state.parent_execution is None:
412
+ _create_parent_execution(ml_instance, workflow, description, dry_run)
413
+
414
+ # ---------------------------------------------------------------------------
415
+ # Capture Hydra runtime choices for provenance
416
+ # ---------------------------------------------------------------------------
417
+ # The choices dict maps config group names to the selected config names
418
+ # e.g., {"model_config": "cifar10_quick", "datasets": "cifar10_training"}
419
+ # Filter out None values (some Hydra internal groups have None choices)
420
+ config_choices: dict[str, str] = {}
421
+ try:
422
+ hydra_cfg = HydraConfig.get()
423
+ config_choices = {k: v for k, v in hydra_cfg.runtime.choices.items() if v is not None}
424
+ except Exception:
425
+ pass # HydraConfig not available outside Hydra context
426
+
427
+ # ---------------------------------------------------------------------------
428
+ # Create the execution context
429
+ # ---------------------------------------------------------------------------
430
+ # The ExecutionConfiguration bundles together all the inputs for this run:
431
+ # which datasets to use, which assets (model weights, etc.), and metadata.
432
+
433
+ # In multirun mode, enhance the description with job info
434
+ job_description = description
435
+ if is_multirun:
436
+ job_num = _get_job_num()
437
+ job_description = f"[Job {job_num}] {description}"
438
+
439
+ execution_config = ExecutionConfiguration(
440
+ datasets=datasets,
441
+ assets=assets,
442
+ description=job_description,
443
+ config_choices=config_choices,
444
+ )
445
+
446
+ # Create the execution record in the catalog. This generates a unique
447
+ # execution ID and sets up the working directories for this run.
448
+ execution = ml_instance.create_execution(
449
+ execution_config,
450
+ workflow=workflow,
451
+ dry_run=dry_run
452
+ )
453
+
454
+ # ---------------------------------------------------------------------------
455
+ # Link to parent execution in multirun mode
456
+ # ---------------------------------------------------------------------------
457
+ if is_multirun and _multirun_state.parent_execution is not None:
458
+ if not dry_run:
459
+ try:
460
+ # Get the current job sequence from the global state
461
+ job_sequence = _multirun_state.job_sequence
462
+ _multirun_state.parent_execution.add_nested_execution(
463
+ execution,
464
+ sequence=job_sequence
465
+ )
466
+ logging.info(
467
+ f"Linked execution {execution.execution_rid} to parent "
468
+ f"{_multirun_state.parent_execution_rid} (sequence={job_sequence})"
469
+ )
470
+ # Increment the sequence for the next job
471
+ _multirun_state.job_sequence += 1
472
+ except Exception as e:
473
+ logging.warning(f"Failed to link execution to parent: {e}")
474
+
475
+ # ---------------------------------------------------------------------------
476
+ # Run the model within the execution context
477
+ # ---------------------------------------------------------------------------
478
+ # The context manager handles setup (downloading datasets, creating output
479
+ # directories) and teardown (recording completion status, timing).
480
+ with execution.execute() as exec_context:
481
+ if dry_run:
482
+ # In dry run mode, skip model execution but still test the setup
483
+ logging.info("Dry run mode: skipping model execution")
484
+ else:
485
+ # Invoke the model configuration callable. The model_config is a
486
+ # hydra-zen wrapped function that has been partially configured with
487
+ # all model-specific parameters (e.g., learning rate, batch size).
488
+ # We provide the runtime context here.
489
+ model_config(ml_instance=ml_instance, execution=exec_context)
490
+
491
+ # ---------------------------------------------------------------------------
492
+ # Upload results to the catalog
493
+ # ---------------------------------------------------------------------------
494
+ # After the model completes, upload any output files (metrics, predictions,
495
+ # model checkpoints) to the Deriva catalog for permanent storage.
496
+ if not dry_run:
497
+ uploaded_assets = execution.upload_execution_outputs()
498
+
499
+ # Print summary of uploaded assets
500
+ total_files = sum(len(files) for files in uploaded_assets.values())
501
+ if total_files > 0:
502
+ print(f"\nUploaded {total_files} asset(s) to catalog:")
503
+ for asset_type, files in uploaded_assets.items():
504
+ for f in files:
505
+ print(f" - {asset_type}: {f}")
506
+
507
+
508
+ def create_model_config(
509
+ ml_class: type["DerivaML"] | None = None,
510
+ description: str = "Model execution",
511
+ hydra_defaults: list | None = None,
512
+ ) -> Any:
513
+ """Create a hydra-zen configuration for run_model.
514
+
515
+ This helper creates a properly configured hydra-zen builds() for run_model
516
+ with the specified DerivaML class bound via partial application.
517
+
518
+ Parameters
519
+ ----------
520
+ ml_class : type[DerivaML], optional
521
+ The DerivaML class (or subclass) to use. If None, uses the base DerivaML.
522
+
523
+ description : str, optional
524
+ Default description for executions. Can be overridden at runtime.
525
+
526
+ hydra_defaults : list, optional
527
+ Custom hydra defaults. If None, uses standard defaults for deriva_ml,
528
+ datasets, assets, workflow, and model_config groups.
529
+
530
+ Returns
531
+ -------
532
+ Any
533
+ A hydra-zen builds() configuration ready to be registered with store().
534
+
535
+ Examples
536
+ --------
537
+ Basic usage with DerivaML:
538
+
539
+ >>> from deriva_ml.execution.runner import create_model_config
540
+ >>> model_config = create_model_config()
541
+ >>> store(model_config, name="deriva_model")
542
+
543
+ With a custom subclass:
544
+
545
+ >>> from eye_ai import EyeAI
546
+ >>> model_config = create_model_config(EyeAI, description="EyeAI analysis")
547
+ >>> store(model_config, name="eyeai_model")
548
+
549
+ With custom hydra defaults:
550
+
551
+ >>> model_config = create_model_config(
552
+ ... hydra_defaults=[
553
+ ... "_self_",
554
+ ... {"deriva_ml": "production"},
555
+ ... {"datasets": "full_dataset"},
556
+ ... ]
557
+ ... )
558
+ """
559
+ from functools import partial
560
+
561
+ if hydra_defaults is None:
562
+ hydra_defaults = [
563
+ "_self_",
564
+ {"deriva_ml": "default_deriva"},
565
+ {"datasets": "default_dataset"},
566
+ {"assets": "default_asset"},
567
+ {"workflow": "default_workflow"},
568
+ {"model_config": "default_model"},
569
+ ]
570
+
571
+ # Create a partial function with ml_class bound
572
+ if ml_class is not None:
573
+ run_func = partial(run_model, ml_class=ml_class)
574
+ else:
575
+ run_func = run_model
576
+
577
+ return builds(
578
+ run_func,
579
+ description=description,
580
+ populate_full_signature=True,
581
+ hydra_defaults=hydra_defaults,
582
+ )
583
+
584
+
585
+ def reset_multirun_state() -> None:
586
+ """Reset the global multirun state.
587
+
588
+ This is primarily useful for testing to ensure clean state between tests.
589
+ """
590
+ global _multirun_state
591
+ _multirun_state.parent_execution_rid = None
592
+ _multirun_state.parent_execution = None
593
+ _multirun_state.ml_instance = None
594
+ _multirun_state.job_sequence = 0
595
+ _multirun_state.sweep_dir = None