expops 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. expops-0.1.3.dist-info/METADATA +826 -0
  2. expops-0.1.3.dist-info/RECORD +86 -0
  3. expops-0.1.3.dist-info/WHEEL +5 -0
  4. expops-0.1.3.dist-info/entry_points.txt +3 -0
  5. expops-0.1.3.dist-info/licenses/LICENSE +674 -0
  6. expops-0.1.3.dist-info/top_level.txt +1 -0
  7. mlops/__init__.py +0 -0
  8. mlops/__main__.py +11 -0
  9. mlops/_version.py +34 -0
  10. mlops/adapters/__init__.py +12 -0
  11. mlops/adapters/base.py +86 -0
  12. mlops/adapters/config_schema.py +89 -0
  13. mlops/adapters/custom/__init__.py +3 -0
  14. mlops/adapters/custom/custom_adapter.py +447 -0
  15. mlops/adapters/plugin_manager.py +113 -0
  16. mlops/adapters/sklearn/__init__.py +3 -0
  17. mlops/adapters/sklearn/adapter.py +94 -0
  18. mlops/cluster/__init__.py +3 -0
  19. mlops/cluster/controller.py +496 -0
  20. mlops/cluster/process_runner.py +91 -0
  21. mlops/cluster/providers.py +258 -0
  22. mlops/core/__init__.py +95 -0
  23. mlops/core/custom_model_base.py +38 -0
  24. mlops/core/dask_networkx_executor.py +1265 -0
  25. mlops/core/executor_worker.py +1239 -0
  26. mlops/core/experiment_tracker.py +81 -0
  27. mlops/core/graph_types.py +64 -0
  28. mlops/core/networkx_parser.py +135 -0
  29. mlops/core/payload_spill.py +278 -0
  30. mlops/core/pipeline_utils.py +162 -0
  31. mlops/core/process_hashing.py +216 -0
  32. mlops/core/step_state_manager.py +1298 -0
  33. mlops/core/step_system.py +956 -0
  34. mlops/core/workspace.py +99 -0
  35. mlops/environment/__init__.py +10 -0
  36. mlops/environment/base.py +43 -0
  37. mlops/environment/conda_manager.py +307 -0
  38. mlops/environment/factory.py +70 -0
  39. mlops/environment/pyenv_manager.py +146 -0
  40. mlops/environment/setup_env.py +31 -0
  41. mlops/environment/system_manager.py +66 -0
  42. mlops/environment/utils.py +105 -0
  43. mlops/environment/venv_manager.py +134 -0
  44. mlops/main.py +527 -0
  45. mlops/managers/project_manager.py +400 -0
  46. mlops/managers/reproducibility_manager.py +575 -0
  47. mlops/platform.py +996 -0
  48. mlops/reporting/__init__.py +16 -0
  49. mlops/reporting/context.py +187 -0
  50. mlops/reporting/entrypoint.py +292 -0
  51. mlops/reporting/kv_utils.py +77 -0
  52. mlops/reporting/registry.py +50 -0
  53. mlops/runtime/__init__.py +9 -0
  54. mlops/runtime/context.py +34 -0
  55. mlops/runtime/env_export.py +113 -0
  56. mlops/storage/__init__.py +12 -0
  57. mlops/storage/adapters/__init__.py +9 -0
  58. mlops/storage/adapters/gcp_kv_store.py +778 -0
  59. mlops/storage/adapters/gcs_object_store.py +96 -0
  60. mlops/storage/adapters/memory_store.py +240 -0
  61. mlops/storage/adapters/redis_store.py +438 -0
  62. mlops/storage/factory.py +199 -0
  63. mlops/storage/interfaces/__init__.py +6 -0
  64. mlops/storage/interfaces/kv_store.py +118 -0
  65. mlops/storage/path_utils.py +38 -0
  66. mlops/templates/premier-league/charts/plot_metrics.js +70 -0
  67. mlops/templates/premier-league/charts/plot_metrics.py +145 -0
  68. mlops/templates/premier-league/charts/requirements.txt +6 -0
  69. mlops/templates/premier-league/configs/cluster_config.yaml +13 -0
  70. mlops/templates/premier-league/configs/project_config.yaml +207 -0
  71. mlops/templates/premier-league/data/England CSV.csv +12154 -0
  72. mlops/templates/premier-league/models/premier_league_model.py +638 -0
  73. mlops/templates/premier-league/requirements.txt +8 -0
  74. mlops/templates/sklearn-basic/README.md +22 -0
  75. mlops/templates/sklearn-basic/charts/plot_metrics.py +85 -0
  76. mlops/templates/sklearn-basic/charts/requirements.txt +3 -0
  77. mlops/templates/sklearn-basic/configs/project_config.yaml +64 -0
  78. mlops/templates/sklearn-basic/data/train.csv +14 -0
  79. mlops/templates/sklearn-basic/models/model.py +62 -0
  80. mlops/templates/sklearn-basic/requirements.txt +10 -0
  81. mlops/web/__init__.py +3 -0
  82. mlops/web/server.py +585 -0
  83. mlops/web/ui/index.html +52 -0
  84. mlops/web/ui/mlops-charts.js +357 -0
  85. mlops/web/ui/script.js +1244 -0
  86. mlops/web/ui/styles.css +248 -0
