expops 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. expops-0.1.3.dist-info/METADATA +826 -0
  2. expops-0.1.3.dist-info/RECORD +86 -0
  3. expops-0.1.3.dist-info/WHEEL +5 -0
  4. expops-0.1.3.dist-info/entry_points.txt +3 -0
  5. expops-0.1.3.dist-info/licenses/LICENSE +674 -0
  6. expops-0.1.3.dist-info/top_level.txt +1 -0
  7. mlops/__init__.py +0 -0
  8. mlops/__main__.py +11 -0
  9. mlops/_version.py +34 -0
  10. mlops/adapters/__init__.py +12 -0
  11. mlops/adapters/base.py +86 -0
  12. mlops/adapters/config_schema.py +89 -0
  13. mlops/adapters/custom/__init__.py +3 -0
  14. mlops/adapters/custom/custom_adapter.py +447 -0
  15. mlops/adapters/plugin_manager.py +113 -0
  16. mlops/adapters/sklearn/__init__.py +3 -0
  17. mlops/adapters/sklearn/adapter.py +94 -0
  18. mlops/cluster/__init__.py +3 -0
  19. mlops/cluster/controller.py +496 -0
  20. mlops/cluster/process_runner.py +91 -0
  21. mlops/cluster/providers.py +258 -0
  22. mlops/core/__init__.py +95 -0
  23. mlops/core/custom_model_base.py +38 -0
  24. mlops/core/dask_networkx_executor.py +1265 -0
  25. mlops/core/executor_worker.py +1239 -0
  26. mlops/core/experiment_tracker.py +81 -0
  27. mlops/core/graph_types.py +64 -0
  28. mlops/core/networkx_parser.py +135 -0
  29. mlops/core/payload_spill.py +278 -0
  30. mlops/core/pipeline_utils.py +162 -0
  31. mlops/core/process_hashing.py +216 -0
  32. mlops/core/step_state_manager.py +1298 -0
  33. mlops/core/step_system.py +956 -0
  34. mlops/core/workspace.py +99 -0
  35. mlops/environment/__init__.py +10 -0
  36. mlops/environment/base.py +43 -0
  37. mlops/environment/conda_manager.py +307 -0
  38. mlops/environment/factory.py +70 -0
  39. mlops/environment/pyenv_manager.py +146 -0
  40. mlops/environment/setup_env.py +31 -0
  41. mlops/environment/system_manager.py +66 -0
  42. mlops/environment/utils.py +105 -0
  43. mlops/environment/venv_manager.py +134 -0
  44. mlops/main.py +527 -0
  45. mlops/managers/project_manager.py +400 -0
  46. mlops/managers/reproducibility_manager.py +575 -0
  47. mlops/platform.py +996 -0
  48. mlops/reporting/__init__.py +16 -0
  49. mlops/reporting/context.py +187 -0
  50. mlops/reporting/entrypoint.py +292 -0
  51. mlops/reporting/kv_utils.py +77 -0
  52. mlops/reporting/registry.py +50 -0
  53. mlops/runtime/__init__.py +9 -0
  54. mlops/runtime/context.py +34 -0
  55. mlops/runtime/env_export.py +113 -0
  56. mlops/storage/__init__.py +12 -0
  57. mlops/storage/adapters/__init__.py +9 -0
  58. mlops/storage/adapters/gcp_kv_store.py +778 -0
  59. mlops/storage/adapters/gcs_object_store.py +96 -0
  60. mlops/storage/adapters/memory_store.py +240 -0
  61. mlops/storage/adapters/redis_store.py +438 -0
  62. mlops/storage/factory.py +199 -0
  63. mlops/storage/interfaces/__init__.py +6 -0
  64. mlops/storage/interfaces/kv_store.py +118 -0
  65. mlops/storage/path_utils.py +38 -0
  66. mlops/templates/premier-league/charts/plot_metrics.js +70 -0
  67. mlops/templates/premier-league/charts/plot_metrics.py +145 -0
  68. mlops/templates/premier-league/charts/requirements.txt +6 -0
  69. mlops/templates/premier-league/configs/cluster_config.yaml +13 -0
  70. mlops/templates/premier-league/configs/project_config.yaml +207 -0
  71. mlops/templates/premier-league/data/England CSV.csv +12154 -0
  72. mlops/templates/premier-league/models/premier_league_model.py +638 -0
  73. mlops/templates/premier-league/requirements.txt +8 -0
  74. mlops/templates/sklearn-basic/README.md +22 -0
  75. mlops/templates/sklearn-basic/charts/plot_metrics.py +85 -0
  76. mlops/templates/sklearn-basic/charts/requirements.txt +3 -0
  77. mlops/templates/sklearn-basic/configs/project_config.yaml +64 -0
  78. mlops/templates/sklearn-basic/data/train.csv +14 -0
  79. mlops/templates/sklearn-basic/models/model.py +62 -0
  80. mlops/templates/sklearn-basic/requirements.txt +10 -0
  81. mlops/web/__init__.py +3 -0
  82. mlops/web/server.py +585 -0
  83. mlops/web/ui/index.html +52 -0
  84. mlops/web/ui/mlops-charts.js +357 -0
  85. mlops/web/ui/script.js +1244 -0
  86. mlops/web/ui/styles.css +248 -0
