expops 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- expops-0.1.3.dist-info/METADATA +826 -0
- expops-0.1.3.dist-info/RECORD +86 -0
- expops-0.1.3.dist-info/WHEEL +5 -0
- expops-0.1.3.dist-info/entry_points.txt +3 -0
- expops-0.1.3.dist-info/licenses/LICENSE +674 -0
- expops-0.1.3.dist-info/top_level.txt +1 -0
- mlops/__init__.py +0 -0
- mlops/__main__.py +11 -0
- mlops/_version.py +34 -0
- mlops/adapters/__init__.py +12 -0
- mlops/adapters/base.py +86 -0
- mlops/adapters/config_schema.py +89 -0
- mlops/adapters/custom/__init__.py +3 -0
- mlops/adapters/custom/custom_adapter.py +447 -0
- mlops/adapters/plugin_manager.py +113 -0
- mlops/adapters/sklearn/__init__.py +3 -0
- mlops/adapters/sklearn/adapter.py +94 -0
- mlops/cluster/__init__.py +3 -0
- mlops/cluster/controller.py +496 -0
- mlops/cluster/process_runner.py +91 -0
- mlops/cluster/providers.py +258 -0
- mlops/core/__init__.py +95 -0
- mlops/core/custom_model_base.py +38 -0
- mlops/core/dask_networkx_executor.py +1265 -0
- mlops/core/executor_worker.py +1239 -0
- mlops/core/experiment_tracker.py +81 -0
- mlops/core/graph_types.py +64 -0
- mlops/core/networkx_parser.py +135 -0
- mlops/core/payload_spill.py +278 -0
- mlops/core/pipeline_utils.py +162 -0
- mlops/core/process_hashing.py +216 -0
- mlops/core/step_state_manager.py +1298 -0
- mlops/core/step_system.py +956 -0
- mlops/core/workspace.py +99 -0
- mlops/environment/__init__.py +10 -0
- mlops/environment/base.py +43 -0
- mlops/environment/conda_manager.py +307 -0
- mlops/environment/factory.py +70 -0
- mlops/environment/pyenv_manager.py +146 -0
- mlops/environment/setup_env.py +31 -0
- mlops/environment/system_manager.py +66 -0
- mlops/environment/utils.py +105 -0
- mlops/environment/venv_manager.py +134 -0
- mlops/main.py +527 -0
- mlops/managers/project_manager.py +400 -0
- mlops/managers/reproducibility_manager.py +575 -0
- mlops/platform.py +996 -0
- mlops/reporting/__init__.py +16 -0
- mlops/reporting/context.py +187 -0
- mlops/reporting/entrypoint.py +292 -0
- mlops/reporting/kv_utils.py +77 -0
- mlops/reporting/registry.py +50 -0
- mlops/runtime/__init__.py +9 -0
- mlops/runtime/context.py +34 -0
- mlops/runtime/env_export.py +113 -0
- mlops/storage/__init__.py +12 -0
- mlops/storage/adapters/__init__.py +9 -0
- mlops/storage/adapters/gcp_kv_store.py +778 -0
- mlops/storage/adapters/gcs_object_store.py +96 -0
- mlops/storage/adapters/memory_store.py +240 -0
- mlops/storage/adapters/redis_store.py +438 -0
- mlops/storage/factory.py +199 -0
- mlops/storage/interfaces/__init__.py +6 -0
- mlops/storage/interfaces/kv_store.py +118 -0
- mlops/storage/path_utils.py +38 -0
- mlops/templates/premier-league/charts/plot_metrics.js +70 -0
- mlops/templates/premier-league/charts/plot_metrics.py +145 -0
- mlops/templates/premier-league/charts/requirements.txt +6 -0
- mlops/templates/premier-league/configs/cluster_config.yaml +13 -0
- mlops/templates/premier-league/configs/project_config.yaml +207 -0
- mlops/templates/premier-league/data/England CSV.csv +12154 -0
- mlops/templates/premier-league/models/premier_league_model.py +638 -0
- mlops/templates/premier-league/requirements.txt +8 -0
- mlops/templates/sklearn-basic/README.md +22 -0
- mlops/templates/sklearn-basic/charts/plot_metrics.py +85 -0
- mlops/templates/sklearn-basic/charts/requirements.txt +3 -0
- mlops/templates/sklearn-basic/configs/project_config.yaml +64 -0
- mlops/templates/sklearn-basic/data/train.csv +14 -0
- mlops/templates/sklearn-basic/models/model.py +62 -0
- mlops/templates/sklearn-basic/requirements.txt +10 -0
- mlops/web/__init__.py +3 -0
- mlops/web/server.py +585 -0
- mlops/web/ui/index.html +52 -0
- mlops/web/ui/mlops-charts.js +357 -0
- mlops/web/ui/script.js +1244 -0
- mlops/web/ui/styles.css +248 -0
|
@@ -0,0 +1,956 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict, List, Optional, Callable, Union
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import json
|
|
6
|
+
import joblib
|
|
7
|
+
import time
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
import logging
|
|
10
|
+
from functools import wraps
|
|
11
|
+
import contextvars
|
|
12
|
+
|
|
13
|
+
from .payload_spill import hydrate_payload_refs
|
|
14
|
+
|
|
15
|
+
SerializableData = Union[
|
|
16
|
+
str, int, float, bool,
|
|
17
|
+
List[Union[str, int, float, bool]],
|
|
18
|
+
Dict[str, Union[str, int, float, bool]],
|
|
19
|
+
None
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
ModelData = Any
|
|
23
|
+
|
|
24
|
+
class StepContext:
|
|
25
|
+
"""Simplified context object for steps."""
|
|
26
|
+
|
|
27
|
+
def __init__(self,
|
|
28
|
+
project_id: str,
|
|
29
|
+
run_id: str = None,
|
|
30
|
+
tracker: Any = None,
|
|
31
|
+
step_results: Dict[str, Dict[str, Any]] = None,
|
|
32
|
+
global_config: Dict[str, Any] = None,
|
|
33
|
+
data_paths: Dict[str, Path] = None,
|
|
34
|
+
checkpoint_dir: Optional[Path] = None):
|
|
35
|
+
self.project_id = project_id
|
|
36
|
+
self.run_id = run_id
|
|
37
|
+
self.tracker = tracker
|
|
38
|
+
self.step_results = step_results or {}
|
|
39
|
+
self.global_config = global_config or {}
|
|
40
|
+
self.data_paths = data_paths or {}
|
|
41
|
+
self.shared_state = {}
|
|
42
|
+
self.checkpoint_dir = checkpoint_dir or Path("artifacts/checkpoints")
|
|
43
|
+
self.iteration = 0
|
|
44
|
+
# Name of the process currently being executed; used for resolving per-process settings
|
|
45
|
+
self.current_process: Optional[str] = None
|
|
46
|
+
|
|
47
|
+
def _hydrate_payload(self, payload: Any) -> Any:
|
|
48
|
+
try:
|
|
49
|
+
sm = get_state_manager()
|
|
50
|
+
except Exception:
|
|
51
|
+
sm = None
|
|
52
|
+
try:
|
|
53
|
+
return hydrate_payload_refs(payload, sm)
|
|
54
|
+
except Exception:
|
|
55
|
+
return payload
|
|
56
|
+
|
|
57
|
+
def _resolve_process_for_step(self, step_name: str) -> Optional[str]:
|
|
58
|
+
try:
|
|
59
|
+
pr = get_process_registry()
|
|
60
|
+
return pr.get_process_for_step(step_name) if pr else None
|
|
61
|
+
except Exception:
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def get_step_result(self, step_name: str) -> Optional[Dict[str, Any]]:
|
|
65
|
+
"""Get result from a previously executed step (returns data dictionary)."""
|
|
66
|
+
result = self.step_results.get(step_name)
|
|
67
|
+
if result is None and getattr(self, 'run_id', None):
|
|
68
|
+
try:
|
|
69
|
+
sm = get_state_manager()
|
|
70
|
+
except Exception:
|
|
71
|
+
sm = None
|
|
72
|
+
if sm:
|
|
73
|
+
try:
|
|
74
|
+
cached = sm.get_cached_step_result(
|
|
75
|
+
run_id=self.run_id,
|
|
76
|
+
step_name=step_name,
|
|
77
|
+
process_name=self._resolve_process_for_step(step_name),
|
|
78
|
+
input_hash=None,
|
|
79
|
+
config_hash=None,
|
|
80
|
+
function_hash=None,
|
|
81
|
+
)
|
|
82
|
+
if cached:
|
|
83
|
+
self.step_results[step_name] = cached
|
|
84
|
+
result = cached
|
|
85
|
+
except Exception:
|
|
86
|
+
pass
|
|
87
|
+
if result is None:
|
|
88
|
+
return None
|
|
89
|
+
hydrated = self._hydrate_payload(result)
|
|
90
|
+
if hydrated is not result:
|
|
91
|
+
self.step_results[step_name] = hydrated
|
|
92
|
+
return hydrated
|
|
93
|
+
|
|
94
|
+
def get_step_data(self, step_name: str, data_key: str, process_name: Optional[str] = None) -> Any:
|
|
95
|
+
"""Get specific data output from a previous step.
|
|
96
|
+
|
|
97
|
+
In distributed mode, if the step result is not present in the in-memory
|
|
98
|
+
context (e.g., from a prior process executed on a different worker), this
|
|
99
|
+
method attempts to load the step result from the cache for the current run
|
|
100
|
+
and hydrate it into the context for subsequent accesses.
|
|
101
|
+
"""
|
|
102
|
+
step_result = self.get_step_result(step_name)
|
|
103
|
+
if isinstance(step_result, dict):
|
|
104
|
+
try:
|
|
105
|
+
return step_result.get(data_key)
|
|
106
|
+
except Exception:
|
|
107
|
+
return None
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
def get_process_data(self, process_name: str, data_key: str) -> Any:
|
|
111
|
+
"""Get specific data output returned by a previous process.
|
|
112
|
+
|
|
113
|
+
Falls back to loading the process result from cache for this run when
|
|
114
|
+
not available in memory (distributed mode).
|
|
115
|
+
"""
|
|
116
|
+
try:
|
|
117
|
+
proc_result = self.step_results.get(process_name)
|
|
118
|
+
except Exception:
|
|
119
|
+
proc_result = None
|
|
120
|
+
if proc_result and isinstance(proc_result, dict):
|
|
121
|
+
hydrated = self._hydrate_payload(proc_result)
|
|
122
|
+
if hydrated is not proc_result:
|
|
123
|
+
self.step_results[process_name] = hydrated
|
|
124
|
+
try:
|
|
125
|
+
return hydrated.get(data_key)
|
|
126
|
+
except Exception:
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
# Not present in memory, try cache
|
|
130
|
+
try:
|
|
131
|
+
sm = get_state_manager()
|
|
132
|
+
except Exception:
|
|
133
|
+
sm = None
|
|
134
|
+
if sm and getattr(self, 'run_id', None):
|
|
135
|
+
try:
|
|
136
|
+
# Generic cached lookup via known API
|
|
137
|
+
loaded = sm.get_cached_process_result(process_name, input_hash=None, config_hash=None, function_hash=None)
|
|
138
|
+
if loaded:
|
|
139
|
+
self.step_results[process_name] = loaded
|
|
140
|
+
hydrated = self._hydrate_payload(loaded)
|
|
141
|
+
if hydrated is not loaded:
|
|
142
|
+
self.step_results[process_name] = hydrated
|
|
143
|
+
try:
|
|
144
|
+
return hydrated.get(data_key)
|
|
145
|
+
except Exception:
|
|
146
|
+
return None
|
|
147
|
+
except Exception:
|
|
148
|
+
return None
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def log_metric(self, key: str, value: Any, step: Optional[int] = None) -> None:
|
|
152
|
+
"""Log a metric with MLflow-style step tracking.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
key: Metric name
|
|
156
|
+
value: Metric value. Numeric values are tracked per step; non-numeric values
|
|
157
|
+
(e.g., lists, dicts) are stored as last-snapshot under a shadow key
|
|
158
|
+
with "__last" suffix in the KV store.
|
|
159
|
+
step: Step number (if None, auto-increments from largest existing step).
|
|
160
|
+
"""
|
|
161
|
+
# Enforce logging flag from current step/process
|
|
162
|
+
try:
|
|
163
|
+
from .step_system import (
|
|
164
|
+
get_current_process_context as _get_cproc,
|
|
165
|
+
get_current_step_context as _get_cstep,
|
|
166
|
+
get_step_registry as _get_sreg,
|
|
167
|
+
get_process_registry as _get_preg,
|
|
168
|
+
)
|
|
169
|
+
cur_step = _get_cstep()
|
|
170
|
+
cur_proc = _get_cproc()
|
|
171
|
+
if cur_step:
|
|
172
|
+
try:
|
|
173
|
+
sdef = _get_sreg().get_step(cur_step)
|
|
174
|
+
if sdef is not None and (getattr(sdef, 'logging', True) is False):
|
|
175
|
+
raise RuntimeError(f"Metric logging is disabled for step '{cur_step}' (logging=False).")
|
|
176
|
+
except Exception:
|
|
177
|
+
pass
|
|
178
|
+
if cur_proc:
|
|
179
|
+
try:
|
|
180
|
+
pdef = _get_preg().get_process(cur_proc)
|
|
181
|
+
if pdef is not None and (getattr(pdef, 'logging', True) is False):
|
|
182
|
+
raise RuntimeError(f"Metric logging is disabled for process '{cur_proc}' (logging=False).")
|
|
183
|
+
except Exception:
|
|
184
|
+
pass
|
|
185
|
+
except RuntimeError:
|
|
186
|
+
# Re-raise explicit logging disabled errors
|
|
187
|
+
raise
|
|
188
|
+
except Exception:
|
|
189
|
+
# On any inspection failure, fall through and attempt to log
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
from .step_system import get_state_manager, get_current_process_context, get_current_step_context
|
|
195
|
+
state_manager = get_state_manager()
|
|
196
|
+
if state_manager and self.run_id:
|
|
197
|
+
process_name = get_current_process_context()
|
|
198
|
+
step_name = get_current_step_context()
|
|
199
|
+
try:
|
|
200
|
+
import logging as _logging
|
|
201
|
+
_logger = _logging.getLogger(__name__)
|
|
202
|
+
kv_cls = type(getattr(state_manager, 'kv_store', None)).__name__ if getattr(state_manager, 'kv_store', None) is not None else 'None'
|
|
203
|
+
_logger.info(f"[Metrics] log_metric call -> run_id={self.run_id}, process={process_name}, step={step_name}, key={key}, step_idx={step if step is not None else 'auto'}, kv_store={kv_cls}")
|
|
204
|
+
except Exception:
|
|
205
|
+
pass
|
|
206
|
+
state_manager.log_metric(
|
|
207
|
+
run_id=self.run_id,
|
|
208
|
+
process_name=process_name,
|
|
209
|
+
step_name=step_name,
|
|
210
|
+
metric_name=key,
|
|
211
|
+
value=value,
|
|
212
|
+
step=step
|
|
213
|
+
)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
import logging
|
|
216
|
+
logging.getLogger(__name__).warning(f"Failed to log metric to KV store: {e}")
|
|
217
|
+
|
|
218
|
+
def log_param(self, key: str, value: Union[int, float, str, bool]) -> None:
|
|
219
|
+
"""Log a parameter to the experiment tracker."""
|
|
220
|
+
tracker = getattr(self, "tracker", None)
|
|
221
|
+
if not tracker:
|
|
222
|
+
return
|
|
223
|
+
fn = getattr(tracker, "log_param", None)
|
|
224
|
+
if callable(fn):
|
|
225
|
+
try:
|
|
226
|
+
fn(key, value)
|
|
227
|
+
except Exception:
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]:
|
|
231
|
+
"""Load a checkpoint from the given path."""
|
|
232
|
+
try:
|
|
233
|
+
checkpoint_data = joblib.load(checkpoint_path)
|
|
234
|
+
print(f"[Checkpoint] Model loaded from: {checkpoint_path}")
|
|
235
|
+
return checkpoint_data
|
|
236
|
+
except Exception as e:
|
|
237
|
+
print(f"[Checkpoint] Failed to load checkpoint {checkpoint_path}: {e}")
|
|
238
|
+
raise
|
|
239
|
+
|
|
240
|
+
def get_hyperparameters(self, process_name: Optional[str] = None) -> Dict[str, Any]:
|
|
241
|
+
"""Return hyperparameters with per-process overrides taking precedence.
|
|
242
|
+
|
|
243
|
+
Resolution order (later overrides earlier):
|
|
244
|
+
1) Global hyperparameters from context (supports both legacy and current layouts)
|
|
245
|
+
2) Process-specific hyperparameters from pipeline.processes[name].hyperparameters
|
|
246
|
+
Args:
|
|
247
|
+
process_name: Explicit process name; if None, uses context.current_process
|
|
248
|
+
Returns:
|
|
249
|
+
Merged hyperparameters dict
|
|
250
|
+
"""
|
|
251
|
+
try:
|
|
252
|
+
# Support current layout (parameters model dumped directly) and legacy layout nested under model.parameters
|
|
253
|
+
global_hp = {}
|
|
254
|
+
try:
|
|
255
|
+
if isinstance(self.global_config, dict):
|
|
256
|
+
if 'hyperparameters' in self.global_config:
|
|
257
|
+
maybe = self.global_config.get('hyperparameters')
|
|
258
|
+
if isinstance(maybe, dict):
|
|
259
|
+
global_hp = dict(maybe)
|
|
260
|
+
elif 'model' in self.global_config:
|
|
261
|
+
maybe = (
|
|
262
|
+
self.global_config.get('model', {})
|
|
263
|
+
.get('parameters', {})
|
|
264
|
+
.get('hyperparameters', {})
|
|
265
|
+
)
|
|
266
|
+
if isinstance(maybe, dict):
|
|
267
|
+
global_hp = dict(maybe)
|
|
268
|
+
except Exception:
|
|
269
|
+
global_hp = {}
|
|
270
|
+
|
|
271
|
+
proc_hp = {}
|
|
272
|
+
try:
|
|
273
|
+
proc = process_name or getattr(self, 'current_process', None)
|
|
274
|
+
pipeline_cfg = self.global_config.get('pipeline') if isinstance(self.global_config, dict) else None
|
|
275
|
+
processes_list = (pipeline_cfg or {}).get('processes') if isinstance(pipeline_cfg, dict) else None
|
|
276
|
+
if isinstance(processes_list, list) and proc:
|
|
277
|
+
for p in processes_list:
|
|
278
|
+
try:
|
|
279
|
+
if isinstance(p, dict) and p.get('name') == proc:
|
|
280
|
+
maybe = p.get('hyperparameters')
|
|
281
|
+
if isinstance(maybe, dict):
|
|
282
|
+
proc_hp = dict(maybe)
|
|
283
|
+
break
|
|
284
|
+
except Exception:
|
|
285
|
+
continue
|
|
286
|
+
except Exception:
|
|
287
|
+
proc_hp = {}
|
|
288
|
+
|
|
289
|
+
merged = {}
|
|
290
|
+
try:
|
|
291
|
+
merged.update(global_hp)
|
|
292
|
+
except Exception:
|
|
293
|
+
pass
|
|
294
|
+
try:
|
|
295
|
+
merged.update(proc_hp)
|
|
296
|
+
except Exception:
|
|
297
|
+
pass
|
|
298
|
+
return merged
|
|
299
|
+
except Exception:
|
|
300
|
+
return {}
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class StepContextFactory:
|
|
304
|
+
"""
|
|
305
|
+
Singleton factory for managing StepContext instances.
|
|
306
|
+
|
|
307
|
+
This factory ensures that each project has a unique context instance
|
|
308
|
+
and provides automatic context management for steps.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
_instance = None
|
|
312
|
+
_contexts: Dict[tuple[str, str], StepContext] = {}
|
|
313
|
+
|
|
314
|
+
def __new__(cls):
|
|
315
|
+
if cls._instance is None:
|
|
316
|
+
cls._instance = super().__new__(cls)
|
|
317
|
+
return cls._instance
|
|
318
|
+
|
|
319
|
+
def create_context(self,
|
|
320
|
+
project_id: str,
|
|
321
|
+
run_id: str,
|
|
322
|
+
tracker: Any = None,
|
|
323
|
+
step_results: Dict[str, Dict[str, Any]] = None,
|
|
324
|
+
global_config: Dict[str, Any] = None,
|
|
325
|
+
data_paths: Dict[str, Path] = None,
|
|
326
|
+
checkpoint_dir: Optional[Path] = None) -> StepContext:
|
|
327
|
+
"""
|
|
328
|
+
Create or get a context for a specific project.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
project_id: Unique project identifier (context key)
|
|
332
|
+
run_id: Unique run identifier (stored in context state)
|
|
333
|
+
tracker: Experiment tracker instance
|
|
334
|
+
step_results: Dictionary of step results (data dictionaries)
|
|
335
|
+
global_config: Global configuration
|
|
336
|
+
data_paths: Paths to data files
|
|
337
|
+
checkpoint_dir: Directory for checkpoints
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
StepContext instance for the project
|
|
341
|
+
"""
|
|
342
|
+
run_key = str(run_id or "default")
|
|
343
|
+
key = (str(project_id), run_key)
|
|
344
|
+
|
|
345
|
+
if key not in self._contexts:
|
|
346
|
+
ctx = StepContext(
|
|
347
|
+
project_id=project_id,
|
|
348
|
+
run_id=run_id,
|
|
349
|
+
tracker=tracker,
|
|
350
|
+
step_results=step_results or {},
|
|
351
|
+
global_config=global_config or {},
|
|
352
|
+
data_paths=data_paths or {},
|
|
353
|
+
checkpoint_dir=checkpoint_dir,
|
|
354
|
+
)
|
|
355
|
+
self._contexts[key] = ctx
|
|
356
|
+
return ctx
|
|
357
|
+
|
|
358
|
+
# Reuse the run-scoped context, but allow callers to refresh its references.
|
|
359
|
+
ctx = self._contexts[key]
|
|
360
|
+
try:
|
|
361
|
+
ctx.run_id = run_id
|
|
362
|
+
except Exception:
|
|
363
|
+
pass
|
|
364
|
+
if tracker is not None:
|
|
365
|
+
try:
|
|
366
|
+
ctx.tracker = tracker
|
|
367
|
+
except Exception:
|
|
368
|
+
pass
|
|
369
|
+
if step_results is not None:
|
|
370
|
+
try:
|
|
371
|
+
ctx.step_results = step_results
|
|
372
|
+
except Exception:
|
|
373
|
+
pass
|
|
374
|
+
if global_config is not None:
|
|
375
|
+
try:
|
|
376
|
+
ctx.global_config = global_config
|
|
377
|
+
except Exception:
|
|
378
|
+
pass
|
|
379
|
+
if data_paths is not None:
|
|
380
|
+
try:
|
|
381
|
+
ctx.data_paths = data_paths
|
|
382
|
+
except Exception:
|
|
383
|
+
pass
|
|
384
|
+
if checkpoint_dir is not None:
|
|
385
|
+
try:
|
|
386
|
+
ctx.checkpoint_dir = checkpoint_dir
|
|
387
|
+
except Exception:
|
|
388
|
+
pass
|
|
389
|
+
return ctx
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@dataclass
|
|
393
|
+
class StepDefinition:
|
|
394
|
+
"""Definition of a step."""
|
|
395
|
+
name: str
|
|
396
|
+
func: Callable
|
|
397
|
+
step_type: str = "general"
|
|
398
|
+
process_name: Optional[str] = None
|
|
399
|
+
original_func: Optional[Callable] = field(default=None, init=False)
|
|
400
|
+
inputs: List[str] = field(default_factory=list)
|
|
401
|
+
outputs: List[str] = field(default_factory=list)
|
|
402
|
+
condition: Optional[str] = None
|
|
403
|
+
logging: bool = True
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@dataclass
|
|
407
|
+
class ProcessDefinition:
|
|
408
|
+
"""Definition of a process that groups related steps."""
|
|
409
|
+
name: str
|
|
410
|
+
description: str
|
|
411
|
+
parameters: Dict[str, Any] = field(default_factory=dict)
|
|
412
|
+
step_names: List[str] = field(default_factory=list)
|
|
413
|
+
runner: Optional[Callable] = None
|
|
414
|
+
original_func: Optional[Callable] = field(default=None, init=False)
|
|
415
|
+
logging: bool = True
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
class StepRegistry:
|
|
419
|
+
"""Simple registry for step functions."""
|
|
420
|
+
|
|
421
|
+
def __init__(self):
|
|
422
|
+
self._steps: Dict[str, StepDefinition] = {}
|
|
423
|
+
|
|
424
|
+
def register_step(self, step_def: StepDefinition) -> None:
|
|
425
|
+
"""Register a step definition."""
|
|
426
|
+
self._steps[step_def.name] = step_def
|
|
427
|
+
|
|
428
|
+
def get_step(self, name: str) -> Optional[StepDefinition]:
|
|
429
|
+
"""Get a step definition by name."""
|
|
430
|
+
return self._steps.get(name)
|
|
431
|
+
|
|
432
|
+
def list_steps(self) -> List[str]:
|
|
433
|
+
"""List all registered step names."""
|
|
434
|
+
return list(self._steps.keys())
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class ProcessRegistry:
|
|
438
|
+
"""Simple registry for process definitions."""
|
|
439
|
+
|
|
440
|
+
def __init__(self):
|
|
441
|
+
self._processes: Dict[str, ProcessDefinition] = {}
|
|
442
|
+
self._step_to_process: Dict[str, str] = {}
|
|
443
|
+
|
|
444
|
+
def register_process(self, process_def: ProcessDefinition) -> None:
|
|
445
|
+
"""Register a process definition."""
|
|
446
|
+
self._processes[process_def.name] = process_def
|
|
447
|
+
for step_name in process_def.step_names:
|
|
448
|
+
self._step_to_process[step_name] = process_def.name
|
|
449
|
+
|
|
450
|
+
def get_process(self, name: str) -> Optional[ProcessDefinition]:
|
|
451
|
+
"""Get a process definition by name."""
|
|
452
|
+
return self._processes.get(name)
|
|
453
|
+
|
|
454
|
+
def get_process_for_step(self, step_name: str) -> Optional[str]:
|
|
455
|
+
"""Get the process name that contains a given step."""
|
|
456
|
+
return self._step_to_process.get(step_name)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
_step_registry = StepRegistry()
|
|
460
|
+
_process_registry = ProcessRegistry()
|
|
461
|
+
|
|
462
|
+
# Use contextvars for thread-safe context propagation in Dask workers
|
|
463
|
+
_current_process_context: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('current_process_context', default=None)
|
|
464
|
+
_current_step_context: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('current_step_context', default=None)
|
|
465
|
+
_current_context: contextvars.ContextVar[Optional[StepContext]] = contextvars.ContextVar('current_context', default=None)
|
|
466
|
+
_current_state_manager: contextvars.ContextVar[Any] = contextvars.ContextVar('current_state_manager', default=None)
|
|
467
|
+
|
|
468
|
+
_context_factory = StepContextFactory()
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def get_context_factory() -> StepContextFactory:
|
|
472
|
+
"""Get the global context factory instance."""
|
|
473
|
+
return _context_factory
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def set_current_context(context: StepContext) -> None:
|
|
477
|
+
"""Set the current active context (thread-safe)."""
|
|
478
|
+
_current_context.set(context)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def get_current_context() -> Optional[StepContext]:
|
|
482
|
+
"""Get the current active context (thread-safe)."""
|
|
483
|
+
return _current_context.get()
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def set_current_process_context(process_name: Optional[str]) -> None:
|
|
487
|
+
"""Set the current process context name for step registration/caching (thread-safe)."""
|
|
488
|
+
_current_process_context.set(process_name)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def get_current_process_context() -> Optional[str]:
|
|
492
|
+
"""Get the current process context name (thread-safe)."""
|
|
493
|
+
return _current_process_context.get()
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def set_current_step_context(step_name: Optional[str]) -> None:
|
|
497
|
+
"""Set the current step context name (thread-safe)."""
|
|
498
|
+
_current_step_context.set(step_name)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def get_current_step_context() -> Optional[str]:
|
|
502
|
+
"""Get the current step context name (thread-safe)."""
|
|
503
|
+
return _current_step_context.get()
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def log_metric(key: str, value: Any, step: Optional[int] = None) -> None:
|
|
507
|
+
"""Convenience function to log metrics from anywhere in your model code.
|
|
508
|
+
|
|
509
|
+
This is the main function users should call to log metrics during training.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
key: Metric name (e.g., 'loss', 'accuracy', 'learning_rate')
|
|
513
|
+
value: Metric value. Numeric values are tracked per step; non-numeric values
|
|
514
|
+
(e.g., lists, dicts) are stored as last-snapshot under a shadow key
|
|
515
|
+
with "__last" suffix in the KV store.
|
|
516
|
+
step: Step/iteration number (optional). If not provided, auto-increments from largest existing step.
|
|
517
|
+
|
|
518
|
+
Example:
|
|
519
|
+
>>> from mlops.core import log_metric
|
|
520
|
+
>>> for epoch in range(100):
|
|
521
|
+
>>> loss = train_one_epoch()
|
|
522
|
+
>>> log_metric('loss', loss, step=epoch+1)
|
|
523
|
+
"""
|
|
524
|
+
ctx = get_current_context()
|
|
525
|
+
if ctx:
|
|
526
|
+
ctx.log_metric(key, value, step=step)
|
|
527
|
+
else:
|
|
528
|
+
import logging
|
|
529
|
+
logging.getLogger(__name__).warning(
|
|
530
|
+
"log_metric called but no context is active. Metric will not be logged."
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def set_state_manager(state_manager: Any) -> None:
|
|
535
|
+
"""Provide global access to the current StepStateManager for manual step caching (thread-safe)."""
|
|
536
|
+
_current_state_manager.set(state_manager)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def get_state_manager() -> Any:
|
|
540
|
+
"""Get the current StepStateManager if available (thread-safe)."""
|
|
541
|
+
return _current_state_manager.get()
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def process(description: str = "", parameters: Dict[str, Any] = None, logging: bool = True):
|
|
545
|
+
"""
|
|
546
|
+
Simplified decorator to define a process and group the steps.
|
|
547
|
+
"""
|
|
548
|
+
def decorator(func):
|
|
549
|
+
# Use function name as the process name for registry lookup
|
|
550
|
+
name = func.__name__
|
|
551
|
+
def _wrapped_runner(*args, **kwargs):
|
|
552
|
+
prev = get_current_process_context()
|
|
553
|
+
if prev is None:
|
|
554
|
+
set_current_process_context(name)
|
|
555
|
+
try:
|
|
556
|
+
_kwargs = dict(kwargs)
|
|
557
|
+
try:
|
|
558
|
+
_ctx = get_current_context()
|
|
559
|
+
except Exception:
|
|
560
|
+
_ctx = None
|
|
561
|
+
if _ctx is not None:
|
|
562
|
+
_kwargs = _inject_context_and_hparams(func, _ctx, _kwargs)
|
|
563
|
+
result = func(*args, **_kwargs)
|
|
564
|
+
if result is None:
|
|
565
|
+
raise ValueError(f"Process '{name}' must return a dictionary of data.")
|
|
566
|
+
# Validate result is a dictionary
|
|
567
|
+
if not isinstance(result, dict):
|
|
568
|
+
raise ValueError(f"Process '{name}' must return a dictionary, got {type(result).__name__}.")
|
|
569
|
+
# Attach to context
|
|
570
|
+
try:
|
|
571
|
+
ctx = get_current_context()
|
|
572
|
+
if ctx:
|
|
573
|
+
ctx.step_results[name] = result
|
|
574
|
+
except Exception:
|
|
575
|
+
# Do not block process execution on context issues
|
|
576
|
+
pass
|
|
577
|
+
return result
|
|
578
|
+
finally:
|
|
579
|
+
# Only restore if we actually changed it (i.e., if it was None before)
|
|
580
|
+
if prev is None:
|
|
581
|
+
set_current_process_context(prev)
|
|
582
|
+
|
|
583
|
+
process_def = ProcessDefinition(
|
|
584
|
+
name=name,
|
|
585
|
+
description=description,
|
|
586
|
+
parameters=parameters or {},
|
|
587
|
+
step_names=[],
|
|
588
|
+
runner=_wrapped_runner,
|
|
589
|
+
logging=logging,
|
|
590
|
+
)
|
|
591
|
+
process_def.original_func = func
|
|
592
|
+
_process_registry.register_process(process_def)
|
|
593
|
+
|
|
594
|
+
return _wrapped_runner
|
|
595
|
+
|
|
596
|
+
return decorator
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def step(name: str = None,
|
|
600
|
+
step_type: str = "general",
|
|
601
|
+
logging: bool = True):
|
|
602
|
+
"""
|
|
603
|
+
Simplified decorator to register a function as a step.
|
|
604
|
+
"""
|
|
605
|
+
def decorator(func: Callable) -> Callable:
|
|
606
|
+
step_name = name or func.__name__
|
|
607
|
+
|
|
608
|
+
@wraps(func)
|
|
609
|
+
def wrapper(*args, **kwargs):
|
|
610
|
+
import inspect
|
|
611
|
+
sig = inspect.signature(func)
|
|
612
|
+
|
|
613
|
+
# (removed) has_var_keyword detection - unused
|
|
614
|
+
|
|
615
|
+
ctx = kwargs.get('context')
|
|
616
|
+
if not ctx:
|
|
617
|
+
ctx = get_current_context()
|
|
618
|
+
|
|
619
|
+
if 'context' in sig.parameters and 'context' not in kwargs:
|
|
620
|
+
if ctx:
|
|
621
|
+
kwargs['context'] = ctx
|
|
622
|
+
|
|
623
|
+
# Attempt step-level cache lookup for manual calls
|
|
624
|
+
result: Dict[str, Any]
|
|
625
|
+
try:
|
|
626
|
+
from .step_state_manager import StepExecutionResult # Local import to avoid circular import at module load
|
|
627
|
+
except Exception:
|
|
628
|
+
StepExecutionResult = None # type: ignore
|
|
629
|
+
try:
|
|
630
|
+
state_manager = get_state_manager()
|
|
631
|
+
except Exception:
|
|
632
|
+
state_manager = None
|
|
633
|
+
|
|
634
|
+
# Lazily initialize a state manager on workers when missing so step-level
|
|
635
|
+
# caching works in distributed mode. This is crucial for distributed execution.
|
|
636
|
+
if state_manager is None and ctx is not None:
|
|
637
|
+
try:
|
|
638
|
+
state_manager = _init_worker_state_manager_if_needed(ctx)
|
|
639
|
+
except Exception:
|
|
640
|
+
state_manager = get_state_manager()
|
|
641
|
+
|
|
642
|
+
step_key = getattr(wrapper, '_step_name', name or func.__name__)
|
|
643
|
+
# Resolve process_name at runtime: prefer current process context over decoration-time value
|
|
644
|
+
# This allows steps to be defined outside processes and pick up the process they're called from
|
|
645
|
+
runtime_process = get_current_process_context()
|
|
646
|
+
decoration_process = getattr(getattr(wrapper, '_step_definition', None), 'process_name', None)
|
|
647
|
+
process_name_for_step = runtime_process if runtime_process is not None else decoration_process
|
|
648
|
+
|
|
649
|
+
cached_used = False
|
|
650
|
+
input_hash = None
|
|
651
|
+
config_hash = None
|
|
652
|
+
function_hash = None
|
|
653
|
+
|
|
654
|
+
if state_manager and ctx and step_key:
|
|
655
|
+
try:
|
|
656
|
+
# Compute hashes similar to executor
|
|
657
|
+
try:
|
|
658
|
+
import inspect as _inspect
|
|
659
|
+
# Prefer original user function signature to bind call-time args
|
|
660
|
+
_orig_func = getattr(wrapper, '_original_func', func)
|
|
661
|
+
_sig = _inspect.signature(_orig_func)
|
|
662
|
+
_bound = _sig.bind_partial(*args, **kwargs)
|
|
663
|
+
# Exclude context from hashing
|
|
664
|
+
_call_params = {k: v for k, v in _bound.arguments.items() if k != 'context'}
|
|
665
|
+
input_hash = state_manager._compute_hash(_call_params)
|
|
666
|
+
except Exception:
|
|
667
|
+
input_hash = None
|
|
668
|
+
try:
|
|
669
|
+
_orig_func = getattr(wrapper, '_original_func', func)
|
|
670
|
+
function_hash = state_manager._compute_function_hash(_orig_func)
|
|
671
|
+
except Exception:
|
|
672
|
+
function_hash = None
|
|
673
|
+
try:
|
|
674
|
+
config_hash = state_manager._compute_hash(getattr(ctx, 'global_config', {}) or {})
|
|
675
|
+
except Exception:
|
|
676
|
+
config_hash = None
|
|
677
|
+
|
|
678
|
+
cached_result = state_manager.get_cached_step_result_with_metadata(
|
|
679
|
+
run_id=getattr(ctx, 'run_id', None) or 'default',
|
|
680
|
+
step_name=step_key,
|
|
681
|
+
process_name=process_name_for_step,
|
|
682
|
+
input_hash=input_hash,
|
|
683
|
+
config_hash=config_hash,
|
|
684
|
+
function_hash=function_hash,
|
|
685
|
+
)
|
|
686
|
+
if cached_result is not None:
|
|
687
|
+
cached_used = True
|
|
688
|
+
result, cached_run_id, cached_metadata = cached_result
|
|
689
|
+
# Tag and log cache usage
|
|
690
|
+
try:
|
|
691
|
+
logging.getLogger(__name__).info(f"Using cached result for step: {step_key} (process {process_name_for_step}) from run {cached_run_id}")
|
|
692
|
+
if isinstance(result, dict):
|
|
693
|
+
result.setdefault('__was_cached__', True)
|
|
694
|
+
# Set execution time to 0 for cached results since they're loaded instantly
|
|
695
|
+
result.setdefault('__execution_time__', 0.0)
|
|
696
|
+
except Exception:
|
|
697
|
+
pass
|
|
698
|
+
# Attach to context
|
|
699
|
+
try:
|
|
700
|
+
if ctx:
|
|
701
|
+
ctx.step_results[step_key] = result
|
|
702
|
+
except Exception:
|
|
703
|
+
pass
|
|
704
|
+
else:
|
|
705
|
+
# Execute step (local or distributed)
|
|
706
|
+
# Set step context before execution
|
|
707
|
+
prev_step = get_current_step_context()
|
|
708
|
+
set_current_step_context(step_key)
|
|
709
|
+
try:
|
|
710
|
+
# Record step start for live UI/tooling
|
|
711
|
+
try:
|
|
712
|
+
if state_manager and ctx and process_name_for_step and step_key:
|
|
713
|
+
state_manager.record_step_started(
|
|
714
|
+
getattr(ctx, 'run_id', None) or 'default',
|
|
715
|
+
process_name_for_step,
|
|
716
|
+
step_key,
|
|
717
|
+
)
|
|
718
|
+
except Exception:
|
|
719
|
+
pass
|
|
720
|
+
result = func(*args, **kwargs)
|
|
721
|
+
finally:
|
|
722
|
+
# Restore previous step context
|
|
723
|
+
set_current_step_context(prev_step)
|
|
724
|
+
|
|
725
|
+
# Validate result is a dictionary
|
|
726
|
+
if not isinstance(result, dict):
|
|
727
|
+
raise ValueError(f"Step '{step_key}' must return a dictionary, got {type(result).__name__}.")
|
|
728
|
+
|
|
729
|
+
# Mark as NOT cached since we executed the function
|
|
730
|
+
try:
|
|
731
|
+
if isinstance(result, dict):
|
|
732
|
+
result.setdefault('__was_cached__', False)
|
|
733
|
+
except Exception:
|
|
734
|
+
pass
|
|
735
|
+
except Exception:
|
|
736
|
+
# On any error, fall back to direct execution
|
|
737
|
+
result = func(*args, **kwargs)
|
|
738
|
+
# Validate result is a dictionary
|
|
739
|
+
if not isinstance(result, dict):
|
|
740
|
+
raise ValueError(f"Step '{step_key}' must return a dictionary, got {type(result).__name__}.")
|
|
741
|
+
else:
|
|
742
|
+
# No state manager or context - execute directly (threaded mode fallback)
|
|
743
|
+
result = func(*args, **kwargs)
|
|
744
|
+
# Validate result is a dictionary
|
|
745
|
+
if not isinstance(result, dict):
|
|
746
|
+
raise ValueError(f"Step '{step_key}' must return a dictionary, got {type(result).__name__}.")
|
|
747
|
+
|
|
748
|
+
try:
|
|
749
|
+
if isinstance(result, dict):
|
|
750
|
+
result.setdefault('__was_cached__', False)
|
|
751
|
+
except Exception:
|
|
752
|
+
pass
|
|
753
|
+
|
|
754
|
+
# Attach result to context
|
|
755
|
+
try:
|
|
756
|
+
if ctx and isinstance(result, dict):
|
|
757
|
+
step_key = getattr(wrapper, '_step_name', name or func.__name__)
|
|
758
|
+
if step_key:
|
|
759
|
+
ctx.step_results[step_key] = result
|
|
760
|
+
except Exception:
|
|
761
|
+
pass
|
|
762
|
+
|
|
763
|
+
# Post-execution: record cache entry or cached hit event for manual steps
|
|
764
|
+
try:
|
|
765
|
+
if state_manager and isinstance(result, dict) and StepExecutionResult is not None:
|
|
766
|
+
try:
|
|
767
|
+
# Recompute hashes if missing
|
|
768
|
+
if input_hash is None:
|
|
769
|
+
try:
|
|
770
|
+
import inspect as _inspect
|
|
771
|
+
_orig_func = getattr(wrapper, '_original_func', func)
|
|
772
|
+
_sig = _inspect.signature(_orig_func)
|
|
773
|
+
_bound = _sig.bind_partial(*args, **kwargs)
|
|
774
|
+
_call_params = {k: v for k, v in _bound.arguments.items() if k != 'context'}
|
|
775
|
+
input_hash = state_manager._compute_hash(_call_params)
|
|
776
|
+
except Exception:
|
|
777
|
+
input_hash = None
|
|
778
|
+
if function_hash is None:
|
|
779
|
+
try:
|
|
780
|
+
_orig_func = getattr(wrapper, '_original_func', func)
|
|
781
|
+
function_hash = state_manager._compute_function_hash(_orig_func)
|
|
782
|
+
except Exception:
|
|
783
|
+
function_hash = None
|
|
784
|
+
if config_hash is None:
|
|
785
|
+
try:
|
|
786
|
+
config_hash = state_manager._compute_hash(getattr(ctx, 'global_config', {}) or {})
|
|
787
|
+
except Exception:
|
|
788
|
+
config_hash = None
|
|
789
|
+
except Exception:
|
|
790
|
+
pass
|
|
791
|
+
try:
|
|
792
|
+
# Get logging flag from step definition (default to True if not available)
|
|
793
|
+
_step_def_for_logging = getattr(wrapper, '_step_definition', None)
|
|
794
|
+
enable_logging = getattr(_step_def_for_logging, 'logging', True) if _step_def_for_logging else True
|
|
795
|
+
|
|
796
|
+
step_exec_result = StepExecutionResult(
|
|
797
|
+
step_name=step_key,
|
|
798
|
+
success=True,
|
|
799
|
+
result=result,
|
|
800
|
+
execution_time=0.0,
|
|
801
|
+
timestamp=datetime.now().isoformat(),
|
|
802
|
+
)
|
|
803
|
+
# Pass cached metadata if this was a cache hit
|
|
804
|
+
cached_run_id = None
|
|
805
|
+
cached_started_at = None
|
|
806
|
+
cached_ended_at = None
|
|
807
|
+
cached_execution_time = None
|
|
808
|
+
if cached_used and 'cached_metadata' in locals():
|
|
809
|
+
cached_run_id = cached_metadata.get('run_id')
|
|
810
|
+
cached_started_at = cached_metadata.get('started_at')
|
|
811
|
+
cached_ended_at = cached_metadata.get('ended_at')
|
|
812
|
+
cached_execution_time = cached_metadata.get('execution_time')
|
|
813
|
+
|
|
814
|
+
state_manager.record_step_completion(
|
|
815
|
+
getattr(ctx, 'run_id', None) or 'default',
|
|
816
|
+
step_exec_result,
|
|
817
|
+
input_hash=input_hash,
|
|
818
|
+
config_hash=config_hash,
|
|
819
|
+
function_name=step_key,
|
|
820
|
+
function_hash=function_hash,
|
|
821
|
+
was_cached=bool(cached_used),
|
|
822
|
+
process_name=process_name_for_step,
|
|
823
|
+
enable_logging=enable_logging,
|
|
824
|
+
cached_run_id=cached_run_id,
|
|
825
|
+
cached_started_at=cached_started_at,
|
|
826
|
+
cached_ended_at=cached_ended_at,
|
|
827
|
+
cached_execution_time=cached_execution_time,
|
|
828
|
+
)
|
|
829
|
+
except Exception:
|
|
830
|
+
pass
|
|
831
|
+
except Exception:
|
|
832
|
+
pass
|
|
833
|
+
|
|
834
|
+
return result
|
|
835
|
+
|
|
836
|
+
# Store decoration-time process_name (can be None if step defined outside process)
|
|
837
|
+
# At runtime, the actual process will be resolved from get_current_process_context()
|
|
838
|
+
step_def = StepDefinition(
|
|
839
|
+
name=step_name,
|
|
840
|
+
func=wrapper,
|
|
841
|
+
step_type=step_type,
|
|
842
|
+
process_name=get_current_process_context(), # Can be None
|
|
843
|
+
logging=logging,
|
|
844
|
+
)
|
|
845
|
+
step_def.original_func = func
|
|
846
|
+
|
|
847
|
+
_step_registry.register_step(step_def)
|
|
848
|
+
|
|
849
|
+
wrapper._step_name = step_name
|
|
850
|
+
wrapper._step_definition = step_def
|
|
851
|
+
wrapper._original_func = func
|
|
852
|
+
|
|
853
|
+
return wrapper
|
|
854
|
+
|
|
855
|
+
return decorator
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def get_step_registry() -> StepRegistry:
|
|
859
|
+
"""Get the global step registry."""
|
|
860
|
+
return _step_registry
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def get_process_registry() -> ProcessRegistry:
|
|
864
|
+
"""Get the global process registry."""
|
|
865
|
+
return _process_registry
|
|
866
|
+
|
|
867
|
+
def _init_worker_state_manager_if_needed(ctx: 'StepContext') -> Any:
|
|
868
|
+
"""Ensure a StepStateManager exists on the worker when executing steps.
|
|
869
|
+
Returns the state manager or None.
|
|
870
|
+
"""
|
|
871
|
+
try:
|
|
872
|
+
from .step_state_manager import StepStateManager # local to avoid import cycles at module import
|
|
873
|
+
# If already present, reuse
|
|
874
|
+
sm = get_state_manager()
|
|
875
|
+
if sm is not None:
|
|
876
|
+
return sm
|
|
877
|
+
# Build from context cache configuration
|
|
878
|
+
kv_store = None
|
|
879
|
+
obj_store = None
|
|
880
|
+
obj_prefix = None
|
|
881
|
+
cache_dir = Path("step_cache")
|
|
882
|
+
cfg = ctx.global_config if isinstance(ctx.global_config, dict) else {}
|
|
883
|
+
cache_cfg = (cfg.get('cache') or (cfg.get('model', {}) or {}).get('parameters', {}).get('cache')) if isinstance(cfg, dict) else {}
|
|
884
|
+
backend_cfg = (cache_cfg or {}).get('backend') if isinstance(cache_cfg, dict) else {}
|
|
885
|
+
store_cfg = (cache_cfg or {}).get('object_store') if isinstance(cache_cfg, dict) else {}
|
|
886
|
+
# Centralized KV/object-store creation.
|
|
887
|
+
try:
|
|
888
|
+
import os as _os
|
|
889
|
+
from mlops.core.workspace import get_projects_root as _get_projects_root, get_workspace_root as _get_workspace_root
|
|
890
|
+
from mlops.storage.factory import create_kv_store as _create_kv_store, create_object_store as _create_obj_store
|
|
891
|
+
|
|
892
|
+
pid_effective = str(getattr(ctx, "project_id", None) or _os.getenv("MLOPS_PROJECT_ID") or "default")
|
|
893
|
+
ws_root = _get_workspace_root()
|
|
894
|
+
proj_root = _get_projects_root(ws_root) / pid_effective
|
|
895
|
+
|
|
896
|
+
kv_store = _create_kv_store(
|
|
897
|
+
pid_effective,
|
|
898
|
+
backend_cfg if isinstance(backend_cfg, dict) else {},
|
|
899
|
+
env=_os.environ,
|
|
900
|
+
workspace_root=ws_root,
|
|
901
|
+
project_root=proj_root,
|
|
902
|
+
)
|
|
903
|
+
# create_object_store expects the full cache cfg (with nested object_store)
|
|
904
|
+
obj_store = _create_obj_store(cache_cfg if isinstance(cache_cfg, dict) else {}, env=_os.environ)
|
|
905
|
+
obj_prefix = None
|
|
906
|
+
except Exception:
|
|
907
|
+
try:
|
|
908
|
+
from mlops.storage.adapters.memory_store import InMemoryStore # type: ignore
|
|
909
|
+
|
|
910
|
+
pid_effective = str(getattr(ctx, "project_id", None) or "default")
|
|
911
|
+
kv_store = InMemoryStore(pid_effective)
|
|
912
|
+
except Exception:
|
|
913
|
+
kv_store = None
|
|
914
|
+
obj_store = None
|
|
915
|
+
obj_prefix = None
|
|
916
|
+
sm_new = StepStateManager(
|
|
917
|
+
cache_dir=cache_dir,
|
|
918
|
+
kv_store=kv_store,
|
|
919
|
+
logger=logging.getLogger(__name__),
|
|
920
|
+
cache_ttl_hours=int(((cache_cfg or {}).get('ttl_hours') if isinstance(cache_cfg, dict) else 24) or 24),
|
|
921
|
+
object_store=obj_store,
|
|
922
|
+
object_prefix=obj_prefix,
|
|
923
|
+
)
|
|
924
|
+
set_state_manager(sm_new)
|
|
925
|
+
return sm_new
|
|
926
|
+
except Exception:
|
|
927
|
+
return get_state_manager()
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
def _inject_context_and_hparams(func: Callable, ctx: 'StepContext', kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
931
|
+
"""Return kwargs with context/hyperparameters injected when requested by signature.
|
|
932
|
+
|
|
933
|
+
Context and hyperparameters are ONLY injected if explicitly declared as parameters,
|
|
934
|
+
NOT into **kwargs.
|
|
935
|
+
"""
|
|
936
|
+
import inspect as _inspect
|
|
937
|
+
try:
|
|
938
|
+
sig = _inspect.signature(func)
|
|
939
|
+
except Exception:
|
|
940
|
+
sig = None
|
|
941
|
+
new_kwargs = dict(kwargs)
|
|
942
|
+
if sig and ctx is not None:
|
|
943
|
+
if 'context' in sig.parameters and 'context' not in new_kwargs:
|
|
944
|
+
new_kwargs['context'] = ctx
|
|
945
|
+
|
|
946
|
+
try:
|
|
947
|
+
if ('hyperparameters' in sig.parameters and 'hyperparameters' not in new_kwargs) or \
|
|
948
|
+
('hparams' in sig.parameters and 'hparams' not in new_kwargs):
|
|
949
|
+
merged = ctx.get_hyperparameters(get_current_process_context()) if hasattr(ctx, 'get_hyperparameters') else {}
|
|
950
|
+
if 'hyperparameters' in sig.parameters and 'hyperparameters' not in new_kwargs:
|
|
951
|
+
new_kwargs['hyperparameters'] = merged
|
|
952
|
+
if 'hparams' in sig.parameters and 'hparams' not in new_kwargs:
|
|
953
|
+
new_kwargs['hparams'] = merged
|
|
954
|
+
except Exception:
|
|
955
|
+
pass
|
|
956
|
+
return new_kwargs
|