@@ -0,0 +1,956 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional, Callable, Union
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ import json
6
+ import joblib
7
+ import time
8
+ from datetime import datetime
9
+ import logging
10
+ from functools import wraps
11
+ import contextvars
12
+
13
+ from .payload_spill import hydrate_payload_refs
14
+
15
+ SerializableData = Union[
16
+ str, int, float, bool,
17
+ List[Union[str, int, float, bool]],
18
+ Dict[str, Union[str, int, float, bool]],
19
+ None
20
+ ]
21
+
22
+ ModelData = Any
23
+
24
+ class StepContext:
25
+ """Simplified context object for steps."""
26
+
27
+ def __init__(self,
28
+ project_id: str,
29
+ run_id: str = None,
30
+ tracker: Any = None,
31
+ step_results: Dict[str, Dict[str, Any]] = None,
32
+ global_config: Dict[str, Any] = None,
33
+ data_paths: Dict[str, Path] = None,
34
+ checkpoint_dir: Optional[Path] = None):
35
+ self.project_id = project_id
36
+ self.run_id = run_id
37
+ self.tracker = tracker
38
+ self.step_results = step_results or {}
39
+ self.global_config = global_config or {}
40
+ self.data_paths = data_paths or {}
41
+ self.shared_state = {}
42
+ self.checkpoint_dir = checkpoint_dir or Path("artifacts/checkpoints")
43
+ self.iteration = 0
44
+ # Name of the process currently being executed; used for resolving per-process settings
45
+ self.current_process: Optional[str] = None
46
+
47
+ def _hydrate_payload(self, payload: Any) -> Any:
48
+ try:
49
+ sm = get_state_manager()
50
+ except Exception:
51
+ sm = None
52
+ try:
53
+ return hydrate_payload_refs(payload, sm)
54
+ except Exception:
55
+ return payload
56
+
57
+ def _resolve_process_for_step(self, step_name: str) -> Optional[str]:
58
+ try:
59
+ pr = get_process_registry()
60
+ return pr.get_process_for_step(step_name) if pr else None
61
+ except Exception:
62
+ return None
63
+
64
+ def get_step_result(self, step_name: str) -> Optional[Dict[str, Any]]:
65
+ """Get result from a previously executed step (returns data dictionary)."""
66
+ result = self.step_results.get(step_name)
67
+ if result is None and getattr(self, 'run_id', None):
68
+ try:
69
+ sm = get_state_manager()
70
+ except Exception:
71
+ sm = None
72
+ if sm:
73
+ try:
74
+ cached = sm.get_cached_step_result(
75
+ run_id=self.run_id,
76
+ step_name=step_name,
77
+ process_name=self._resolve_process_for_step(step_name),
78
+ input_hash=None,
79
+ config_hash=None,
80
+ function_hash=None,
81
+ )
82
+ if cached:
83
+ self.step_results[step_name] = cached
84
+ result = cached
85
+ except Exception:
86
+ pass
87
+ if result is None:
88
+ return None
89
+ hydrated = self._hydrate_payload(result)
90
+ if hydrated is not result:
91
+ self.step_results[step_name] = hydrated
92
+ return hydrated
93
+
94
+ def get_step_data(self, step_name: str, data_key: str, process_name: Optional[str] = None) -> Any:
95
+ """Get specific data output from a previous step.
96
+
97
+ In distributed mode, if the step result is not present in the in-memory
98
+ context (e.g., from a prior process executed on a different worker), this
99
+ method attempts to load the step result from the cache for the current run
100
+ and hydrate it into the context for subsequent accesses.
101
+ """
102
+ step_result = self.get_step_result(step_name)
103
+ if isinstance(step_result, dict):
104
+ try:
105
+ return step_result.get(data_key)
106
+ except Exception:
107
+ return None
108
+ return None
109
+
110
+ def get_process_data(self, process_name: str, data_key: str) -> Any:
111
+ """Get specific data output returned by a previous process.
112
+
113
+ Falls back to loading the process result from cache for this run when
114
+ not available in memory (distributed mode).
115
+ """
116
+ try:
117
+ proc_result = self.step_results.get(process_name)
118
+ except Exception:
119
+ proc_result = None
120
+ if proc_result and isinstance(proc_result, dict):
121
+ hydrated = self._hydrate_payload(proc_result)
122
+ if hydrated is not proc_result:
123
+ self.step_results[process_name] = hydrated
124
+ try:
125
+ return hydrated.get(data_key)
126
+ except Exception:
127
+ return None
128
+
129
+ # Not present in memory, try cache
130
+ try:
131
+ sm = get_state_manager()
132
+ except Exception:
133
+ sm = None
134
+ if sm and getattr(self, 'run_id', None):
135
+ try:
136
+ # Generic cached lookup via known API
137
+ loaded = sm.get_cached_process_result(process_name, input_hash=None, config_hash=None, function_hash=None)
138
+ if loaded:
139
+ self.step_results[process_name] = loaded
140
+ hydrated = self._hydrate_payload(loaded)
141
+ if hydrated is not loaded:
142
+ self.step_results[process_name] = hydrated
143
+ try:
144
+ return hydrated.get(data_key)
145
+ except Exception:
146
+ return None
147
+ except Exception:
148
+ return None
149
+ return None
150
+
151
+ def log_metric(self, key: str, value: Any, step: Optional[int] = None) -> None:
152
+ """Log a metric with MLflow-style step tracking.
153
+
154
+ Args:
155
+ key: Metric name
156
+ value: Metric value. Numeric values are tracked per step; non-numeric values
157
+ (e.g., lists, dicts) are stored as last-snapshot under a shadow key
158
+ with "__last" suffix in the KV store.
159
+ step: Step number (if None, auto-increments from largest existing step).
160
+ """
161
+ # Enforce logging flag from current step/process
162
+ try:
163
+ from .step_system import (
164
+ get_current_process_context as _get_cproc,
165
+ get_current_step_context as _get_cstep,
166
+ get_step_registry as _get_sreg,
167
+ get_process_registry as _get_preg,
168
+ )
169
+ cur_step = _get_cstep()
170
+ cur_proc = _get_cproc()
171
+ if cur_step:
172
+ try:
173
+ sdef = _get_sreg().get_step(cur_step)
174
+ if sdef is not None and (getattr(sdef, 'logging', True) is False):
175
+ raise RuntimeError(f"Metric logging is disabled for step '{cur_step}' (logging=False).")
176
+ except Exception:
177
+ pass
178
+ if cur_proc:
179
+ try:
180
+ pdef = _get_preg().get_process(cur_proc)
181
+ if pdef is not None and (getattr(pdef, 'logging', True) is False):
182
+ raise RuntimeError(f"Metric logging is disabled for process '{cur_proc}' (logging=False).")
183
+ except Exception:
184
+ pass
185
+ except RuntimeError:
186
+ # Re-raise explicit logging disabled errors
187
+ raise
188
+ except Exception:
189
+ # On any inspection failure, fall through and attempt to log
190
+ pass
191
+
192
+
193
+ try:
194
+ from .step_system import get_state_manager, get_current_process_context, get_current_step_context
195
+ state_manager = get_state_manager()
196
+ if state_manager and self.run_id:
197
+ process_name = get_current_process_context()
198
+ step_name = get_current_step_context()
199
+ try:
200
+ import logging as _logging
201
+ _logger = _logging.getLogger(__name__)
202
+ kv_cls = type(getattr(state_manager, 'kv_store', None)).__name__ if getattr(state_manager, 'kv_store', None) is not None else 'None'
203
+ _logger.info(f"[Metrics] log_metric call -> run_id={self.run_id}, process={process_name}, step={step_name}, key={key}, step_idx={step if step is not None else 'auto'}, kv_store={kv_cls}")
204
+ except Exception:
205
+ pass
206
+ state_manager.log_metric(
207
+ run_id=self.run_id,
208
+ process_name=process_name,
209
+ step_name=step_name,
210
+ metric_name=key,
211
+ value=value,
212
+ step=step
213
+ )
214
+ except Exception as e:
215
+ import logging
216
+ logging.getLogger(__name__).warning(f"Failed to log metric to KV store: {e}")
217
+
218
+ def log_param(self, key: str, value: Union[int, float, str, bool]) -> None:
219
+ """Log a parameter to the experiment tracker."""
220
+ tracker = getattr(self, "tracker", None)
221
+ if not tracker:
222
+ return
223
+ fn = getattr(tracker, "log_param", None)
224
+ if callable(fn):
225
+ try:
226
+ fn(key, value)
227
+ except Exception:
228
+ pass
229
+
230
+ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]:
231
+ """Load a checkpoint from the given path."""
232
+ try:
233
+ checkpoint_data = joblib.load(checkpoint_path)
234
+ print(f"[Checkpoint] Model loaded from: {checkpoint_path}")
235
+ return checkpoint_data
236
+ except Exception as e:
237
+ print(f"[Checkpoint] Failed to load checkpoint {checkpoint_path}: {e}")
238
+ raise
239
+
240
+ def get_hyperparameters(self, process_name: Optional[str] = None) -> Dict[str, Any]:
241
+ """Return hyperparameters with per-process overrides taking precedence.
242
+
243
+ Resolution order (later overrides earlier):
244
+ 1) Global hyperparameters from context (supports both legacy and current layouts)
245
+ 2) Process-specific hyperparameters from pipeline.processes[name].hyperparameters
246
+ Args:
247
+ process_name: Explicit process name; if None, uses context.current_process
248
+ Returns:
249
+ Merged hyperparameters dict
250
+ """
251
+ try:
252
+ # Support current layout (parameters model dumped directly) and legacy layout nested under model.parameters
253
+ global_hp = {}
254
+ try:
255
+ if isinstance(self.global_config, dict):
256
+ if 'hyperparameters' in self.global_config:
257
+ maybe = self.global_config.get('hyperparameters')
258
+ if isinstance(maybe, dict):
259
+ global_hp = dict(maybe)
260
+ elif 'model' in self.global_config:
261
+ maybe = (
262
+ self.global_config.get('model', {})
263
+ .get('parameters', {})
264
+ .get('hyperparameters', {})
265
+ )
266
+ if isinstance(maybe, dict):
267
+ global_hp = dict(maybe)
268
+ except Exception:
269
+ global_hp = {}
270
+
271
+ proc_hp = {}
272
+ try:
273
+ proc = process_name or getattr(self, 'current_process', None)
274
+ pipeline_cfg = self.global_config.get('pipeline') if isinstance(self.global_config, dict) else None
275
+ processes_list = (pipeline_cfg or {}).get('processes') if isinstance(pipeline_cfg, dict) else None
276
+ if isinstance(processes_list, list) and proc:
277
+ for p in processes_list:
278
+ try:
279
+ if isinstance(p, dict) and p.get('name') == proc:
280
+ maybe = p.get('hyperparameters')
281
+ if isinstance(maybe, dict):
282
+ proc_hp = dict(maybe)
283
+ break
284
+ except Exception:
285
+ continue
286
+ except Exception:
287
+ proc_hp = {}
288
+
289
+ merged = {}
290
+ try:
291
+ merged.update(global_hp)
292
+ except Exception:
293
+ pass
294
+ try:
295
+ merged.update(proc_hp)
296
+ except Exception:
297
+ pass
298
+ return merged
299
+ except Exception:
300
+ return {}
301
+
302
+
303
+ class StepContextFactory:
304
+ """
305
+ Singleton factory for managing StepContext instances.
306
+
307
+ This factory ensures that each project has a unique context instance
308
+ and provides automatic context management for steps.
309
+ """
310
+
311
+ _instance = None
312
+ _contexts: Dict[tuple[str, str], StepContext] = {}
313
+
314
+ def __new__(cls):
315
+ if cls._instance is None:
316
+ cls._instance = super().__new__(cls)
317
+ return cls._instance
318
+
319
+ def create_context(self,
320
+ project_id: str,
321
+ run_id: str,
322
+ tracker: Any = None,
323
+ step_results: Dict[str, Dict[str, Any]] = None,
324
+ global_config: Dict[str, Any] = None,
325
+ data_paths: Dict[str, Path] = None,
326
+ checkpoint_dir: Optional[Path] = None) -> StepContext:
327
+ """
328
+ Create or get a context for a specific project.
329
+
330
+ Args:
331
+ project_id: Unique project identifier (context key)
332
+ run_id: Unique run identifier (stored in context state)
333
+ tracker: Experiment tracker instance
334
+ step_results: Dictionary of step results (data dictionaries)
335
+ global_config: Global configuration
336
+ data_paths: Paths to data files
337
+ checkpoint_dir: Directory for checkpoints
338
+
339
+ Returns:
340
+ StepContext instance for the project
341
+ """
342
+ run_key = str(run_id or "default")
343
+ key = (str(project_id), run_key)
344
+
345
+ if key not in self._contexts:
346
+ ctx = StepContext(
347
+ project_id=project_id,
348
+ run_id=run_id,
349
+ tracker=tracker,
350
+ step_results=step_results or {},
351
+ global_config=global_config or {},
352
+ data_paths=data_paths or {},
353
+ checkpoint_dir=checkpoint_dir,
354
+ )
355
+ self._contexts[key] = ctx
356
+ return ctx
357
+
358
+ # Reuse the run-scoped context, but allow callers to refresh its references.
359
+ ctx = self._contexts[key]
360
+ try:
361
+ ctx.run_id = run_id
362
+ except Exception:
363
+ pass
364
+ if tracker is not None:
365
+ try:
366
+ ctx.tracker = tracker
367
+ except Exception:
368
+ pass
369
+ if step_results is not None:
370
+ try:
371
+ ctx.step_results = step_results
372
+ except Exception:
373
+ pass
374
+ if global_config is not None:
375
+ try:
376
+ ctx.global_config = global_config
377
+ except Exception:
378
+ pass
379
+ if data_paths is not None:
380
+ try:
381
+ ctx.data_paths = data_paths
382
+ except Exception:
383
+ pass
384
+ if checkpoint_dir is not None:
385
+ try:
386
+ ctx.checkpoint_dir = checkpoint_dir
387
+ except Exception:
388
+ pass
389
+ return ctx
390
+
391
+
392
+ @dataclass
393
+ class StepDefinition:
394
+ """Definition of a step."""
395
+ name: str
396
+ func: Callable
397
+ step_type: str = "general"
398
+ process_name: Optional[str] = None
399
+ original_func: Optional[Callable] = field(default=None, init=False)
400
+ inputs: List[str] = field(default_factory=list)
401
+ outputs: List[str] = field(default_factory=list)
402
+ condition: Optional[str] = None
403
+ logging: bool = True
404
+
405
+
406
+ @dataclass
407
+ class ProcessDefinition:
408
+ """Definition of a process that groups related steps."""
409
+ name: str
410
+ description: str
411
+ parameters: Dict[str, Any] = field(default_factory=dict)
412
+ step_names: List[str] = field(default_factory=list)
413
+ runner: Optional[Callable] = None
414
+ original_func: Optional[Callable] = field(default=None, init=False)
415
+ logging: bool = True
416
+
417
+
418
+ class StepRegistry:
419
+ """Simple registry for step functions."""
420
+
421
+ def __init__(self):
422
+ self._steps: Dict[str, StepDefinition] = {}
423
+
424
+ def register_step(self, step_def: StepDefinition) -> None:
425
+ """Register a step definition."""
426
+ self._steps[step_def.name] = step_def
427
+
428
+ def get_step(self, name: str) -> Optional[StepDefinition]:
429
+ """Get a step definition by name."""
430
+ return self._steps.get(name)
431
+
432
+ def list_steps(self) -> List[str]:
433
+ """List all registered step names."""
434
+ return list(self._steps.keys())
435
+
436
+
437
+ class ProcessRegistry:
438
+ """Simple registry for process definitions."""
439
+
440
+ def __init__(self):
441
+ self._processes: Dict[str, ProcessDefinition] = {}
442
+ self._step_to_process: Dict[str, str] = {}
443
+
444
+ def register_process(self, process_def: ProcessDefinition) -> None:
445
+ """Register a process definition."""
446
+ self._processes[process_def.name] = process_def
447
+ for step_name in process_def.step_names:
448
+ self._step_to_process[step_name] = process_def.name
449
+
450
+ def get_process(self, name: str) -> Optional[ProcessDefinition]:
451
+ """Get a process definition by name."""
452
+ return self._processes.get(name)
453
+
454
+ def get_process_for_step(self, step_name: str) -> Optional[str]:
455
+ """Get the process name that contains a given step."""
456
+ return self._step_to_process.get(step_name)
457
+
458
+
459
+ _step_registry = StepRegistry()
460
+ _process_registry = ProcessRegistry()
461
+
462
+ # Use contextvars for thread-safe context propagation in Dask workers
463
+ _current_process_context: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('current_process_context', default=None)
464
+ _current_step_context: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('current_step_context', default=None)
465
+ _current_context: contextvars.ContextVar[Optional[StepContext]] = contextvars.ContextVar('current_context', default=None)
466
+ _current_state_manager: contextvars.ContextVar[Any] = contextvars.ContextVar('current_state_manager', default=None)
467
+
468
+ _context_factory = StepContextFactory()
469
+
470
+
471
+ def get_context_factory() -> StepContextFactory:
472
+ """Get the global context factory instance."""
473
+ return _context_factory
474
+
475
+
476
+ def set_current_context(context: StepContext) -> None:
477
+ """Set the current active context (thread-safe)."""
478
+ _current_context.set(context)
479
+
480
+
481
+ def get_current_context() -> Optional[StepContext]:
482
+ """Get the current active context (thread-safe)."""
483
+ return _current_context.get()
484
+
485
+
486
+ def set_current_process_context(process_name: Optional[str]) -> None:
487
+ """Set the current process context name for step registration/caching (thread-safe)."""
488
+ _current_process_context.set(process_name)
489
+
490
+
491
+ def get_current_process_context() -> Optional[str]:
492
+ """Get the current process context name (thread-safe)."""
493
+ return _current_process_context.get()
494
+
495
+
496
+ def set_current_step_context(step_name: Optional[str]) -> None:
497
+ """Set the current step context name (thread-safe)."""
498
+ _current_step_context.set(step_name)
499
+
500
+
501
+ def get_current_step_context() -> Optional[str]:
502
+ """Get the current step context name (thread-safe)."""
503
+ return _current_step_context.get()
504
+
505
+
506
+ def log_metric(key: str, value: Any, step: Optional[int] = None) -> None:
507
+ """Convenience function to log metrics from anywhere in your model code.
508
+
509
+ This is the main function users should call to log metrics during training.
510
+
511
+ Args:
512
+ key: Metric name (e.g., 'loss', 'accuracy', 'learning_rate')
513
+ value: Metric value. Numeric values are tracked per step; non-numeric values
514
+ (e.g., lists, dicts) are stored as last-snapshot under a shadow key
515
+ with "__last" suffix in the KV store.
516
+ step: Step/iteration number (optional). If not provided, auto-increments from largest existing step.
517
+
518
+ Example:
519
+ >>> from mlops.core import log_metric
520
+ >>> for epoch in range(100):
521
+ >>> loss = train_one_epoch()
522
+ >>> log_metric('loss', loss, step=epoch+1)
523
+ """
524
+ ctx = get_current_context()
525
+ if ctx:
526
+ ctx.log_metric(key, value, step=step)
527
+ else:
528
+ import logging
529
+ logging.getLogger(__name__).warning(
530
+ "log_metric called but no context is active. Metric will not be logged."
531
+ )
532
+
533
+
534
+ def set_state_manager(state_manager: Any) -> None:
535
+ """Provide global access to the current StepStateManager for manual step caching (thread-safe)."""
536
+ _current_state_manager.set(state_manager)
537
+
538
+
539
+ def get_state_manager() -> Any:
540
+ """Get the current StepStateManager if available (thread-safe)."""
541
+ return _current_state_manager.get()
542
+
543
+
544
+ def process(description: str = "", parameters: Dict[str, Any] = None, logging: bool = True):
545
+ """
546
+ Simplified decorator to define a process and group the steps.
547
+ """
548
+ def decorator(func):
549
+ # Use function name as the process name for registry lookup
550
+ name = func.__name__
551
+ def _wrapped_runner(*args, **kwargs):
552
+ prev = get_current_process_context()
553
+ if prev is None:
554
+ set_current_process_context(name)
555
+ try:
556
+ _kwargs = dict(kwargs)
557
+ try:
558
+ _ctx = get_current_context()
559
+ except Exception:
560
+ _ctx = None
561
+ if _ctx is not None:
562
+ _kwargs = _inject_context_and_hparams(func, _ctx, _kwargs)
563
+ result = func(*args, **_kwargs)
564
+ if result is None:
565
+ raise ValueError(f"Process '{name}' must return a dictionary of data.")
566
+ # Validate result is a dictionary
567
+ if not isinstance(result, dict):
568
+ raise ValueError(f"Process '{name}' must return a dictionary, got {type(result).__name__}.")
569
+ # Attach to context
570
+ try:
571
+ ctx = get_current_context()
572
+ if ctx:
573
+ ctx.step_results[name] = result
574
+ except Exception:
575
+ # Do not block process execution on context issues
576
+ pass
577
+ return result
578
+ finally:
579
+ # Only restore if we actually changed it (i.e., if it was None before)
580
+ if prev is None:
581
+ set_current_process_context(prev)
582
+
583
+ process_def = ProcessDefinition(
584
+ name=name,
585
+ description=description,
586
+ parameters=parameters or {},
587
+ step_names=[],
588
+ runner=_wrapped_runner,
589
+ logging=logging,
590
+ )
591
+ process_def.original_func = func
592
+ _process_registry.register_process(process_def)
593
+
594
+ return _wrapped_runner
595
+
596
+ return decorator
597
+
598
+
599
+ def step(name: str = None,
600
+ step_type: str = "general",
601
+ logging: bool = True):
602
+ """
603
+ Simplified decorator to register a function as a step.
604
+ """
605
+ def decorator(func: Callable) -> Callable:
606
+ step_name = name or func.__name__
607
+
608
+ @wraps(func)
609
+ def wrapper(*args, **kwargs):
610
+ import inspect
611
+ sig = inspect.signature(func)
612
+
613
+ # (removed) has_var_keyword detection - unused
614
+
615
+ ctx = kwargs.get('context')
616
+ if not ctx:
617
+ ctx = get_current_context()
618
+
619
+ if 'context' in sig.parameters and 'context' not in kwargs:
620
+ if ctx:
621
+ kwargs['context'] = ctx
622
+
623
+ # Attempt step-level cache lookup for manual calls
624
+ result: Dict[str, Any]
625
+ try:
626
+ from .step_state_manager import StepExecutionResult # Local import to avoid circular import at module load
627
+ except Exception:
628
+ StepExecutionResult = None # type: ignore
629
+ try:
630
+ state_manager = get_state_manager()
631
+ except Exception:
632
+ state_manager = None
633
+
634
+ # Lazily initialize a state manager on workers when missing so step-level
635
+ # caching works in distributed mode. This is crucial for distributed execution.
636
+ if state_manager is None and ctx is not None:
637
+ try:
638
+ state_manager = _init_worker_state_manager_if_needed(ctx)
639
+ except Exception:
640
+ state_manager = get_state_manager()
641
+
642
+ step_key = getattr(wrapper, '_step_name', name or func.__name__)
643
+ # Resolve process_name at runtime: prefer current process context over decoration-time value
644
+ # This allows steps to be defined outside processes and pick up the process they're called from
645
+ runtime_process = get_current_process_context()
646
+ decoration_process = getattr(getattr(wrapper, '_step_definition', None), 'process_name', None)
647
+ process_name_for_step = runtime_process if runtime_process is not None else decoration_process
648
+
649
+ cached_used = False
650
+ input_hash = None
651
+ config_hash = None
652
+ function_hash = None
653
+
654
+ if state_manager and ctx and step_key:
655
+ try:
656
+ # Compute hashes similar to executor
657
+ try:
658
+ import inspect as _inspect
659
+ # Prefer original user function signature to bind call-time args
660
+ _orig_func = getattr(wrapper, '_original_func', func)
661
+ _sig = _inspect.signature(_orig_func)
662
+ _bound = _sig.bind_partial(*args, **kwargs)
663
+ # Exclude context from hashing
664
+ _call_params = {k: v for k, v in _bound.arguments.items() if k != 'context'}
665
+ input_hash = state_manager._compute_hash(_call_params)
666
+ except Exception:
667
+ input_hash = None
668
+ try:
669
+ _orig_func = getattr(wrapper, '_original_func', func)
670
+ function_hash = state_manager._compute_function_hash(_orig_func)
671
+ except Exception:
672
+ function_hash = None
673
+ try:
674
+ config_hash = state_manager._compute_hash(getattr(ctx, 'global_config', {}) or {})
675
+ except Exception:
676
+ config_hash = None
677
+
678
+ cached_result = state_manager.get_cached_step_result_with_metadata(
679
+ run_id=getattr(ctx, 'run_id', None) or 'default',
680
+ step_name=step_key,
681
+ process_name=process_name_for_step,
682
+ input_hash=input_hash,
683
+ config_hash=config_hash,
684
+ function_hash=function_hash,
685
+ )
686
+ if cached_result is not None:
687
+ cached_used = True
688
+ result, cached_run_id, cached_metadata = cached_result
689
+ # Tag and log cache usage
690
+ try:
691
+ logging.getLogger(__name__).info(f"Using cached result for step: {step_key} (process {process_name_for_step}) from run {cached_run_id}")
692
+ if isinstance(result, dict):
693
+ result.setdefault('__was_cached__', True)
694
+ # Set execution time to 0 for cached results since they're loaded instantly
695
+ result.setdefault('__execution_time__', 0.0)
696
+ except Exception:
697
+ pass
698
+ # Attach to context
699
+ try:
700
+ if ctx:
701
+ ctx.step_results[step_key] = result
702
+ except Exception:
703
+ pass
704
+ else:
705
+ # Execute step (local or distributed)
706
+ # Set step context before execution
707
+ prev_step = get_current_step_context()
708
+ set_current_step_context(step_key)
709
+ try:
710
+ # Record step start for live UI/tooling
711
+ try:
712
+ if state_manager and ctx and process_name_for_step and step_key:
713
+ state_manager.record_step_started(
714
+ getattr(ctx, 'run_id', None) or 'default',
715
+ process_name_for_step,
716
+ step_key,
717
+ )
718
+ except Exception:
719
+ pass
720
+ result = func(*args, **kwargs)
721
+ finally:
722
+ # Restore previous step context
723
+ set_current_step_context(prev_step)
724
+
725
+ # Validate result is a dictionary
726
+ if not isinstance(result, dict):
727
+ raise ValueError(f"Step '{step_key}' must return a dictionary, got {type(result).__name__}.")
728
+
729
+ # Mark as NOT cached since we executed the function
730
+ try:
731
+ if isinstance(result, dict):
732
+ result.setdefault('__was_cached__', False)
733
+ except Exception:
734
+ pass
735
+ except Exception:
736
+ # On any error, fall back to direct execution
737
+ result = func(*args, **kwargs)
738
+ # Validate result is a dictionary
739
+ if not isinstance(result, dict):
740
+ raise ValueError(f"Step '{step_key}' must return a dictionary, got {type(result).__name__}.")
741
+ else:
742
+ # No state manager or context - execute directly (threaded mode fallback)
743
+ result = func(*args, **kwargs)
744
+ # Validate result is a dictionary
745
+ if not isinstance(result, dict):
746
+ raise ValueError(f"Step '{step_key}' must return a dictionary, got {type(result).__name__}.")
747
+
748
+ try:
749
+ if isinstance(result, dict):
750
+ result.setdefault('__was_cached__', False)
751
+ except Exception:
752
+ pass
753
+
754
+ # Attach result to context
755
+ try:
756
+ if ctx and isinstance(result, dict):
757
+ step_key = getattr(wrapper, '_step_name', name or func.__name__)
758
+ if step_key:
759
+ ctx.step_results[step_key] = result
760
+ except Exception:
761
+ pass
762
+
763
+ # Post-execution: record cache entry or cached hit event for manual steps
764
+ try:
765
+ if state_manager and isinstance(result, dict) and StepExecutionResult is not None:
766
+ try:
767
+ # Recompute hashes if missing
768
+ if input_hash is None:
769
+ try:
770
+ import inspect as _inspect
771
+ _orig_func = getattr(wrapper, '_original_func', func)
772
+ _sig = _inspect.signature(_orig_func)
773
+ _bound = _sig.bind_partial(*args, **kwargs)
774
+ _call_params = {k: v for k, v in _bound.arguments.items() if k != 'context'}
775
+ input_hash = state_manager._compute_hash(_call_params)
776
+ except Exception:
777
+ input_hash = None
778
+ if function_hash is None:
779
+ try:
780
+ _orig_func = getattr(wrapper, '_original_func', func)
781
+ function_hash = state_manager._compute_function_hash(_orig_func)
782
+ except Exception:
783
+ function_hash = None
784
+ if config_hash is None:
785
+ try:
786
+ config_hash = state_manager._compute_hash(getattr(ctx, 'global_config', {}) or {})
787
+ except Exception:
788
+ config_hash = None
789
+ except Exception:
790
+ pass
791
+ try:
792
+ # Get logging flag from step definition (default to True if not available)
793
+ _step_def_for_logging = getattr(wrapper, '_step_definition', None)
794
+ enable_logging = getattr(_step_def_for_logging, 'logging', True) if _step_def_for_logging else True
795
+
796
+ step_exec_result = StepExecutionResult(
797
+ step_name=step_key,
798
+ success=True,
799
+ result=result,
800
+ execution_time=0.0,
801
+ timestamp=datetime.now().isoformat(),
802
+ )
803
+ # Pass cached metadata if this was a cache hit
804
+ cached_run_id = None
805
+ cached_started_at = None
806
+ cached_ended_at = None
807
+ cached_execution_time = None
808
+ if cached_used and 'cached_metadata' in locals():
809
+ cached_run_id = cached_metadata.get('run_id')
810
+ cached_started_at = cached_metadata.get('started_at')
811
+ cached_ended_at = cached_metadata.get('ended_at')
812
+ cached_execution_time = cached_metadata.get('execution_time')
813
+
814
+ state_manager.record_step_completion(
815
+ getattr(ctx, 'run_id', None) or 'default',
816
+ step_exec_result,
817
+ input_hash=input_hash,
818
+ config_hash=config_hash,
819
+ function_name=step_key,
820
+ function_hash=function_hash,
821
+ was_cached=bool(cached_used),
822
+ process_name=process_name_for_step,
823
+ enable_logging=enable_logging,
824
+ cached_run_id=cached_run_id,
825
+ cached_started_at=cached_started_at,
826
+ cached_ended_at=cached_ended_at,
827
+ cached_execution_time=cached_execution_time,
828
+ )
829
+ except Exception:
830
+ pass
831
+ except Exception:
832
+ pass
833
+
834
+ return result
835
+
836
+ # Store decoration-time process_name (can be None if step defined outside process)
837
+ # At runtime, the actual process will be resolved from get_current_process_context()
838
+ step_def = StepDefinition(
839
+ name=step_name,
840
+ func=wrapper,
841
+ step_type=step_type,
842
+ process_name=get_current_process_context(), # Can be None
843
+ logging=logging,
844
+ )
845
+ step_def.original_func = func
846
+
847
+ _step_registry.register_step(step_def)
848
+
849
+ wrapper._step_name = step_name
850
+ wrapper._step_definition = step_def
851
+ wrapper._original_func = func
852
+
853
+ return wrapper
854
+
855
+ return decorator
856
+
857
+
858
+ def get_step_registry() -> StepRegistry:
859
+ """Get the global step registry."""
860
+ return _step_registry
861
+
862
+
863
+ def get_process_registry() -> ProcessRegistry:
864
+ """Get the global process registry."""
865
+ return _process_registry
866
+
867
+ def _init_worker_state_manager_if_needed(ctx: 'StepContext') -> Any:
868
+ """Ensure a StepStateManager exists on the worker when executing steps.
869
+ Returns the state manager or None.
870
+ """
871
+ try:
872
+ from .step_state_manager import StepStateManager # local to avoid import cycles at module import
873
+ # If already present, reuse
874
+ sm = get_state_manager()
875
+ if sm is not None:
876
+ return sm
877
+ # Build from context cache configuration
878
+ kv_store = None
879
+ obj_store = None
880
+ obj_prefix = None
881
+ cache_dir = Path("step_cache")
882
+ cfg = ctx.global_config if isinstance(ctx.global_config, dict) else {}
883
+ cache_cfg = (cfg.get('cache') or (cfg.get('model', {}) or {}).get('parameters', {}).get('cache')) if isinstance(cfg, dict) else {}
884
+ backend_cfg = (cache_cfg or {}).get('backend') if isinstance(cache_cfg, dict) else {}
885
+ store_cfg = (cache_cfg or {}).get('object_store') if isinstance(cache_cfg, dict) else {}
886
+ # Centralized KV/object-store creation.
887
+ try:
888
+ import os as _os
889
+ from mlops.core.workspace import get_projects_root as _get_projects_root, get_workspace_root as _get_workspace_root
890
+ from mlops.storage.factory import create_kv_store as _create_kv_store, create_object_store as _create_obj_store
891
+
892
+ pid_effective = str(getattr(ctx, "project_id", None) or _os.getenv("MLOPS_PROJECT_ID") or "default")
893
+ ws_root = _get_workspace_root()
894
+ proj_root = _get_projects_root(ws_root) / pid_effective
895
+
896
+ kv_store = _create_kv_store(
897
+ pid_effective,
898
+ backend_cfg if isinstance(backend_cfg, dict) else {},
899
+ env=_os.environ,
900
+ workspace_root=ws_root,
901
+ project_root=proj_root,
902
+ )
903
+ # create_object_store expects the full cache cfg (with nested object_store)
904
+ obj_store = _create_obj_store(cache_cfg if isinstance(cache_cfg, dict) else {}, env=_os.environ)
905
+ obj_prefix = None
906
+ except Exception:
907
+ try:
908
+ from mlops.storage.adapters.memory_store import InMemoryStore # type: ignore
909
+
910
+ pid_effective = str(getattr(ctx, "project_id", None) or "default")
911
+ kv_store = InMemoryStore(pid_effective)
912
+ except Exception:
913
+ kv_store = None
914
+ obj_store = None
915
+ obj_prefix = None
916
+ sm_new = StepStateManager(
917
+ cache_dir=cache_dir,
918
+ kv_store=kv_store,
919
+ logger=logging.getLogger(__name__),
920
+ cache_ttl_hours=int(((cache_cfg or {}).get('ttl_hours') if isinstance(cache_cfg, dict) else 24) or 24),
921
+ object_store=obj_store,
922
+ object_prefix=obj_prefix,
923
+ )
924
+ set_state_manager(sm_new)
925
+ return sm_new
926
+ except Exception:
927
+ return get_state_manager()
928
+
929
+
930
+ def _inject_context_and_hparams(func: Callable, ctx: 'StepContext', kwargs: Dict[str, Any]) -> Dict[str, Any]:
931
+ """Return kwargs with context/hyperparameters injected when requested by signature.
932
+
933
+ Context and hyperparameters are ONLY injected if explicitly declared as parameters,
934
+ NOT into **kwargs.
935
+ """
936
+ import inspect as _inspect
937
+ try:
938
+ sig = _inspect.signature(func)
939
+ except Exception:
940
+ sig = None
941
+ new_kwargs = dict(kwargs)
942
+ if sig and ctx is not None:
943
+ if 'context' in sig.parameters and 'context' not in new_kwargs:
944
+ new_kwargs['context'] = ctx
945
+
946
+ try:
947
+ if ('hyperparameters' in sig.parameters and 'hyperparameters' not in new_kwargs) or \
948
+ ('hparams' in sig.parameters and 'hparams' not in new_kwargs):
949
+ merged = ctx.get_hyperparameters(get_current_process_context()) if hasattr(ctx, 'get_hyperparameters') else {}
950
+ if 'hyperparameters' in sig.parameters and 'hyperparameters' not in new_kwargs:
951
+ new_kwargs['hyperparameters'] = merged
952
+ if 'hparams' in sig.parameters and 'hparams' not in new_kwargs:
953
+ new_kwargs['hparams'] = merged
954
+ except Exception:
955
+ pass
956
+ return new_kwargs