@@ -0,0 +1,1239 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from datetime import datetime
6
+ import json
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ import time
11
+
12
+ from .graph_types import ExecutionResult
13
+ from .payload_spill import spill_large_payloads
14
+ from .workspace import get_projects_root, get_workspace_root, infer_source_root, resolve_relative_path
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _apply_hash_overrides(proc_payload: Dict[str, Any], config_hash: Optional[str], function_hash: Optional[str]) -> tuple[Optional[str], Optional[str]]:
20
+ """Apply optional hash overrides supplied by the driver (used for chart nodes)."""
21
+ overrides = proc_payload.get("hash_overrides") if isinstance(proc_payload, dict) else None
22
+ if isinstance(overrides, dict):
23
+ config_hash = overrides.get("config_hash") or config_hash
24
+ function_hash = overrides.get("function_hash") or function_hash
25
+ return config_hash, function_hash
26
+
27
+
28
+ def _strip_internal_keys(value: Any) -> Any:
29
+ if isinstance(value, dict):
30
+ return {k: v for k, v in value.items() if not str(k).startswith("__")}
31
+ return value
32
+
33
+
34
+ def _record_chart_artifacts(
35
+ state_manager: Any,
36
+ project_id: str,
37
+ run_id: str,
38
+ chart_name: str,
39
+ out_dir: Path,
40
+ chart_type: str = "static",
41
+ ) -> None:
42
+ """Best-effort: upload PNG artifacts and record them in the KV store for UI listing."""
43
+ try:
44
+ obj_store = getattr(state_manager, "object_store", None)
45
+ kv = getattr(state_manager, "kv_store", None)
46
+ except Exception:
47
+ return
48
+
49
+ try:
50
+ pngs = list(out_dir.rglob("*.png"))
51
+ except Exception:
52
+ pngs = []
53
+
54
+ abs_charts_root = None
55
+ try:
56
+ if obj_store and hasattr(obj_store, "_bucket") and getattr(obj_store, "_bucket") is not None:
57
+ bname = getattr(getattr(obj_store, "_bucket"), "name", None)
58
+ if bname:
59
+ abs_charts_root = f"gs://{bname}/projects/{project_id}/charts/{run_id}"
60
+ except Exception:
61
+ abs_charts_root = None
62
+
63
+ artifacts: list[dict] = []
64
+ for p in pngs:
65
+ try:
66
+ local_path = str(p.resolve())
67
+ except Exception:
68
+ local_path = str(p)
69
+
70
+ obj_path = None
71
+ if obj_store:
72
+ try:
73
+ base = f"projects/{project_id}/charts/{run_id}/{chart_name}"
74
+ if abs_charts_root:
75
+ base = f"{abs_charts_root}/{chart_name}"
76
+ remote = obj_store.build_uri(base, p.name)
77
+ with open(p, "rb") as f:
78
+ obj_store.put_bytes(remote, f.read(), content_type="image/png")
79
+ obj_path = remote
80
+ except Exception as upload_err:
81
+ logger.warning(f"[Charts] Upload failed for {p.name}: {upload_err}")
82
+ obj_path = None
83
+
84
+ if not obj_path:
85
+ obj_path = local_path
86
+
87
+ try:
88
+ artifacts.append(
89
+ {
90
+ "title": p.name,
91
+ "object_path": obj_path,
92
+ "cache_path": local_path,
93
+ "mime_type": "image/png",
94
+ "size_bytes": p.stat().st_size,
95
+ "created_at": time.time(),
96
+ "chart_type": chart_type,
97
+ }
98
+ )
99
+ except Exception:
100
+ continue
101
+
102
+ try:
103
+ if kv and hasattr(kv, "record_run_chart_artifacts"):
104
+ kv.record_run_chart_artifacts(str(run_id), str(chart_name), artifacts)
105
+ except Exception as kv_err:
106
+ logger.warning(f"[Charts] Failed to record artifacts in KV: {kv_err}")
107
+
108
+
109
+ def _build_step_context_from_payload(context_payload: Dict[str, Any]) -> Any:
110
+ from .step_system import StepContext as _Ctx
111
+ try:
112
+ payload = context_payload if isinstance(context_payload, dict) else {}
113
+
114
+ checkpoint_dir_value = payload.get("checkpoint_dir")
115
+
116
+ step_results_in = payload.get("step_results") or {}
117
+ step_results: Dict[str, Any] = {}
118
+ if isinstance(step_results_in, dict):
119
+ for key, val in step_results_in.items():
120
+ if isinstance(val, dict):
121
+ data = val.get("data")
122
+ if isinstance(data, dict):
123
+ step_results[key] = dict(data)
124
+ else:
125
+ step_results[key] = _strip_internal_keys(val)
126
+ else:
127
+ step_results[key] = val
128
+
129
+ data_paths: Dict[str, Path] = {}
130
+ data_paths_in = payload.get("data_paths") or {}
131
+ if isinstance(data_paths_in, dict):
132
+ for k, v in data_paths_in.items():
133
+ try:
134
+ data_paths[str(k)] = Path(v)
135
+ except Exception:
136
+ continue
137
+
138
+ checkpoint_dir = Path(checkpoint_dir_value) if checkpoint_dir_value else None
139
+
140
+ return _Ctx(
141
+ project_id=payload.get("project_id"),
142
+ run_id=payload.get("run_id"),
143
+ tracker=None,
144
+ step_results=step_results,
145
+ global_config=payload.get("global_config") or {},
146
+ data_paths=data_paths,
147
+ checkpoint_dir=checkpoint_dir,
148
+ )
149
+ except Exception:
150
+ # Fall back to a minimal context if anything in the payload is malformed.
151
+ try:
152
+ pid = context_payload.get("project_id") if isinstance(context_payload, dict) else "default"
153
+ except Exception:
154
+ pid = "default"
155
+ return _Ctx(project_id=pid)
156
+
157
+ def _derive_task_seed(base_seed: int, parts: List[str]) -> int:
158
+ import hashlib as _hashlib
159
+ payload = f"{base_seed}|" + "|".join(parts)
160
+ digest = _hashlib.sha256(payload.encode()).digest()
161
+ val = int.from_bytes(digest[:4], "big") & 0x7FFFFFFF
162
+ return val or (base_seed & 0x7FFFFFFF) or 1
163
+
164
+
165
+ def _seed_all(seed: int) -> None:
166
+ import random as _random
167
+ try:
168
+ _random.seed(seed)
169
+ except Exception:
170
+ pass
171
+ try:
172
+ import numpy as _np # type: ignore
173
+ _np.random.seed(seed)
174
+ except Exception:
175
+ pass
176
+ # Best-effort deep learning libs
177
+ try:
178
+ import torch as _torch # type: ignore
179
+ try:
180
+ _torch.manual_seed(seed)
181
+ except Exception:
182
+ pass
183
+ try:
184
+ if _torch.cuda.is_available():
185
+ _torch.cuda.manual_seed_all(seed)
186
+ except Exception:
187
+ pass
188
+ try:
189
+ _torch.use_deterministic_algorithms(True) # type: ignore[attr-defined]
190
+ except Exception:
191
+ pass
192
+ except Exception:
193
+ pass
194
+ try:
195
+ import tensorflow as _tf # type: ignore
196
+ try:
197
+ _tf.random.set_seed(seed)
198
+ except Exception:
199
+ pass
200
+ except Exception:
201
+ pass
202
+
203
+
204
+ def _seed_rng_for_task(run_id: Optional[str], process_name: Optional[str], step_name: Optional[str], iteration: Optional[int]) -> None:
205
+ # Gate with task-level seeding toggle; default enabled.
206
+ try:
207
+ enabled = str(os.environ.get("MLOPS_TASK_LEVEL_SEEDING", "1")).lower() not in ("0", "false", "no")
208
+ except Exception:
209
+ enabled = True
210
+ if not enabled:
211
+ return
212
+
213
+ try:
214
+ base = int(os.environ.get("MLOPS_RANDOM_SEED", "42") or 42)
215
+ except Exception:
216
+ base = 42
217
+
218
+ parts: List[str] = []
219
+ if process_name:
220
+ parts.append(str(process_name))
221
+ if step_name:
222
+ parts.append(str(step_name))
223
+ if iteration is not None:
224
+ try:
225
+ parts.append(str(int(iteration)))
226
+ except Exception:
227
+ parts.append(str(iteration))
228
+
229
+ seed_val = _derive_task_seed(base, parts)
230
+ _seed_all(seed_val)
231
+ logger.debug(f"[Seed] base={base} parts={parts} -> seed={seed_val}")
232
+
233
+
234
+ def _maybe_import_custom_model_from_global_config(global_params: Dict[str, Any]) -> None:
235
+ if not isinstance(global_params, dict):
236
+ return
237
+
238
+ script_path = global_params.get("custom_script_path")
239
+ if not script_path:
240
+ try:
241
+ script_path = (global_params.get("model", {}) or {}).get("parameters", {}).get("custom_script_path")
242
+ except Exception:
243
+ script_path = None
244
+ if not script_path:
245
+ return
246
+
247
+ import importlib
248
+ import importlib.util
249
+ import sys as _sys
250
+
251
+ script_path_str = str(script_path)
252
+ try:
253
+ spec = importlib.util.spec_from_file_location("custom_model", script_path_str)
254
+ if spec and spec.loader:
255
+ mod = importlib.util.module_from_spec(spec)
256
+ _sys.modules["custom_model"] = mod
257
+ spec.loader.exec_module(mod) # type: ignore[attr-defined]
258
+ return
259
+ except Exception:
260
+ pass
261
+
262
+ # Fall back to importing by module name.
263
+ try:
264
+ stem = Path(script_path_str).stem
265
+ importlib.import_module(stem)
266
+ return
267
+ except Exception:
268
+ pass
269
+
270
+ # Fall back to a previously-loaded module.
271
+ try:
272
+ importlib.import_module("custom_model")
273
+ except Exception:
274
+ return
275
+
276
+
277
+ def _prepare_runner_kwargs(sig: Any, ctx: Any, process_name: Optional[str], dependencies: List[str]) -> Dict[str, Any]:
278
+ kwargs: Dict[str, Any] = {}
279
+ params = getattr(sig, "parameters", {}) or {}
280
+
281
+ if "data" in params:
282
+ data_payload: Dict[str, Any] = {}
283
+ for dep_name in dependencies or []:
284
+ dep_result = ctx.get_step_result(dep_name) if (ctx and hasattr(ctx, "get_step_result")) else None
285
+ data_payload[dep_name] = dep_result if dep_result else {}
286
+ kwargs["data"] = data_payload
287
+
288
+ if "hyperparameters" in params:
289
+ try:
290
+ kwargs["hyperparameters"] = ctx.get_hyperparameters(process_name) if ctx else {}
291
+ except Exception:
292
+ kwargs["hyperparameters"] = {}
293
+ return kwargs
294
+
295
+
296
+ def _compute_process_lookup_hashes_worker(
297
+ state_manager: Any,
298
+ ctx: Any,
299
+ process_name: str,
300
+ dependencies: List[str],
301
+ dependency_map: Optional[Dict[str, List[str]]] = None,
302
+ lookup_name: Optional[str] = None,
303
+ ) -> tuple[Optional[str], Optional[str], Optional[str]]:
304
+ """Compute (ih, ch, fh) on the worker using the same helper as the driver."""
305
+ try:
306
+ from .process_hashing import compute_process_hashes
307
+ except Exception:
308
+ compute_process_hashes = None # type: ignore[assignment]
309
+
310
+ dep_map: Dict[str, List[str]] = {}
311
+ try:
312
+ for k, v in (dependency_map or {}).items():
313
+ dep_map[str(k)] = sorted(set(v or []))
314
+ except Exception:
315
+ dep_map = {}
316
+ dep_map.setdefault(process_name, sorted(set(dependencies or [])))
317
+
318
+ if compute_process_hashes:
319
+ ih, ch, fh = compute_process_hashes(state_manager, ctx, process_name, dep_map, lookup_name=lookup_name)
320
+ else:
321
+ ih = ch = fh = None
322
+
323
+ logger.debug(f"[HashTrace] side=worker process={process_name} ih={ih} ch={ch} fh={fh}")
324
+ return (ih, ch, fh)
325
+
326
+
327
+ def _execute_process_on_worker(ctx: Any, proc_payload: Dict[str, Any], run_id: Optional[str]) -> ExecutionResult:
328
+ process_name = proc_payload.get('name')
329
+ start_time = time.time()
330
+ # Deterministic task-level seeding (process scope)
331
+ _seed_rng_for_task(run_id, process_name, None, 0)
332
+
333
+ from .step_system import get_process_registry as _get_pr, set_current_context as _set_ctx, set_current_process_context as _set_proc
334
+ import io as _io
335
+ import contextlib as _ctxlib
336
+ import logging as _logging
337
+ _log_stream = _io.StringIO()
338
+ _stdout_stream = _io.StringIO()
339
+ _stderr_stream = _io.StringIO()
340
+ _root_logger = _logging.getLogger()
341
+ _prev_root_level = getattr(_root_logger, "level", _logging.INFO)
342
+ _handler = _logging.StreamHandler(_log_stream)
343
+ try:
344
+ _handler.setLevel(_logging.DEBUG)
345
+ _root_logger.addHandler(_handler)
346
+ _root_logger.setLevel(_logging.DEBUG)
347
+ except Exception:
348
+ pass
349
+
350
+ def _cleanup_capture() -> None:
351
+ try:
352
+ _root_logger.removeHandler(_handler)
353
+ except Exception:
354
+ pass
355
+ try:
356
+ _root_logger.setLevel(_prev_root_level)
357
+ except Exception:
358
+ pass
359
+
360
+ state_manager = None
361
+ cache_config = proc_payload.get('cache_config', {}) if isinstance(proc_payload, dict) else {}
362
+ if cache_config:
363
+ try:
364
+ logger.info(
365
+ f"[Worker] Attempting state manager creation from proc_payload cache_config: "
366
+ f"kv_type={cache_config.get('kv_store_type')}, "
367
+ f"obj_type={cache_config.get('object_store_type')}, "
368
+ f"bucket={cache_config.get('object_store_config', {}).get('bucket')}"
369
+ )
370
+ state_manager = _create_worker_state_manager(cache_config)
371
+ if state_manager:
372
+ logger.info(
373
+ f"[Worker] State manager created from cache_config -> "
374
+ f"has_object_store={hasattr(state_manager, 'object_store') and state_manager.object_store is not None}"
375
+ )
376
+ except Exception as e:
377
+ logger.warning(f"[Worker] Failed to create state manager from cache_config: {e}")
378
+ state_manager = None
379
+ if not state_manager:
380
+ try:
381
+ from .step_system import get_current_state_manager
382
+ state_manager = get_current_state_manager()
383
+ except Exception:
384
+ state_manager = None
385
+ if not state_manager:
386
+ try:
387
+ from .step_system import _get_step_system
388
+ ss = _get_step_system()
389
+ if ss and hasattr(ss, 'state_manager'):
390
+ state_manager = ss.state_manager
391
+ except Exception:
392
+ state_manager = None
393
+ # Final fallback: initialize worker state manager from context if still missing
394
+ if not state_manager:
395
+ try:
396
+ _gc_ctx = getattr(ctx, 'global_config', {}) if ctx else {}
397
+ _pid_ctx = getattr(ctx, 'project_id', None)
398
+ _maybe_init_worker_state_manager(_gc_ctx, _pid_ctx)
399
+ from .step_system import get_state_manager as _get_sm_fallback
400
+ state_manager = _get_sm_fallback()
401
+ logger.debug("[Worker] State manager lazily initialized in _execute_process_on_worker")
402
+ except Exception:
403
+ state_manager = None
404
+
405
+ # Resolve process definition early to compute canonical hashes using code_function mapping
406
+ pr = _get_pr()
407
+ lookup_name = proc_payload.get('code_function') or process_name
408
+ pdef = pr.get_process(lookup_name) if pr else None
409
+ runner = getattr(pdef, 'runner', None) if pdef else None
410
+ if not callable(runner):
411
+ _maybe_import_custom_model_from_global_config(getattr(ctx, 'global_config', {}) or {})
412
+ pr = _get_pr()
413
+ pdef = pr.get_process(lookup_name) if pr else None
414
+ runner = getattr(pdef, 'runner', None) if pdef else None
415
+ # Record process started exactly at timing start using unified hashing
416
+ try:
417
+ if state_manager and run_id and process_name:
418
+ deps = proc_payload.get('dependencies', []) or []
419
+ dep_map = proc_payload.get('dependency_map') or {}
420
+ ih, ch, fh = _compute_process_lookup_hashes_worker(state_manager, ctx, process_name, deps, dep_map, lookup_name=lookup_name)
421
+ ch, fh = _apply_hash_overrides(proc_payload, ch, fh)
422
+ state_manager.record_process_started(
423
+ run_id,
424
+ process_name,
425
+ input_hash=ih,
426
+ config_hash=ch,
427
+ function_hash=fh,
428
+ started_at=start_time,
429
+ )
430
+ except Exception:
431
+ pass
432
+
433
+ # Special handling for chart processes
434
+ if str(proc_payload.get('process_type', 'process')) == 'chart':
435
+ try:
436
+ result = _run_chart_process_on_worker(ctx, proc_payload, run_id)
437
+ exec_time = time.time() - start_time
438
+ if isinstance(result, dict):
439
+ rc = result.get("returncode")
440
+ artifact_count = result.get("artifact_count")
441
+ try:
442
+ rc_int = int(rc) if rc is not None else 0
443
+ except Exception:
444
+ rc_int = 0
445
+ try:
446
+ art_int = int(artifact_count) if artifact_count is not None else 0
447
+ except Exception:
448
+ art_int = 0
449
+
450
+ chart_error: Optional[str] = None
451
+ if rc_int != 0:
452
+ chart_error = (
453
+ f"Chart subprocess failed (exit_code={rc_int}). "
454
+ f"See logs: stdout={result.get('stdout_log')}, stderr={result.get('stderr_log')} "
455
+ f"and runner error file: {result.get('error_txt')}."
456
+ )
457
+ elif art_int <= 0:
458
+ chart_error = (
459
+ "Chart produced no PNG artifacts (0 files). "
460
+ "This is treated as a failure. Ensure the chart function calls ctx.savefig(...) "
461
+ f"and probe_paths are correct. Output dir: {result.get('output_dir')}. "
462
+ f"Logs: stdout={result.get('stdout_log')}, stderr={result.get('stderr_log')}."
463
+ )
464
+
465
+ is_success = chart_error is None
466
+
467
+ # Record completion for chart processes so UI/KV reflect status immediately
468
+ try:
469
+ from .step_state_manager import ProcessExecutionResult as _ProcessExec
470
+ if state_manager:
471
+ ih, ch, fh = _compute_process_lookup_hashes_worker(
472
+ state_manager,
473
+ ctx,
474
+ process_name,
475
+ proc_payload.get('dependencies', []) or [],
476
+ proc_payload.get('dependency_map') or {},
477
+ lookup_name=lookup_name,
478
+ )
479
+ ch, fh = _apply_hash_overrides(proc_payload, ch, fh)
480
+ enable_logging = proc_payload.get('logging', True) if isinstance(proc_payload, dict) else True
481
+ state_manager.record_process_completion(
482
+ run_id or 'default',
483
+ _ProcessExec(
484
+ process_name=process_name,
485
+ success=is_success,
486
+ result=result if is_success else None,
487
+ error=chart_error,
488
+ execution_time=exec_time,
489
+ timestamp=datetime.now().isoformat(),
490
+ ),
491
+ input_hash=ih,
492
+ config_hash=ch,
493
+ function_hash=fh,
494
+ was_cached=False,
495
+ enable_logging=enable_logging,
496
+ )
497
+ except Exception:
498
+ pass
499
+ _cleanup_capture()
500
+ return ExecutionResult(name=process_name, result=result, execution_time=exec_time, was_cached=False, error=chart_error)
501
+ else:
502
+ _cleanup_capture()
503
+ return ExecutionResult(name=process_name, result={'__logs__': 'Invalid chart result'}, execution_time=exec_time, was_cached=False, error=None)
504
+ except Exception as e:
505
+ exec_time = time.time() - start_time
506
+ try:
507
+ from .step_state_manager import ProcessExecutionResult as _ProcessExec
508
+ if state_manager:
509
+ ih, ch, fh = _compute_process_lookup_hashes_worker(
510
+ state_manager,
511
+ ctx,
512
+ process_name,
513
+ proc_payload.get('dependencies', []) or [],
514
+ proc_payload.get('dependency_map') or {},
515
+ lookup_name=lookup_name,
516
+ )
517
+ ch, fh = _apply_hash_overrides(proc_payload, ch, fh)
518
+ enable_logging = proc_payload.get('logging', True) if isinstance(proc_payload, dict) else True
519
+ state_manager.record_process_completion(
520
+ run_id or 'default',
521
+ _ProcessExec(
522
+ process_name=process_name,
523
+ success=False,
524
+ result=None,
525
+ error=str(e),
526
+ execution_time=exec_time,
527
+ timestamp=datetime.now().isoformat(),
528
+ ),
529
+ input_hash=ih,
530
+ config_hash=ch,
531
+ function_hash=fh,
532
+ was_cached=False,
533
+ enable_logging=enable_logging,
534
+ )
535
+ except Exception:
536
+ pass
537
+ _cleanup_capture()
538
+ return ExecutionResult(name=process_name, result={'__error_context__': str(e)}, execution_time=exec_time, was_cached=False, error=str(e))
539
+
540
+ # runner already resolved above
541
+ if not callable(runner):
542
+ exec_time = time.time() - start_time
543
+ _cleanup_capture()
544
+ return ExecutionResult(name=process_name, result=None, execution_time=exec_time, was_cached=False, error=f"No runner defined for process '{process_name}'.")
545
+
546
+ _set_ctx(ctx)
547
+ try:
548
+ try:
549
+ _set_proc(process_name)
550
+ except Exception:
551
+ pass
552
+ import inspect
553
+ try:
554
+ original_func = getattr(pdef, 'original_func', None) if pdef else None
555
+ sig = inspect.signature(original_func) if original_func else inspect.signature(runner)
556
+ except Exception:
557
+ sig = inspect.signature(runner)
558
+ dependencies = proc_payload.get('dependencies', [])
559
+ kwargs = _prepare_runner_kwargs(sig, ctx, process_name, dependencies)
560
+ with _ctxlib.redirect_stdout(_stdout_stream), _ctxlib.redirect_stderr(_stderr_stream):
561
+ ret = runner(**kwargs) if kwargs else runner()
562
+ except Exception as runner_error:
563
+ exec_time = time.time() - start_time
564
+ error_result = {
565
+ '__logs__': (_log_stream.getvalue() or '') + (_stdout_stream.getvalue() or '') + (_stderr_stream.getvalue() or '')
566
+ }
567
+ return ExecutionResult(name=process_name, result=error_result, execution_time=exec_time, was_cached=False, error=str(runner_error))
568
+ finally:
569
+ try:
570
+ _set_proc(None)
571
+ except Exception:
572
+ pass
573
+ _set_ctx(None)
574
+ _cleanup_capture()
575
+
576
+ if not isinstance(ret, dict):
577
+ exec_time = time.time() - start_time
578
+ try:
579
+ _captured = {
580
+ '__logs__': (_log_stream.getvalue() or '') + (_stdout_stream.getvalue() or '') + (_stderr_stream.getvalue() or '')
581
+ }
582
+ except Exception:
583
+ _captured = None
584
+ return ExecutionResult(name=process_name, result=_captured, execution_time=exec_time, was_cached=False, error=f"Process '{process_name}' must return a dictionary, got {type(ret).__name__}.")
585
+
586
+ if state_manager and isinstance(ret, dict):
587
+ try:
588
+ ret = spill_large_payloads(ret, state_manager, run_id, process_name)
589
+ except Exception as spill_err:
590
+ logger.error("[PayloadSpill] Failed to spill payload for process %s: %s", process_name, spill_err)
591
+ raise
592
+
593
+ exec_time = time.time() - start_time
594
+ ret["__logs__"] = (_log_stream.getvalue() or "") + (_stdout_stream.getvalue() or "") + (_stderr_stream.getvalue() or "")
595
+
596
+ try:
597
+ from .step_state_manager import ProcessExecutionResult as _ProcessExec
598
+ if state_manager and ret:
599
+ ih, ch, fh = _compute_process_lookup_hashes_worker(state_manager, ctx, process_name, proc_payload.get('dependencies', []) or [], proc_payload.get('dependency_map') or {}, lookup_name=lookup_name)
600
+ enable_logging = proc_payload.get('logging', True) if isinstance(proc_payload, dict) else True
601
+ state_manager.record_process_completion(
602
+ run_id or 'default',
603
+ _ProcessExec(
604
+ process_name=process_name,
605
+ success=True,
606
+ result=ret,
607
+ execution_time=exec_time,
608
+ timestamp=datetime.now().isoformat(),
609
+ ),
610
+ input_hash=ih,
611
+ config_hash=ch,
612
+ function_hash=fh,
613
+ was_cached=False,
614
+ enable_logging=enable_logging,
615
+ )
616
+ except Exception:
617
+ pass
618
+
619
+ # If lightweight mode, return only minimal data (logs) to the driver to avoid large deserialization
620
+ _lightweight = isinstance(proc_payload, dict) and bool(proc_payload.get('lightweight_result'))
621
+ if _lightweight:
622
+ try:
623
+ logs_only = None
624
+ if isinstance(ret, dict) and '__logs__' in ret:
625
+ logs_only = {'__logs__': ret.get('__logs__')}
626
+ return ExecutionResult(name=process_name, result=logs_only, execution_time=exec_time, was_cached=False, error=None)
627
+ except Exception:
628
+ return ExecutionResult(name=process_name, result=None, execution_time=exec_time, was_cached=False, error=None)
629
+
630
+ return ExecutionResult(name=process_name, result=ret, execution_time=exec_time, was_cached=False, error=None)
631
+
632
+
633
+ def _run_chart_process_on_worker(ctx: Any, proc_payload: Dict[str, Any], run_id: Optional[str]) -> Dict[str, Any]:
634
+ name = proc_payload.get('name')
635
+ if not name:
636
+ logger.error("[Charts] No chart name in proc_payload")
637
+ return {'output_dir': '', 'artifact_count': 0}
638
+
639
+ chart_spec = proc_payload.get('chart_spec') or {}
640
+ entrypoint = chart_spec.get('entrypoint') or ''
641
+ reporting_python = chart_spec.get('reporting_python') or os.environ.get('MLOPS_REPORTING_PYTHON') or None
642
+ project_id = getattr(ctx, 'project_id', None) or os.environ.get('MLOPS_PROJECT_ID') or 'default'
643
+ rid = run_id or getattr(ctx, 'run_id', None) or os.environ.get('MLOPS_RUN_ID') or 'default'
644
+
645
+ logger.info(f"[Charts] Starting chart '{name}' for run {rid}, project {project_id}")
646
+
647
+ # Safeguard: mark process as running when the chart actually begins execution
648
+ try:
649
+ from .step_system import get_state_manager as _get_sm
650
+ sm = _get_sm()
651
+ except Exception:
652
+ sm = None
653
+ try:
654
+ if sm and rid and name:
655
+ already_running = False
656
+ try:
657
+ prev = sm.kv_store.list_run_steps(rid) if hasattr(sm, 'kv_store') else {}
658
+ rec = (prev or {}).get(f"{name}.__process__") or {}
659
+ status = str(rec.get('status') or '').lower()
660
+ already_running = status in ('running','completed','cached','failed')
661
+ except Exception:
662
+ already_running = False
663
+ if not already_running:
664
+ deps = proc_payload.get('dependencies', []) or []
665
+ dep_map = proc_payload.get('dependency_map') or {}
666
+ ih, ch, fh = _compute_process_lookup_hashes_worker(sm, ctx, name, deps, dep_map)
667
+ ch, fh = _apply_hash_overrides(proc_payload, ch, fh)
668
+ sm.record_process_started(
669
+ rid,
670
+ name,
671
+ input_hash=ih,
672
+ config_hash=ch,
673
+ function_hash=fh,
674
+ started_at=time.time(),
675
+ )
676
+ logger.debug(f"[Charts] Marked '{name}' running for run {rid}")
677
+ except Exception:
678
+ pass
679
+
680
+ # Build output dir under project artifacts (workspace-root based)
681
+ workspace_root = get_workspace_root()
682
+ projects_root = get_projects_root(workspace_root)
683
+ out_dir = projects_root / project_id / 'artifacts' / 'charts' / rid / name / time.strftime('%Y%m%d_%H%M%S')
684
+ out_dir.mkdir(parents=True, exist_ok=True)
685
+
686
+ # Prepare environment
687
+ env = os.environ.copy()
688
+ # Centralized env export for KV backend (best-effort).
689
+ try:
690
+ from mlops.runtime.env_export import export_kv_env
691
+
692
+ gc = getattr(ctx, "global_config", {}) if ctx else {}
693
+ cache_cfg = (gc.get("cache") or {}) if isinstance(gc, dict) else {}
694
+ if not cache_cfg and isinstance(gc, dict):
695
+ try:
696
+ cache_cfg = ((gc.get("model") or {}).get("parameters") or {}).get("cache") or {}
697
+ except Exception:
698
+ cache_cfg = {}
699
+ backend_cfg = (cache_cfg.get("backend") or {}) if isinstance(cache_cfg, dict) else {}
700
+ project_root = get_projects_root(get_workspace_root()) / str(project_id)
701
+ env.update(export_kv_env(backend_cfg if isinstance(backend_cfg, dict) else {}, workspace_root=workspace_root, project_root=project_root))
702
+ except Exception:
703
+ pass
704
+ env['MLOPS_PROJECT_ID'] = str(project_id)
705
+ env['MLOPS_RUN_ID'] = str(rid)
706
+ env['MLOPS_OUTPUT_DIR'] = str(out_dir)
707
+ env['MLOPS_CHART_NAME'] = str(name)
708
+ env['MLOPS_CHART_TYPE'] = 'static'
709
+
710
+ # Include probe_paths (chart-level overrides merged with global spec at driver)
711
+ try:
712
+ if chart_spec.get('probe_paths'):
713
+ env['MLOPS_PROBE_PATHS'] = json.dumps(chart_spec['probe_paths'])
714
+ except Exception:
715
+ pass
716
+
717
+ # Ensure PYTHONPATH contains src dir only for source checkouts (installed packages don't need it)
718
+ try:
719
+ src_root = infer_source_root()
720
+ if src_root and (src_root / "src").exists():
721
+ env['PYTHONPATH'] = f"{src_root / 'src'}:{env.get('PYTHONPATH', '')}".rstrip(":")
722
+ except Exception:
723
+ pass
724
+
725
+ # Always use framework entrypoint as a module, and import user script if provided
726
+ # This ensures chart functions are properly discovered and executed
727
+ if entrypoint:
728
+ project_root = projects_root / project_id
729
+ ep = resolve_relative_path(entrypoint, project_root=project_root, workspace_root=workspace_root)
730
+ # Set import path for user's chart script
731
+ env['MLOPS_CHART_IMPORT_FILES'] = str(ep)
732
+ logger.info(f"[Charts] Will import user script: {ep}")
733
+
734
+ # Build command - always use framework entrypoint as module
735
+ py = reporting_python or os.environ.get('MLOPS_RUNTIME_PYTHON') or 'python'
736
+ try:
737
+ if reporting_python:
738
+ env['MLOPS_REPORTING_PYTHON'] = str(reporting_python)
739
+ except Exception:
740
+ pass
741
+
742
+ # Check if Python interpreter exists and is executable
743
+ py_path = Path(py) if not py.startswith('python') else None
744
+ if py_path and not py_path.exists():
745
+ logger.error(f"[Charts] Python interpreter not found: {py}")
746
+ logger.warning(f"[Charts] Falling back to system python3")
747
+ py = 'python3'
748
+
749
+ # Always run as module to ensure proper initialization
750
+ cmd = [py, '-u', '-m', 'mlops.reporting.entrypoint', '--oneshot'] + list(chart_spec.get('args') or [])
751
+
752
+ # Run chart
753
+ import subprocess as _subprocess
754
+ stdout_log = out_dir / 'stdout.log'
755
+ stderr_log = out_dir / 'stderr.log'
756
+
757
+ logger.info(f"[Charts] Executing chart '{name}': {' '.join(cmd)}")
758
+ logger.info(f"[Charts] Output directory: {out_dir}")
759
+ logger.info(f"[Charts] Python interpreter: {py}")
760
+ logger.info(f"[Charts] Entrypoint: {entrypoint}")
761
+ logger.info(f"[Charts] MLOPS_CHART_NAME env: {env.get('MLOPS_CHART_NAME')}")
762
+ logger.info(f"[Charts] MLOPS_OUTPUT_DIR env: {env.get('MLOPS_OUTPUT_DIR')}")
763
+
764
+ with open(stdout_log, 'w', buffering=1) as out_f, open(stderr_log, 'w', buffering=1) as err_f:
765
+ # Write diagnostic info
766
+ err_f.write(f"=== Chart Execution Diagnostics ===\n")
767
+ err_f.write(f"Chart: {name}\n")
768
+ err_f.write(f"Python: {py}\n")
769
+ err_f.write(f"Command: {' '.join(cmd)}\n")
770
+ err_f.write(f"CWD: {workspace_root}\n")
771
+ err_f.write(f"Output dir: {out_dir}\n")
772
+ err_f.write(f"===================================\n\n")
773
+ err_f.flush()
774
+
775
+ result = _subprocess.run(cmd, env=env, check=False, stdout=out_f, stderr=err_f, cwd=str(workspace_root))
776
+
777
+ try:
778
+ returncode = int(getattr(result, "returncode", 0) or 0)
779
+ except Exception:
780
+ returncode = 0
781
+ logger.info(f"[Charts] Chart '{name}' execution completed with return code: {returncode}")
782
+
783
+ # Check if any PNGs were created
784
+ png_count = len(list(out_dir.rglob('*.png')))
785
+ logger.info(f"[Charts] Found {png_count} PNG file(s) in {out_dir}")
786
+
787
+ # Upload/record artifacts (best-effort).
788
+ try:
789
+ from .step_system import get_state_manager as _get_sm
790
+ sm = _get_sm()
791
+ except Exception:
792
+ sm = None
793
+
794
+ if sm is not None:
795
+ try:
796
+ _record_chart_artifacts(sm, str(project_id), str(rid), str(name), out_dir, chart_type="static")
797
+ except Exception as upload_exc:
798
+ logger.warning(f"[Charts] Artifact recording failed: {upload_exc}")
799
+ else:
800
+ logger.warning("[Charts] No state manager available - artifacts not recorded")
801
+
802
+ final_count = len(list(out_dir.rglob('*.png')))
803
+ logger.info(f"[Charts] Chart '{name}' complete. Output dir: {out_dir}, PNG count: {final_count}")
804
+
805
+ return {
806
+ 'output_dir': str(out_dir),
807
+ 'artifact_count': final_count,
808
+ 'returncode': returncode,
809
+ 'stdout_log': str(stdout_log),
810
+ 'stderr_log': str(stderr_log),
811
+ 'error_txt': str((out_dir / 'error.txt')),
812
+ }
813
+
814
+
815
+ def _return_placeholder_cached_process_execution_result(process_name: str) -> ExecutionResult:
816
+ return ExecutionResult(name=process_name, result=None, execution_time=0.0, was_cached=True, error=None)
817
+
818
+ def _return_placeholder_cached_process_execution_result_with_deps(process_name: str, dep_results: List[ExecutionResult]) -> ExecutionResult:
819
+ for dep in dep_results or []:
820
+ if dep and dep.error is not None:
821
+ return ExecutionResult(name=process_name, result=None, execution_time=0.0, was_cached=False, error=f"Dependency {dep.name} failed: {dep.error}")
822
+ return ExecutionResult(name=process_name, result=None, execution_time=0.0, was_cached=True, error=None)
823
+
824
+
825
+ def _worker_execute_step_task(step_name: str, process_name: Optional[str], context_arg: Any,
826
+ iteration: int = 0, run_id: Optional[str] = None) -> ExecutionResult:
827
+ from .step_system import get_step_registry, set_current_context
828
+ registry = get_step_registry()
829
+ step_def = registry.get_step(step_name)
830
+ if not step_def:
831
+ try:
832
+ if isinstance(context_arg, dict):
833
+ global_params = context_arg.get('global_config') or {}
834
+ else:
835
+ global_params = getattr(context_arg, 'global_config', {}) or {}
836
+ _maybe_import_custom_model_from_global_config(global_params)
837
+ step_def = registry.get_step(step_name)
838
+ except Exception:
839
+ step_def = registry.get_step(step_name)
840
+ if not step_def:
841
+ raise ValueError(f"Step '{step_name}' not found in registry (worker). Ensure the model module defines and registers it.")
842
+
843
+ try:
844
+ from dask.distributed import get_worker # type: ignore
845
+ _worker = get_worker()
846
+ _worker_addr = getattr(_worker, "address", "unknown")
847
+ except Exception:
848
+ _worker_addr = None
849
+ try:
850
+ import socket as _socket
851
+ _host = _socket.gethostname()
852
+ except Exception:
853
+ _host = "unknown"
854
+ logger.info(
855
+ f"[Distributed] Executing step '{step_name}' (process {process_name}, iter {iteration}) "
856
+ f"on worker={_worker_addr or 'n/a'} host={_host}"
857
+ )
858
+
859
+ start_time = time.time()
860
+
861
+ from .step_system import StepContext as _Ctx
862
+ if isinstance(context_arg, _Ctx):
863
+ ctx = context_arg
864
+ else:
865
+ if isinstance(context_arg, dict):
866
+ ctx = _build_step_context_from_payload(context_arg)
867
+ else:
868
+ try:
869
+ ctx = _Ctx(project_id=getattr(context_arg, 'project_id', 'default'))
870
+ except Exception:
871
+ ctx = _Ctx(project_id='default')
872
+
873
+ try:
874
+ _gc = getattr(ctx, 'global_config', {}) if ctx else {}
875
+ _pid = getattr(ctx, 'project_id', None)
876
+ _maybe_init_worker_state_manager(_gc, _pid)
877
+ except Exception as e:
878
+ logger.warning(f"[Distributed] Worker state manager init failed for step {step_name}: {e}")
879
+
880
+ set_current_context(ctx)
881
+ # Deterministic task-level seeding (step scope)
882
+ _seed_rng_for_task(run_id, process_name, step_name, iteration)
883
+ try:
884
+ from .step_system import get_current_state_manager as _get_sm
885
+ _sm = _get_sm()
886
+ except Exception:
887
+ _sm = None
888
+ try:
889
+ _proc_name = process_name or getattr(ctx, 'current_process', None)
890
+ if _sm and run_id and _proc_name and step_name:
891
+ _sm.record_step_started(run_id, _proc_name, step_name)
892
+ except Exception:
893
+ pass
894
+ try:
895
+ # Execute step without auto-parameter resolution; context is available via current context
896
+ # The step wrapper will inject context automatically if declared in the signature.
897
+ result = step_def.func()
898
+ finally:
899
+ set_current_context(None)
900
+
901
+ if not isinstance(result, dict):
902
+ raise ValueError(f"Step '{step_name}' must return a dictionary, got {type(result).__name__}.")
903
+ try:
904
+ def _json_safe(v: Any) -> Any:
905
+ import json as _json
906
+ from collections.abc import Mapping, Sequence
907
+ primitives = (str, int, float, bool, type(None))
908
+ if isinstance(v, primitives):
909
+ return v
910
+ if isinstance(v, Mapping):
911
+ return {str(k): _json_safe(val) for k, val in v.items()}
912
+ if isinstance(v, Sequence) and not isinstance(v, (str, bytes, bytearray)):
913
+ return [_json_safe(x) for x in v]
914
+ try:
915
+ _json.dumps(v)
916
+ return v
917
+ except Exception:
918
+ return str(v)
919
+ result = {k: _json_safe(v) for k, v in result.items()}
920
+ except Exception:
921
+ pass
922
+ exec_time = time.time() - start_time
923
+ if isinstance(result, dict):
924
+ result['__execution_time__'] = exec_time
925
+ ctx.step_results[step_name] = result
926
+ return ExecutionResult(name=step_name, result=result, execution_time=exec_time, was_cached=False, error=None)
927
+
928
+
929
+ def _worker_execute_step_with_deps(step_name: str, process_name: Optional[str], context_payload: dict,
930
+ dep_results: List[ExecutionResult], iteration: int = 0,
931
+ run_id: Optional[str] = None) -> ExecutionResult:
932
+ from .step_system import StepContext as _Ctx
933
+ ctx = _build_step_context_from_payload(context_payload) if isinstance(context_payload, dict) else _Ctx(project_id='default')
934
+ try:
935
+ setattr(ctx, 'current_process', process_name)
936
+ except Exception:
937
+ pass
938
+ for dep in dep_results:
939
+ if dep.error is None and dep.result:
940
+ ctx.step_results[dep.name] = dep.result
941
+ else:
942
+ raise RuntimeError(f"Dependency step {dep.name} failed: {dep.error}")
943
+ return _worker_execute_step_task(step_name, process_name, ctx, iteration, run_id)
944
+
945
+
946
+ def _create_worker_state_manager(cache_config: Dict[str, Any]):
947
+ from .step_state_manager import StepStateManager
948
+ from pathlib import Path
949
+ import os
950
+ import logging
951
+
952
+ try:
953
+ from mlops.storage.factory import create_kv_store as _create_kv_store, create_object_store as _create_obj_store
954
+ except Exception:
955
+ _create_kv_store = None # type: ignore[assignment]
956
+ _create_obj_store = None # type: ignore[assignment]
957
+
958
+ kv_cfg = cache_config.get("kv_store_config", {}) if isinstance(cache_config, dict) else {}
959
+ kv_store_type = str(cache_config.get("kv_store_type", "") or "")
960
+ project_id = None
961
+ try:
962
+ project_id = kv_cfg.get("project_id")
963
+ except Exception:
964
+ project_id = None
965
+ project_id = str(project_id or os.getenv("MLOPS_PROJECT_ID") or "default")
966
+
967
+ backend_cfg: Dict[str, Any] = {}
968
+ if "GCP" in kv_store_type or "Firestore" in kv_store_type:
969
+ backend_cfg = {
970
+ "type": "gcp",
971
+ "gcp_project": kv_cfg.get("gcp_project"),
972
+ "topic_name": kv_cfg.get("topic_name"),
973
+ "emulator_host": kv_cfg.get("emulator_host"),
974
+ }
975
+ elif "Redis" in kv_store_type:
976
+ backend_cfg = {
977
+ "type": "redis",
978
+ "host": kv_cfg.get("host"),
979
+ "port": kv_cfg.get("port"),
980
+ "db": kv_cfg.get("db"),
981
+ "password": kv_cfg.get("password"),
982
+ }
983
+ else:
984
+ backend_cfg = {"type": "memory"}
985
+
986
+ kv_store = None
987
+ if _create_kv_store:
988
+ try:
989
+ kv_store = _create_kv_store(project_id, backend_cfg, env=os.environ)
990
+ except Exception:
991
+ kv_store = None
992
+
993
+ if kv_store is None:
994
+ try:
995
+ from mlops.storage.adapters.memory_store import InMemoryStore # type: ignore
996
+ kv_store = InMemoryStore(project_id)
997
+ except Exception:
998
+ kv_store = None
999
+
1000
+ object_store = None
1001
+ if _create_obj_store:
1002
+ try:
1003
+ obj_cfg = cache_config.get("object_store_config", {}) if isinstance(cache_config, dict) else {}
1004
+ obj_type = str(cache_config.get("object_store_type", "") or "")
1005
+ cache_cfg = {}
1006
+ if "GCS" in obj_type:
1007
+ cache_cfg = {"object_store": {"type": "gcs", "bucket": obj_cfg.get("bucket"), "prefix": obj_cfg.get("prefix")}}
1008
+ object_store = _create_obj_store(cache_cfg, env=os.environ) if cache_cfg else None
1009
+ except Exception:
1010
+ object_store = None
1011
+
1012
+ if kv_store:
1013
+ cache_dir = Path(os.getenv('MLOPS_STEP_CACHE_DIR') or '/tmp/mlops-step-cache')
1014
+ return StepStateManager(
1015
+ cache_dir=cache_dir,
1016
+ kv_store=kv_store,
1017
+ logger=logging.getLogger(__name__),
1018
+ object_store=object_store
1019
+ )
1020
+ return None
1021
+
1022
+
1023
+ def _worker_execute_process_task(proc_payload: Dict[str, Any], context_payload: Dict[str, Any],
1024
+ run_id: Optional[str] = None) -> ExecutionResult:
1025
+ process_name = proc_payload.get('name')
1026
+ start_time = time.time()
1027
+
1028
+ from .step_system import StepContext as _Ctx
1029
+ try:
1030
+ if isinstance(context_payload, _Ctx):
1031
+ ctx = context_payload
1032
+ else:
1033
+ ctx = _build_step_context_from_payload(context_payload) if isinstance(context_payload, dict) else _Ctx(project_id='default')
1034
+ except Exception:
1035
+ ctx = _Ctx(project_id='default')
1036
+
1037
+ try:
1038
+ try:
1039
+ _gc2 = context_payload.get('global_config') if isinstance(context_payload, dict) else {}
1040
+ _pid2 = context_payload.get('project_id') if isinstance(context_payload, dict) else None
1041
+ _maybe_init_worker_state_manager(_gc2, _pid2)
1042
+ except Exception:
1043
+ pass
1044
+
1045
+ # Deterministic task-level seeding (process scope)
1046
+ _seed_rng_for_task(run_id, process_name, None, 0)
1047
+
1048
+ # Process start is now recorded inside _execute_process_on_worker at the exact timing start
1049
+
1050
+ return _execute_process_on_worker(ctx, proc_payload, run_id)
1051
+ except Exception as e:
1052
+ exec_time = time.time() - start_time
1053
+ error_result = {'__error_context__': str(e)}
1054
+ return ExecutionResult(name=process_name, result=error_result, execution_time=exec_time, was_cached=False, error=str(e))
1055
+
1056
+
1057
+ def _worker_execute_process_with_deps(proc_payload: Dict[str, Any], context_payload: Dict[str, Any],
1058
+ dep_results: List[ExecutionResult], run_id: Optional[str] = None) -> ExecutionResult:
1059
+ from .step_system import StepContext as _Ctx
1060
+ ctx = _build_step_context_from_payload(context_payload) if isinstance(context_payload, dict) else _Ctx(project_id='default')
1061
+ try:
1062
+ setattr(ctx, 'current_process', proc_payload.get('name'))
1063
+ except Exception:
1064
+ pass
1065
+ # Ensure a worker state manager exists (with object store when available) and custom model is imported
1066
+ try:
1067
+ from .step_system import get_state_manager as _get_sm, set_state_manager as _set_sm
1068
+ sm_existing = _get_sm()
1069
+ # Prefer cache_config-provisioned state manager when missing or when object_store is absent
1070
+ try:
1071
+ cfg = proc_payload.get('cache_config') if isinstance(proc_payload, dict) else None
1072
+ except Exception:
1073
+ cfg = None
1074
+ needs_obj_store = False
1075
+ if sm_existing is None:
1076
+ needs_obj_store = True
1077
+ else:
1078
+ try:
1079
+ needs_obj_store = getattr(sm_existing, 'object_store', None) is None
1080
+ except Exception:
1081
+ needs_obj_store = True
1082
+ if cfg and needs_obj_store:
1083
+ try:
1084
+ sm_new = _create_worker_state_manager(cfg)
1085
+ if sm_new is not None:
1086
+ _set_sm(sm_new)
1087
+ except Exception:
1088
+ pass
1089
+ # Import custom model on the worker so process/step registries are populated for hashing
1090
+ try:
1091
+ _maybe_import_custom_model_from_global_config(getattr(ctx, 'global_config', {}) or {})
1092
+ except Exception:
1093
+ pass
1094
+ except Exception:
1095
+ pass
1096
+ for dep in dep_results:
1097
+ if dep.error is not None:
1098
+ return ExecutionResult(name=proc_payload.get('name'), result=None, execution_time=0.0, was_cached=False, error=f"Dependency {dep.name} failed: {dep.error}")
1099
+ if not getattr(dep, 'was_cached', False) and not dep.result:
1100
+ return ExecutionResult(name=proc_payload.get('name'), result=None, execution_time=0.0, was_cached=False, error=f"Dependency {dep.name} failed: {dep.error}")
1101
+ try:
1102
+ if dep.result is not None:
1103
+ ctx.step_results[dep.name] = _strip_internal_keys(dep.result)
1104
+ else:
1105
+ # Hydrate cached dependency placeholder on worker using state manager
1106
+ try:
1107
+ from .step_system import get_state_manager as _get_sm
1108
+ sm = _get_sm()
1109
+ except Exception:
1110
+ sm = None
1111
+ if sm is not None:
1112
+ try:
1113
+ # Compute hashes for the dependency using the same worker helper
1114
+ deps_for_dep = []
1115
+ try:
1116
+ dep_map = proc_payload.get('dependency_map') or {}
1117
+ deps_for_dep = (dep_map or {}).get(dep.name, [])
1118
+ except Exception:
1119
+ deps_for_dep = []
1120
+ ih, ch, fh = _compute_process_lookup_hashes_worker(sm, ctx, dep.name, deps_for_dep, proc_payload.get('dependency_map') or {})
1121
+ except Exception:
1122
+ ih = ch = fh = None
1123
+ loaded = None
1124
+ # Try hash-based lookup first
1125
+ try:
1126
+ if hasattr(sm, 'get_cached_process_result_with_metadata'):
1127
+ data = sm.get_cached_process_result_with_metadata(dep.name, input_hash=ih, config_hash=ch, function_hash=fh)
1128
+ if data is not None:
1129
+ loaded, _, _ = data
1130
+ except Exception:
1131
+ loaded = None
1132
+ # Fallback: if context_payload carried a cache_path alias
1133
+ if loaded is None:
1134
+ try:
1135
+ cache_hint = None
1136
+ try:
1137
+ s = (context_payload.get('step_results') or {}).get(dep.name, {}) if isinstance(context_payload, dict) else {}
1138
+ cache_hint = s.get('cache_path') if isinstance(s, dict) else None
1139
+ except Exception:
1140
+ cache_hint = None
1141
+ if cache_hint and hasattr(sm, 'load_process_result_from_path'):
1142
+ loaded = sm.load_process_result_from_path(cache_hint)
1143
+ except Exception:
1144
+ loaded = None
1145
+ if isinstance(loaded, dict):
1146
+ try:
1147
+ ctx.step_results[dep.name] = _strip_internal_keys(loaded)
1148
+ except Exception:
1149
+ ctx.step_results[dep.name] = loaded
1150
+ inner = {}
1151
+ try:
1152
+ inner = dep.result.get('__step_results__', {}) if isinstance(dep.result, dict) else {}
1153
+ except Exception:
1154
+ inner = {}
1155
+ if isinstance(inner, dict):
1156
+ ctx.step_results.update(inner)
1157
+ except Exception:
1158
+ continue
1159
+ return _worker_execute_process_task(proc_payload, ctx, run_id)
1160
+
1161
+
1162
+ def _maybe_init_worker_state_manager(global_config: Any, project_id: Optional[str]) -> None:
1163
+ try:
1164
+ from .step_system import get_state_manager as _get_sm, set_state_manager as _set_sm
1165
+ sm_existing = _get_sm()
1166
+ if sm_existing is not None:
1167
+ return
1168
+ import os as _os
1169
+ import logging as _logging
1170
+ logger = _logging.getLogger(__name__)
1171
+ gcp_creds = _os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
1172
+ gcp_project = _os.getenv('GOOGLE_CLOUD_PROJECT')
1173
+ logger.info(
1174
+ f"[Worker] Initializing state manager with GOOGLE_APPLICATION_CREDENTIALS={'SET' if gcp_creds else 'UNSET'}, "
1175
+ f"GOOGLE_CLOUD_PROJECT={gcp_project or 'UNSET'}"
1176
+ )
1177
+ # Prefer top-level cache; fallback to nested model.parameters.cache for backward compatibility
1178
+ cache_cfg = (global_config.get('cache') or {}) if isinstance(global_config, dict) else {}
1179
+ if not cache_cfg and isinstance(global_config, dict):
1180
+ try:
1181
+ cache_cfg = (global_config.get('model') or {}).get('parameters', {}).get('cache') or {}
1182
+ except Exception:
1183
+ cache_cfg = {}
1184
+ backend_cfg = cache_cfg.get('backend') if isinstance(cache_cfg, dict) else {}
1185
+ # Derive missing GCP env from backend config and project layout
1186
+ try:
1187
+ from pathlib import Path as _Path
1188
+ creds_rel = (backend_cfg or {}).get('credentials_json')
1189
+ if creds_rel and not _os.getenv('GOOGLE_APPLICATION_CREDENTIALS'):
1190
+ repo_root = get_workspace_root()
1191
+ pid_effective = project_id or _os.getenv('MLOPS_PROJECT_ID') or 'default'
1192
+ cred_path = (get_projects_root(repo_root) / str(pid_effective) / str(creds_rel)).resolve()
1193
+ if cred_path.exists():
1194
+ _os.environ.setdefault('GOOGLE_APPLICATION_CREDENTIALS', str(cred_path))
1195
+ if (backend_cfg or {}).get('gcp_project') and not _os.getenv('GOOGLE_CLOUD_PROJECT'):
1196
+ _os.environ.setdefault('GOOGLE_CLOUD_PROJECT', str(backend_cfg.get('gcp_project')))
1197
+ except Exception:
1198
+ pass
1199
+ pid_effective = project_id or _os.getenv('MLOPS_PROJECT_ID') or 'default'
1200
+ backend_type = (backend_cfg.get('type') if isinstance(backend_cfg, dict) else None) or _os.getenv('MLOPS_KV_BACKEND') or 'memory'
1201
+ logger.info(
1202
+ f"[Worker] KV backend selection -> MLOPS_KV_BACKEND={_os.getenv('MLOPS_KV_BACKEND') or 'unset'}, "
1203
+ f"resolved={backend_type}, project_ns={pid_effective}"
1204
+ )
1205
+ try:
1206
+ from mlops.storage.factory import create_kv_store as _create_kv_store, create_object_store as _create_obj_store
1207
+ ws_root = get_workspace_root()
1208
+ proj_root = get_projects_root(ws_root) / str(pid_effective)
1209
+ kv_store = _create_kv_store(
1210
+ str(pid_effective),
1211
+ backend_cfg if isinstance(backend_cfg, dict) else {},
1212
+ env=_os.environ,
1213
+ workspace_root=ws_root,
1214
+ project_root=proj_root,
1215
+ )
1216
+ obj_store = _create_obj_store(cache_cfg if isinstance(cache_cfg, dict) else {}, env=_os.environ)
1217
+ obj_prefix = None
1218
+ except Exception:
1219
+ from mlops.storage.adapters.memory_store import InMemoryStore
1220
+ kv_store = InMemoryStore(str(pid_effective))
1221
+ obj_store = None
1222
+ obj_prefix = None
1223
+ from pathlib import Path as _Path
1224
+ cache_dir = _Path(_os.getenv('MLOPS_STEP_CACHE_DIR') or '/tmp/mlops-step-cache')
1225
+ from .step_state_manager import StepStateManager as _SSM
1226
+ try:
1227
+ ttl_val = int(((cache_cfg or {}).get('ttl_hours') if isinstance(cache_cfg, dict) else 24) or 24)
1228
+ except Exception:
1229
+ ttl_val = 24
1230
+ sm_new = _SSM(cache_dir=cache_dir, kv_store=kv_store, logger=logging.getLogger(__name__), cache_ttl_hours=ttl_val, object_store=obj_store, object_prefix=obj_prefix)
1231
+ logger.info(
1232
+ f"[Worker] StateManager created -> kv_store={type(kv_store).__name__ if kv_store else 'None'}, "
1233
+ f"object_store={type(obj_store).__name__ if obj_store else 'None'}"
1234
+ )
1235
+ _set_sm(sm_new)
1236
+ except Exception:
1237
+ return
1238
+
1239
+