expops 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. expops-0.1.3.dist-info/METADATA +826 -0
  2. expops-0.1.3.dist-info/RECORD +86 -0
  3. expops-0.1.3.dist-info/WHEEL +5 -0
  4. expops-0.1.3.dist-info/entry_points.txt +3 -0
  5. expops-0.1.3.dist-info/licenses/LICENSE +674 -0
  6. expops-0.1.3.dist-info/top_level.txt +1 -0
  7. mlops/__init__.py +0 -0
  8. mlops/__main__.py +11 -0
  9. mlops/_version.py +34 -0
  10. mlops/adapters/__init__.py +12 -0
  11. mlops/adapters/base.py +86 -0
  12. mlops/adapters/config_schema.py +89 -0
  13. mlops/adapters/custom/__init__.py +3 -0
  14. mlops/adapters/custom/custom_adapter.py +447 -0
  15. mlops/adapters/plugin_manager.py +113 -0
  16. mlops/adapters/sklearn/__init__.py +3 -0
  17. mlops/adapters/sklearn/adapter.py +94 -0
  18. mlops/cluster/__init__.py +3 -0
  19. mlops/cluster/controller.py +496 -0
  20. mlops/cluster/process_runner.py +91 -0
  21. mlops/cluster/providers.py +258 -0
  22. mlops/core/__init__.py +95 -0
  23. mlops/core/custom_model_base.py +38 -0
  24. mlops/core/dask_networkx_executor.py +1265 -0
  25. mlops/core/executor_worker.py +1239 -0
  26. mlops/core/experiment_tracker.py +81 -0
  27. mlops/core/graph_types.py +64 -0
  28. mlops/core/networkx_parser.py +135 -0
  29. mlops/core/payload_spill.py +278 -0
  30. mlops/core/pipeline_utils.py +162 -0
  31. mlops/core/process_hashing.py +216 -0
  32. mlops/core/step_state_manager.py +1298 -0
  33. mlops/core/step_system.py +956 -0
  34. mlops/core/workspace.py +99 -0
  35. mlops/environment/__init__.py +10 -0
  36. mlops/environment/base.py +43 -0
  37. mlops/environment/conda_manager.py +307 -0
  38. mlops/environment/factory.py +70 -0
  39. mlops/environment/pyenv_manager.py +146 -0
  40. mlops/environment/setup_env.py +31 -0
  41. mlops/environment/system_manager.py +66 -0
  42. mlops/environment/utils.py +105 -0
  43. mlops/environment/venv_manager.py +134 -0
  44. mlops/main.py +527 -0
  45. mlops/managers/project_manager.py +400 -0
  46. mlops/managers/reproducibility_manager.py +575 -0
  47. mlops/platform.py +996 -0
  48. mlops/reporting/__init__.py +16 -0
  49. mlops/reporting/context.py +187 -0
  50. mlops/reporting/entrypoint.py +292 -0
  51. mlops/reporting/kv_utils.py +77 -0
  52. mlops/reporting/registry.py +50 -0
  53. mlops/runtime/__init__.py +9 -0
  54. mlops/runtime/context.py +34 -0
  55. mlops/runtime/env_export.py +113 -0
  56. mlops/storage/__init__.py +12 -0
  57. mlops/storage/adapters/__init__.py +9 -0
  58. mlops/storage/adapters/gcp_kv_store.py +778 -0
  59. mlops/storage/adapters/gcs_object_store.py +96 -0
  60. mlops/storage/adapters/memory_store.py +240 -0
  61. mlops/storage/adapters/redis_store.py +438 -0
  62. mlops/storage/factory.py +199 -0
  63. mlops/storage/interfaces/__init__.py +6 -0
  64. mlops/storage/interfaces/kv_store.py +118 -0
  65. mlops/storage/path_utils.py +38 -0
  66. mlops/templates/premier-league/charts/plot_metrics.js +70 -0
  67. mlops/templates/premier-league/charts/plot_metrics.py +145 -0
  68. mlops/templates/premier-league/charts/requirements.txt +6 -0
  69. mlops/templates/premier-league/configs/cluster_config.yaml +13 -0
  70. mlops/templates/premier-league/configs/project_config.yaml +207 -0
  71. mlops/templates/premier-league/data/England CSV.csv +12154 -0
  72. mlops/templates/premier-league/models/premier_league_model.py +638 -0
  73. mlops/templates/premier-league/requirements.txt +8 -0
  74. mlops/templates/sklearn-basic/README.md +22 -0
  75. mlops/templates/sklearn-basic/charts/plot_metrics.py +85 -0
  76. mlops/templates/sklearn-basic/charts/requirements.txt +3 -0
  77. mlops/templates/sklearn-basic/configs/project_config.yaml +64 -0
  78. mlops/templates/sklearn-basic/data/train.csv +14 -0
  79. mlops/templates/sklearn-basic/models/model.py +62 -0
  80. mlops/templates/sklearn-basic/requirements.txt +10 -0
  81. mlops/web/__init__.py +3 -0
  82. mlops/web/server.py +585 -0
  83. mlops/web/ui/index.html +52 -0
  84. mlops/web/ui/mlops-charts.js +357 -0
  85. mlops/web/ui/script.js +1244 -0
  86. mlops/web/ui/styles.css +248 -0
@@ -0,0 +1,1265 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ import hashlib
5
+ import json
6
+ import logging
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import networkx as nx
12
+
13
+ from dask import compute
14
+ from dask.delayed import delayed
15
+
16
+ from .graph_types import ExecutionResult, NetworkXGraphConfig, NodeType
17
+ from .step_state_manager import StepStateManager
18
+ from .step_system import StepContext, StepDefinition, StepRegistry
19
+
20
+ from .executor_worker import (
21
+ _return_placeholder_cached_process_execution_result,
22
+ _return_placeholder_cached_process_execution_result_with_deps,
23
+ _worker_execute_process_task,
24
+ _worker_execute_process_with_deps,
25
+ )
26
+
27
+
28
+ class DaskNetworkXExecutor:
29
+ """
30
+ Execute NetworkX DAGs with Dask (threads or distributed) and integrated caching.
31
+ """
32
+
33
+ @staticmethod
34
+ def _flatten_dask_overrides(overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
35
+ flat: Dict[str, Any] = {}
36
+ if not isinstance(overrides, dict):
37
+ return flat
38
+ for key, value in overrides.items():
39
+ if isinstance(value, dict):
40
+ nested = DaskNetworkXExecutor._flatten_dask_overrides(value)
41
+ for sub_key, sub_val in nested.items():
42
+ flat_key = f"{key}.{sub_key}" if key else sub_key
43
+ flat[flat_key] = sub_val
44
+ else:
45
+ flat[str(key)] = value
46
+ return flat
47
+
48
+ def __init__(self, step_registry: StepRegistry, state_manager: Optional[StepStateManager] = None,
49
+ logger: Optional[logging.Logger] = None,
50
+ n_workers: int = 2,
51
+ scheduler_mode: str = "threads",
52
+ scheduler_address: Optional[str] = None,
53
+ client: Any = None,
54
+ extra_files_to_upload: Optional[List[str]] = None,
55
+ strict_cache: bool = True,
56
+ min_workers: Optional[int] = None,
57
+ wait_for_workers_sec: Optional[float] = None,
58
+ dask_config_overrides: Optional[Dict[str, Any]] = None):
59
+ self.step_registry = step_registry
60
+ self.state_manager = state_manager
61
+ self.logger = logger or logging.getLogger(__name__)
62
+ self.cache_enabled = True
63
+ self.n_workers = n_workers
64
+ self.scheduler_mode = scheduler_mode
65
+ self._scheduler_address = scheduler_address
66
+ self._distributed_client = client
67
+ self.strict_cache = bool(strict_cache)
68
+ try:
69
+ self._extra_upload_files: List[str] = list(extra_files_to_upload) if extra_files_to_upload else []
70
+ except Exception:
71
+ self._extra_upload_files = []
72
+ self._min_workers_override = min_workers if isinstance(min_workers, int) and min_workers > 0 else None
73
+ self._wait_for_workers_override = float(wait_for_workers_sec) if isinstance(wait_for_workers_sec, (int, float)) and wait_for_workers_sec > 0 else None
74
+ self._dask_config_overrides = self._flatten_dask_overrides(dask_config_overrides)
75
+
76
+ def _prepare_dask_overrides(self) -> Dict[str, Any]:
77
+ overrides = dict(self._dask_config_overrides)
78
+ # Only apply env defaults when not explicitly configured.
79
+ if "distributed.comm.compression" not in overrides:
80
+ overrides["distributed.comm.compression"] = (
81
+ os.environ.get("DASK_DISTRIBUTED__COMM__COMPRESSION")
82
+ or "zlib"
83
+ )
84
+ return overrides
85
+
86
+ @staticmethod
87
+ def _normalize_bootstrap_mode(value: Optional[str]) -> str:
88
+ """Normalize the bootstrap mode string.
89
+
90
+ Supported values:
91
+ - "auto" (default): only bootstrap when workers/scheduler can't import `mlops`
92
+ - "always": always bootstrap (upload zip, sys.path, cleanup)
93
+ - "never": never bootstrap `mlops` package code (still may upload extra files)
94
+ """
95
+ try:
96
+ v = (value or "").strip().lower()
97
+ except Exception:
98
+ v = ""
99
+ if v in ("1", "true", "yes", "y", "always", "on"):
100
+ return "always"
101
+ if v in ("0", "false", "no", "n", "never", "off"):
102
+ return "never"
103
+ return "auto"
104
+
105
+ @staticmethod
106
+ def _worker_apply_dask_config(overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
107
+ """Apply Dask config overrides in the current process (worker or scheduler)."""
108
+ applied: Dict[str, Any] = {}
109
+ try:
110
+ from dask import config as _dask_config
111
+ except Exception:
112
+ return {"error": "config_import_failed"}
113
+ if isinstance(overrides, dict):
114
+ for key, value in overrides.items():
115
+ try:
116
+ _dask_config.set({key: value})
117
+ applied[key] = _dask_config.get(key)
118
+ except Exception:
119
+ applied[key] = value
120
+ return applied
121
+
122
+ @staticmethod
123
+ def _worker_try_import(module_name: str) -> bool:
124
+ try:
125
+ __import__(str(module_name))
126
+ return True
127
+ except Exception:
128
+ return False
129
+
130
+ @staticmethod
131
+ def _worker_set_env_vars(env_vars: Dict[str, Any]) -> bool:
132
+ import os as _os
133
+ try:
134
+ for key, value in (env_vars or {}).items():
135
+ if value is not None:
136
+ _os.environ[str(key)] = str(value)
137
+ return True
138
+ except Exception:
139
+ return False
140
+
141
+ @staticmethod
142
+ def _worker_ensure_sys_path(paths: List[str]) -> Dict[str, bool]:
143
+ import sys as _sys
144
+ results: Dict[str, bool] = {}
145
+ for p in paths or []:
146
+ try:
147
+ p_str = str(p)
148
+ except Exception:
149
+ continue
150
+ if p_str and p_str not in _sys.path:
151
+ _sys.path.insert(0, p_str)
152
+ results[p_str] = p_str in _sys.path
153
+ return results
154
+
155
+ @staticmethod
156
+ def _worker_cleanup_import_state(patterns: List[str], files_to_clean: List[str], clean_mlops: bool = True) -> Dict[str, Any]:
157
+ """Best-effort cleanup to avoid stale imports between runs on long-lived worker processes."""
158
+ import sys as _sys
159
+ import re as _re
160
+ import importlib as _importlib
161
+ import os as _os
162
+
163
+ removed_paths: List[str] = []
164
+ removed_files: List[str] = []
165
+ removed_modules: List[str] = []
166
+
167
+ # Remove sys.path entries matching known uploaded artifacts (zip names, temp dirs, etc.)
168
+ new_path: List[str] = []
169
+ for p in list(_sys.path):
170
+ try:
171
+ p_str = str(p or "")
172
+ except Exception:
173
+ p_str = ""
174
+ try:
175
+ matched = any(_re.search(ptn, p_str) for ptn in (patterns or []))
176
+ except Exception:
177
+ matched = False
178
+ if matched:
179
+ removed_paths.append(p_str)
180
+ continue
181
+ if p not in new_path:
182
+ new_path.append(p)
183
+ _sys.path[:] = new_path
184
+
185
+ # Optionally evict previously loaded `mlops` modules to force a clean import.
186
+ if clean_mlops:
187
+ for name in list(_sys.modules.keys()):
188
+ if name == "mlops" or name.startswith("mlops."):
189
+ removed_modules.append(name)
190
+ _sys.modules.pop(name, None)
191
+
192
+ # Aggressively remove uploaded custom model modules from sys.modules and common upload locations.
193
+ for fname in (files_to_clean or []):
194
+ try:
195
+ base_fname = _os.path.basename(str(fname))
196
+ if not base_fname.endswith(".py"):
197
+ continue
198
+ mod_name = base_fname[:-3]
199
+ if mod_name in _sys.modules:
200
+ removed_modules.append(mod_name)
201
+ _sys.modules.pop(mod_name, None)
202
+ # Also clear our canonical fallback module name.
203
+ if "custom_model" in _sys.modules:
204
+ removed_modules.append("custom_model")
205
+ _sys.modules.pop("custom_model", None)
206
+
207
+ # Try to find and remove file from common upload locations (best-effort).
208
+ search_paths = [
209
+ _os.path.expanduser("~"),
210
+ _os.getcwd(),
211
+ ]
212
+ for search_dir in search_paths:
213
+ try:
214
+ worker_file_path = _os.path.join(search_dir, base_fname)
215
+ if _os.path.exists(worker_file_path):
216
+ _os.remove(worker_file_path)
217
+ removed_files.append(worker_file_path)
218
+ except Exception:
219
+ pass
220
+ except Exception:
221
+ pass
222
+
223
+ try:
224
+ _importlib.invalidate_caches()
225
+ except Exception:
226
+ pass
227
+ return {
228
+ "removed_paths": removed_paths,
229
+ "removed_modules": removed_modules,
230
+ "removed_files": removed_files,
231
+ }
232
+
233
+ def _ensure_connected_to_scheduler(self, force: bool = False) -> None:
234
+ """Connect to an external Dask scheduler if requested; otherwise use threads.
235
+
236
+ Note: For many real clusters, workers already have `mlops` installed or share a filesystem.
237
+ In that case, this method will skip the heavy "upload source zip / sys.path surgery" work
238
+ by default (bootstrap mode = auto).
239
+ """
240
+ if self.scheduler_mode != 'distributed':
241
+ return
242
+ if self._distributed_client is not None and not force:
243
+ return
244
+ scheduler_addr = self._scheduler_address or os.environ.get('DASK_SCHEDULER_ADDRESS')
245
+ # If a client already exists (e.g., user passed one in, or restart created one),
246
+ # prefer its known scheduler address as a fallback.
247
+ if not scheduler_addr and self._distributed_client is not None:
248
+ try:
249
+ scheduler_addr = getattr(getattr(self._distributed_client, "scheduler", None), "address", None)
250
+ except Exception:
251
+ scheduler_addr = None
252
+ if not scheduler_addr:
253
+ self.logger.warning(
254
+ "Distributed scheduler requested but no Client or address provided. Falling back to threads."
255
+ )
256
+ self.scheduler_mode = 'threads'
257
+ return
258
+ try:
259
+ # Ensure we have a client connection.
260
+ if self._distributed_client is None:
261
+ try:
262
+ from distributed import Client
263
+ except Exception:
264
+ from dask.distributed import Client
265
+ self._distributed_client = Client(scheduler_addr)
266
+ self.logger.info(f"Connected to Dask scheduler at {scheduler_addr}")
267
+
268
+ self._configure_distributed_runtime(scheduler_addr)
269
+ except Exception as e:
270
+ self.logger.warning(
271
+ f"Failed to connect/configure scheduler at {scheduler_addr} ({e}). Falling back to threads."
272
+ )
273
+ self.scheduler_mode = 'threads'
274
+
275
+ def _configure_distributed_runtime(self, scheduler_addr: str) -> None:
276
+ """Best-effort cluster setup after a Client is connected."""
277
+ if not self._distributed_client:
278
+ raise RuntimeError("Distributed client is not initialized")
279
+
280
+ client = self._distributed_client
281
+ config_overrides = self._prepare_dask_overrides()
282
+
283
+ # Push config to scheduler first (so it is applied even before workers join).
284
+ try:
285
+ sched_conf = client.run_on_scheduler(self._worker_apply_dask_config, config_overrides)
286
+ self.logger.info(f"Scheduler Dask config applied: {sched_conf}")
287
+ except Exception as e:
288
+ self.logger.warning(f"Failed to push Dask config to scheduler: {e}")
289
+
290
+ # Resolve worker wait settings.
291
+ min_workers_env = os.environ.get('MLOPS_DASK_MIN_WORKERS', '').strip()
292
+ timeout_env = os.environ.get('MLOPS_DASK_WAIT_FOR_WORKERS_SEC', '').strip()
293
+ if self._min_workers_override:
294
+ min_workers = self._min_workers_override
295
+ elif min_workers_env.isdigit():
296
+ min_workers = int(min_workers_env)
297
+ else:
298
+ min_workers = self.n_workers if isinstance(self.n_workers, int) and self.n_workers > 0 else 1
299
+ if min_workers < 1:
300
+ min_workers = 1
301
+ if self._wait_for_workers_override:
302
+ wait_timeout = self._wait_for_workers_override
303
+ elif timeout_env:
304
+ try:
305
+ wait_timeout = float(timeout_env)
306
+ except Exception:
307
+ wait_timeout = 30.0
308
+ else:
309
+ wait_timeout = 30.0
310
+
311
+ # Wait for workers to connect (best-effort; verify count afterwards).
312
+ try:
313
+ self.logger.info(f"Waiting for at least {min_workers} worker(s) to connect (timeout={wait_timeout}s)")
314
+ client.wait_for_workers(min_workers, timeout=wait_timeout) # type: ignore[attr-defined]
315
+ except Exception:
316
+ pass
317
+
318
+ try:
319
+ info = client.scheduler_info()
320
+ workers_dict = info.get('workers', {}) if isinstance(info, dict) else {}
321
+ worker_count = len(workers_dict)
322
+ except Exception:
323
+ worker_count = 0
324
+
325
+ if worker_count < min_workers:
326
+ raise RuntimeError(
327
+ f"Connected to Dask scheduler at {scheduler_addr} but only {worker_count} worker(s) "
328
+ f"available after waiting {wait_timeout}s. Ensure workers are started and connected."
329
+ )
330
+ self.logger.info(f"Workers connected: {worker_count}")
331
+
332
+ # After workers connect, push config to workers as well.
333
+ try:
334
+ workers_conf = client.run(self._worker_apply_dask_config, config_overrides)
335
+ self.logger.info(f"Workers Dask config applied: {workers_conf}")
336
+ except Exception as e:
337
+ self.logger.warning(f"Failed to push Dask config to workers: {e}")
338
+
339
+ # Bootstrap code/env on long-lived worker processes (auto-skip when not needed).
340
+ self._bootstrap_distributed_imports_and_env()
341
+
342
+ def _bootstrap_distributed_imports_and_env(self) -> None:
343
+ if not self._distributed_client:
344
+ return
345
+
346
+ client = self._distributed_client
347
+ bootstrap_mode = self._normalize_bootstrap_mode(os.environ.get("MLOPS_DASK_BOOTSTRAP", "auto"))
348
+
349
+ # Decide whether we need to ship the `mlops` package code to workers.
350
+ needs_mlops_bootstrap = bootstrap_mode == "always"
351
+ if bootstrap_mode == "auto":
352
+ try:
353
+ import_ok_workers = client.run(self._worker_try_import, "mlops")
354
+ needs_mlops_bootstrap = not all(bool(v) for v in import_ok_workers.values())
355
+ except Exception:
356
+ needs_mlops_bootstrap = True
357
+ try:
358
+ import_ok_sched = client.run_on_scheduler(self._worker_try_import, "mlops")
359
+ needs_mlops_bootstrap = needs_mlops_bootstrap or not bool(import_ok_sched)
360
+ except Exception:
361
+ needs_mlops_bootstrap = True
362
+ elif bootstrap_mode == "never":
363
+ needs_mlops_bootstrap = False
364
+
365
+ # Cleanup import state only when we are about to upload something (mlops zip and/or extra files).
366
+ if needs_mlops_bootstrap or self._extra_upload_files:
367
+ try:
368
+ cleanup_patterns = [r"mlops_src.*\\.zip", r"mlops_src_"]
369
+ files_to_clean = self._extra_upload_files if self._extra_upload_files else []
370
+ try:
371
+ cleaned_workers = client.run(
372
+ self._worker_cleanup_import_state,
373
+ cleanup_patterns,
374
+ files_to_clean,
375
+ needs_mlops_bootstrap,
376
+ )
377
+ self.logger.info(f"Cleaned worker import state: {cleaned_workers}")
378
+ except Exception as e:
379
+ self.logger.warning(f"Failed to clean workers import state: {e}")
380
+ try:
381
+ cleaned_sched = client.run_on_scheduler(
382
+ self._worker_cleanup_import_state,
383
+ cleanup_patterns,
384
+ files_to_clean,
385
+ needs_mlops_bootstrap,
386
+ )
387
+ self.logger.info(f"Cleaned scheduler import state: {cleaned_sched}")
388
+ except Exception as e:
389
+ self.logger.warning(f"Failed to clean scheduler import state: {e}")
390
+ except Exception as e:
391
+ self.logger.warning(f"Cleanup of previous uploaded code failed: {e}")
392
+
393
+ # Ship mlops code only when required.
394
+ if needs_mlops_bootstrap:
395
+ try:
396
+ from pathlib import Path as _Path
397
+ import tempfile
398
+ import zipfile
399
+
400
+ # Package just the `mlops/` python package, not an entire repo root or site-packages.
401
+ mlops_pkg_dir = _Path(__file__).resolve().parents[1] # .../mlops
402
+ pkg_parent = mlops_pkg_dir.parent # .../src (source) or .../site-packages (installed)
403
+
404
+ zip_path = _Path(tempfile.gettempdir()) / "mlops_pkg.zip"
405
+ with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
406
+ for path in mlops_pkg_dir.rglob('*.py'):
407
+ try:
408
+ # Keep `mlops/...` as the top-level folder inside the zip.
409
+ zf.write(path, arcname=str(path.relative_to(pkg_parent)))
410
+ except Exception:
411
+ pass
412
+ client.upload_file(str(zip_path))
413
+ self.logger.info(f"Uploaded mlops source package to workers: {zip_path}")
414
+
415
+ # Ensure the zip and repo paths are importable (best-effort; may be a no-op on real clusters).
416
+ paths_to_add: List[str] = []
417
+ try:
418
+ paths_to_add.append(str(zip_path))
419
+ except Exception:
420
+ pass
421
+ try:
422
+ added = client.run(self._worker_ensure_sys_path, paths_to_add)
423
+ self.logger.info(f"Adjusted sys.path on workers for importability: {added}")
424
+ except Exception as e:
425
+ self.logger.warning(f"Failed to adjust worker sys.path: {e}")
426
+ try:
427
+ added_sched = client.run_on_scheduler(self._worker_ensure_sys_path, paths_to_add)
428
+ self.logger.info(f"Adjusted sys.path on scheduler for importability: {added_sched}")
429
+ except Exception as e:
430
+ self.logger.warning(f"Failed to adjust scheduler sys.path: {e}")
431
+
432
+ # Validate import on workers (warn-only).
433
+ try:
434
+ import_ok = client.run(self._worker_try_import, "mlops")
435
+ if isinstance(import_ok, dict) and not all(bool(v) for v in import_ok.values()):
436
+ self.logger.warning(
437
+ "One or more workers cannot import 'mlops'. Ensure shared filesystem or install package on workers."
438
+ )
439
+ except Exception:
440
+ pass
441
+ except Exception as e:
442
+ self.logger.warning(f"Failed to package/upload mlops code to workers: {e}")
443
+
444
+ # Upload any additional files requested (e.g., custom model script, reporting entrypoint).
445
+ if self._extra_upload_files:
446
+ for _f in self._extra_upload_files:
447
+ try:
448
+ if _f and os.path.exists(_f):
449
+ load_flag = True
450
+ try:
451
+ norm = str(_f).replace("\\", "/")
452
+ if norm.endswith(".py") and "/charts/" in norm:
453
+ load_flag = False
454
+ except Exception:
455
+ load_flag = True
456
+ try:
457
+ client.upload_file(str(_f), load=load_flag)
458
+ except TypeError:
459
+ # Backward-compatible fallback (older clients without load kwarg)
460
+ client.upload_file(str(_f))
461
+ if load_flag:
462
+ self.logger.info(f"Uploaded extra file to workers: {_f}")
463
+ else:
464
+ self.logger.info(f"Uploaded extra file to workers (load=False): {_f}")
465
+ except Exception as e:
466
+ self.logger.warning(f"Failed to upload extra file '{_f}' to workers: {e}")
467
+
468
+ # Propagate critical environment variables to workers (best-effort).
469
+ try:
470
+ env_to_propagate: Dict[str, str] = {}
471
+ critical_env_vars = [
472
+ 'MLOPS_PROJECT_ID',
473
+ 'MLOPS_WORKSPACE_DIR',
474
+ 'GOOGLE_APPLICATION_CREDENTIALS',
475
+ 'GOOGLE_CLOUD_PROJECT',
476
+ 'FIRESTORE_EMULATOR_HOST',
477
+ 'DASK_DISTRIBUTED__COMM__COMPRESSION',
478
+ 'MLOPS_RUNTIME_PYTHON',
479
+ 'MLOPS_REPORTING_PYTHON',
480
+ 'MLOPS_REPORTING_CONFIG',
481
+ 'MLOPS_RANDOM_SEED',
482
+ 'MLOPS_TASK_LEVEL_SEEDING',
483
+ ]
484
+ for env_var in critical_env_vars:
485
+ value = os.environ.get(env_var)
486
+ if value is not None:
487
+ env_to_propagate[env_var] = value
488
+ if env_to_propagate:
489
+ client.run(self._worker_set_env_vars, env_to_propagate)
490
+ client.run_on_scheduler(self._worker_set_env_vars, env_to_propagate)
491
+ self.logger.info(f"Propagated environment variables: {list(env_to_propagate.keys())}")
492
+ except Exception as e:
493
+ self.logger.warning(f"Failed to propagate environment variables to workers: {e}")
494
+
495
+ def set_cache_enabled(self, enabled: bool) -> None:
496
+ """Enable or disable step caching."""
497
+ self.cache_enabled = enabled
498
+
499
+ def restart_distributed_client(self) -> None:
500
+ """Restart the distributed client connection to ensure clean state between runs.
501
+
502
+ This is a more aggressive cleanup than the automatic cleanup performed in
503
+ _ensure_connected_to_scheduler. Use this if you're experiencing persistent
504
+ deserialization errors between consecutive runs.
505
+ """
506
+ if self.scheduler_mode != 'distributed' or not self._distributed_client:
507
+ return
508
+
509
+ try:
510
+ scheduler_addr = self._distributed_client.scheduler.address
511
+ self.logger.info(f"Restarting Dask client connection to {scheduler_addr}")
512
+
513
+ # Close the existing client
514
+ try:
515
+ self._distributed_client.close()
516
+ except Exception as e:
517
+ self.logger.warning(f"Error closing existing client: {e}")
518
+
519
+ # Create a new client
520
+ try:
521
+ from distributed import Client # type: ignore
522
+ except Exception:
523
+ from dask.distributed import Client # type: ignore
524
+
525
+ self._distributed_client = Client(scheduler_addr)
526
+ self.logger.info(f"Successfully restarted Dask client connection")
527
+
528
+ # Re-run the setup (compression, worker checks, etc.)
529
+ self._ensure_connected_to_scheduler(force=True)
530
+ except Exception as e:
531
+ self.logger.error(f"Failed to restart Dask client: {e}")
532
+ raise
533
+
534
+ def _build_process_graph(self, config: NetworkXGraphConfig) -> nx.DiGraph:
535
+ """Build the main process-level DAG."""
536
+ process_graph = nx.DiGraph()
537
+
538
+ for process_config in config.processes:
539
+ process_graph.add_node(
540
+ process_config.name,
541
+ type=NodeType.PROCESS,
542
+ config=process_config
543
+ )
544
+
545
+ for process_config in config.processes:
546
+ depends_on = getattr(process_config, 'depends_on', [])
547
+ for dependency in depends_on:
548
+ process_graph.add_edge(dependency, process_config.name)
549
+
550
+ return process_graph
551
+
552
+ def _validate_process_dag(self, process_graph: nx.DiGraph) -> None:
553
+ """Validate that the process-level graph is a DAG."""
554
+ if not nx.is_directed_acyclic_graph(process_graph):
555
+ cycles = list(nx.simple_cycles(process_graph))
556
+ raise ValueError(f"Process-level graph contains cycles: {cycles}")
557
+
558
+ # Removed unused step-config hashing; step-level caching uses stable context hash instead
559
+
560
+ def _compute_step_input_hash(self, step_def: StepDefinition, context: StepContext) -> str:
561
+ """Compute a stable step input hash without parameter resolution.
562
+ Uses available context surface to approximate variability.
563
+ """
564
+ if not self.state_manager:
565
+ return ""
566
+ try:
567
+ context_data = {
568
+ 'step': getattr(step_def, 'name', None),
569
+ 'process': getattr(context, 'current_process', None),
570
+ 'step_results_keys': sorted(list((getattr(context, 'step_results', {}) or {}).keys())),
571
+ 'iteration': getattr(context, 'iteration', 0),
572
+ }
573
+ return self.state_manager._compute_hash(context_data)
574
+ except Exception as e:
575
+ try:
576
+ self.logger.warning(f"Failed to compute context-based step hash: {e}")
577
+ except Exception:
578
+ pass
579
+ return ""
580
+
581
+ # Removed unused Dask step task; steps execute via the step wrapper inside process runners
582
+
583
+ def _serialize_context_for_worker(self, context: StepContext) -> dict:
584
+ """Serialize minimal context payload with only primitives for safe graph shipping."""
585
+ try:
586
+ data_paths = {k: str(v) for k, v in (context.data_paths or {}).items()}
587
+ except Exception:
588
+ data_paths = {}
589
+ # Sanitize step_results: keep only JSON-serializable fields and avoid heavy objects
590
+ sanitized_results: Dict[str, Dict[str, Any]] = {}
591
+ for process_name, result in (context.step_results or {}).items():
592
+ try:
593
+ # Copy data and drop recursive/heavy fields (nested step maps, logs, large artifact pointers)
594
+ _data = dict(result) if isinstance(result, dict) else {}
595
+ _common_drop_keys = ['__step_results__', '__logs__', 'checkpoint_path', 'cache_path', 'artifacts']
596
+ _heavy_drop_keys = ['model', 'saved_model'] if str(self.scheduler_mode) == 'distributed' else []
597
+ for _k in (_common_drop_keys + _heavy_drop_keys):
598
+ try:
599
+ _data.pop(_k, None)
600
+ except Exception:
601
+ pass
602
+ # Best-effort: attach cache path so workers can rehydrate dependencies locally without shipping heavy objects
603
+ _cache_path = None
604
+ try:
605
+ if self.state_manager:
606
+ ih, ch, fh = self._compute_process_lookup_hashes(context, process_name)
607
+ # Guard against missing kv_store in rare cases
608
+ _kvs = getattr(self.state_manager, 'kv_store', None)
609
+ if _kvs and hasattr(_kvs, 'get_process_cache_path'):
610
+ _cache_path = _kvs.get_process_cache_path(process_name, ih, ch, fh)
611
+ except Exception:
612
+ _cache_path = None
613
+ # Keep the same shape as original result surface for dependency injection.
614
+ # Attach cache path as a top-level meta key to enable rehydration when needed.
615
+ _data['cache_path'] = _cache_path
616
+ sanitized_results[process_name] = _data
617
+ except Exception:
618
+ continue
619
+ return {
620
+ 'project_id': context.project_id,
621
+ 'run_id': context.run_id,
622
+ 'global_config': dict(getattr(context, 'global_config', {}) or {}),
623
+ 'data_paths': data_paths,
624
+ 'checkpoint_dir': str(getattr(context, 'checkpoint_dir', 'artifacts/checkpoints')),
625
+ 'step_results': sanitized_results,
626
+ }
627
+
628
+ def _compute_process_lookup_hashes(self, context: StepContext, process_name: str) -> tuple:
629
+ """Compute (ih, ch, fh) via shared helper to keep driver/worker in lockstep."""
630
+ try:
631
+ from .process_hashing import compute_process_hashes
632
+ except Exception:
633
+ compute_process_hashes = None
634
+
635
+ # Build deterministic dependency_map from the process graph
636
+ dependency_map = {}
637
+ try:
638
+ if hasattr(self, 'process_graph'):
639
+ for n in list(self.process_graph.nodes):
640
+ try:
641
+ preds = list(self.process_graph.predecessors(n))
642
+ preds = sorted(set(preds))
643
+ dependency_map[n] = preds
644
+ except Exception:
645
+ dependency_map[n] = []
646
+ except Exception:
647
+ dependency_map = {}
648
+
649
+ if compute_process_hashes and self.state_manager:
650
+ # Use code_function mapping when available to resolve the correct process definition
651
+ try:
652
+ lookup_name = self._get_lookup_name(process_name) or process_name
653
+ except Exception:
654
+ lookup_name = process_name
655
+ ih, ch, fh = compute_process_hashes(
656
+ self.state_manager,
657
+ context,
658
+ process_name,
659
+ dependency_map,
660
+ lookup_name=lookup_name,
661
+ )
662
+ else:
663
+ ih = ch = fh = None
664
+
665
+ self.logger.debug(f"[HashTrace] side=driver process={process_name} ih={ih} ch={ch} fh={fh}")
666
+
667
+ return (ih, ch, fh)
668
+
669
+ def _get_cache_config_for_worker(self) -> Dict[str, Any]:
670
+ """Extract cache configuration for worker state manager creation."""
671
+ try:
672
+ if self.state_manager is None:
673
+ return {}
674
+
675
+ cfg: Dict[str, Any] = {}
676
+
677
+ kv_store = getattr(self.state_manager, "kv_store", None)
678
+ if kv_store is not None:
679
+ kv_type = type(kv_store).__name__
680
+ cfg["kv_store_type"] = kv_type
681
+ if "GCP" in kv_type or "Firestore" in kv_type:
682
+ cfg["kv_store_config"] = {
683
+ "project_id": getattr(kv_store, "project_id", None),
684
+ "gcp_project": getattr(kv_store, "gcp_project", None),
685
+ "topic_name": getattr(kv_store, "topic_name", None),
686
+ "emulator_host": getattr(kv_store, "_emulator_host", None),
687
+ }
688
+
689
+ obj_store = getattr(self.state_manager, "object_store", None)
690
+ if obj_store is not None:
691
+ obj_type = type(obj_store).__name__
692
+ cfg["object_store_type"] = obj_type
693
+ if "GCS" in obj_type:
694
+ bucket_name = None
695
+ try:
696
+ bucket_obj = getattr(obj_store, "_bucket", None)
697
+ bucket_name = getattr(bucket_obj, "name", None) if bucket_obj is not None else None
698
+ except Exception:
699
+ bucket_name = None
700
+ cfg["object_store_config"] = {
701
+ "bucket": bucket_name,
702
+ "prefix": getattr(obj_store, "_prefix", None),
703
+ }
704
+
705
+ return cfg
706
+ except Exception:
707
+ return {}
708
+
709
+ def _get_lookup_name(self, proc_name: str) -> Optional[str]:
710
+ try:
711
+ if hasattr(self, 'process_graph') and self.process_graph.has_node(proc_name):
712
+ cfg = self.process_graph.nodes[proc_name].get('config')
713
+ return getattr(cfg, 'code_function', None)
714
+ except Exception:
715
+ return None
716
+ return None
717
+
718
+ def _get_logging_flag(self, proc_name: str) -> bool:
719
+ try:
720
+ from .step_system import get_process_registry # type: ignore
721
+ pr = get_process_registry()
722
+ pdef = pr.get_process(self._get_lookup_name(proc_name) or proc_name) if pr else None
723
+ return getattr(pdef, 'logging', True) if pdef else True
724
+ except Exception:
725
+ return True
726
+
727
+ def _repo_root(self) -> Path:
728
+ # Legacy: many code paths historically treated this as the repo root.
729
+ # For installed packages, use the workspace root (where projects/ lives).
730
+ try:
731
+ from .workspace import get_workspace_root
732
+ return get_workspace_root()
733
+ except Exception:
734
+ return Path.cwd()
735
+
736
+ def _resolve_path(self, p: str) -> Path:
737
+ try:
738
+ path = Path(p)
739
+ if path.is_absolute():
740
+ return path
741
+ return self._repo_root() / p
742
+ except Exception:
743
+ return Path(p)
744
+
745
+ def _get_reporting_cfg(self) -> dict:
746
+ try:
747
+ txt = os.environ.get('MLOPS_REPORTING_CONFIG') or ''
748
+ if not txt:
749
+ return {}
750
+ return json.loads(txt)
751
+ except Exception:
752
+ return {}
753
+
754
+ def _get_chart_spec(self, name: str) -> dict:
755
+ rcfg = self._get_reporting_cfg()
756
+ charts = rcfg.get('charts') or []
757
+ for item in charts:
758
+ try:
759
+ if isinstance(item, dict) and str(item.get('name')) == name and str(item.get('type', 'static')).lower() != 'dynamic':
760
+ return item
761
+ except Exception:
762
+ continue
763
+ return {}
764
+
765
+ def _compute_chart_function_hash(self, static_entrypoint: str) -> Optional[str]:
766
+ try:
767
+ p = self._resolve_path(static_entrypoint)
768
+ with open(p, 'rb') as f:
769
+ data = f.read()
770
+ return hashlib.sha256(data).hexdigest()
771
+ except Exception:
772
+ return None
773
+
774
+ def _compute_chart_config_hash(self, name: str) -> Optional[str]:
775
+ try:
776
+ rcfg = self._get_reporting_cfg()
777
+ global_args = rcfg.get('args') or []
778
+ theme = os.environ.get('MLOPS_CHART_THEME')
779
+ spec = self._get_chart_spec(name)
780
+ cfg_payload = {
781
+ 'name': name,
782
+ 'probe_paths': spec.get('probe_paths') or {},
783
+ 'chart_args': spec.get('args') or [],
784
+ 'global_args': global_args,
785
+ 'theme': theme,
786
+ }
787
+ if self.state_manager:
788
+ return self.state_manager._compute_hash(cfg_payload)
789
+ else:
790
+ payload = json.dumps(cfg_payload, sort_keys=True, separators=(",", ":")).encode()
791
+ return hashlib.sha256(payload).hexdigest()
792
+ except Exception:
793
+ return None
794
+
795
+ def _maybe_apply_chart_hash_overrides(
796
+ self,
797
+ process_name: str,
798
+ config_hash: Optional[str],
799
+ function_hash: Optional[str],
800
+ ) -> tuple[Optional[str], Optional[str]]:
801
+ """Override hashes for chart nodes so cache keys track chart config + entrypoint content."""
802
+ try:
803
+ if not hasattr(self, "process_graph") or not self.process_graph.has_node(process_name):
804
+ return config_hash, function_hash
805
+ cfg = self.process_graph.nodes[process_name].get("config")
806
+ if getattr(cfg, "process_type", "process") != "chart":
807
+ return config_hash, function_hash
808
+ rcfg = self._get_reporting_cfg()
809
+ entrypoint = rcfg.get("static_entrypoint") or rcfg.get("entrypoint") or ""
810
+ ch_override = self._compute_chart_config_hash(process_name)
811
+ fh_override = self._compute_chart_function_hash(entrypoint) if entrypoint else None
812
+ return ch_override or config_hash, fh_override or function_hash
813
+ except Exception:
814
+ return config_hash, function_hash
815
+
816
+ def execute_graph(
817
+ self,
818
+ config: NetworkXGraphConfig,
819
+ context: StepContext,
820
+ run_id: Optional[str] = None,
821
+ resume_from_process: Optional[str] = None,
822
+ stop_after_process: bool | str = False,
823
+ ) -> Dict[str, ExecutionResult]:
824
+ """
825
+ Execute the NetworkX-based graph using Dask's advanced scheduler.
826
+
827
+ This is the main entry point that replaces the manual scheduling approach
828
+ with Dask's task scheduler for automatic dependency resolution and parallelization.
829
+ """
830
+
831
+ process_graph = self._build_process_graph(config)
832
+ # Persist the main process graph so hashing can access per-process code_function
833
+ self.process_graph = process_graph
834
+ self._validate_process_dag(process_graph)
835
+
836
+ if self.state_manager and run_id and not stop_after_process:
837
+ self.state_manager.start_pipeline_execution(run_id, config.__dict__, self.cache_enabled)
838
+
839
+ try:
840
+ failure_mode_cfg = None
841
+ try:
842
+ failure_mode_cfg = (config.execution or {}).get("failure_mode")
843
+ except Exception:
844
+ failure_mode_cfg = None
845
+ except Exception:
846
+ pass
847
+
848
+ self.logger.info(f"Executing {len(process_graph.nodes)} processes with Dask scheduler")
849
+ self.logger.info(f"Process execution order will be determined by Dask: {list(nx.topological_sort(process_graph))}")
850
+ if self.scheduler_mode == 'distributed':
851
+ self.logger.info("Using distributed scheduler (external Dask Client)")
852
+ self._ensure_connected_to_scheduler()
853
+ # Wire distributed client into step_system so @step calls submit to workers
854
+ try:
855
+ from .step_system import set_distributed_client as _set_dc
856
+ _set_dc(self._distributed_client)
857
+ except Exception:
858
+ pass
859
+ else:
860
+ self.logger.info(f"Using threaded scheduler with {self.n_workers} workers")
861
+ # Ensure no distributed client is set in threaded mode
862
+ try:
863
+ from .step_system import set_distributed_client as _set_dc # type: ignore
864
+ _set_dc(None)
865
+ except Exception:
866
+ pass
867
+
868
+ execution_results = {}
869
+
870
+ try:
871
+ is_distributed = self.scheduler_mode == 'distributed'
872
+ if is_distributed:
873
+ self._ensure_connected_to_scheduler()
874
+ process_tasks: Dict[str, Any] = {}
875
+
876
+ topo_order = list(nx.topological_sort(process_graph))
877
+
878
+ # Limit scheduling for resume/single-process execution modes.
879
+ stop_target: Optional[str] = None
880
+ if isinstance(stop_after_process, str) and stop_after_process.strip():
881
+ stop_target = stop_after_process.strip()
882
+ elif stop_after_process and resume_from_process:
883
+ stop_target = str(resume_from_process)
884
+
885
+ targets: set[str] = set(topo_order)
886
+ if stop_target and stop_target in process_graph:
887
+ targets = {stop_target}
888
+ elif resume_from_process and resume_from_process in process_graph:
889
+ # Resume from a given process and continue downstream.
890
+ targets = set(nx.descendants(process_graph, resume_from_process))
891
+ targets.add(resume_from_process)
892
+
893
+ required_nodes: set[str] = set()
894
+ for n in targets:
895
+ required_nodes.add(n)
896
+ try:
897
+ required_nodes.update(nx.ancestors(process_graph, n))
898
+ except Exception:
899
+ pass
900
+
901
+ nodes_order = [n for n in topo_order if n in required_nodes]
902
+
903
+ for process_name in nodes_order:
904
+ process_config = process_graph.nodes[process_name]['config']
905
+ dependencies = list(process_graph.predecessors(process_name))
906
+ dep_tasks = [process_tasks[dep] for dep in dependencies if dep in process_tasks]
907
+
908
+ cached_task_created = False
909
+
910
+ if self.cache_enabled and self.state_manager:
911
+ # Unified enhanced hashing for cache lookup
912
+ try:
913
+ process_input_hash, process_config_hash, composite_fhash = self._compute_process_lookup_hashes(context, process_name)
914
+ process_config_hash, composite_fhash = self._maybe_apply_chart_hash_overrides(
915
+ process_name, process_config_hash, composite_fhash
916
+ )
917
+ except Exception:
918
+ process_input_hash = process_config_hash = composite_fhash = None
919
+
920
+ if composite_fhash is None:
921
+ try:
922
+ from .step_system import get_process_registry
923
+ pr = get_process_registry()
924
+ pdef = pr.get_process(self._get_lookup_name(process_name) or process_name) if pr else None
925
+ orig_fn = getattr(pdef, 'original_func', None) if pdef else None
926
+ composite_fhash = self.state_manager._compute_function_hash(orig_fn or getattr(pdef, 'runner', None)) if pdef else None
927
+ except Exception:
928
+ composite_fhash = None
929
+
930
+ try:
931
+ self.logger.info(
932
+ f"[CACHE] Lookup process={process_name} ih={process_input_hash} ch={process_config_hash} fh={composite_fhash}"
933
+ )
934
+ cached_data = self.state_manager.get_cached_process_result_with_metadata(
935
+ process_name,
936
+ input_hash=process_input_hash,
937
+ config_hash=process_config_hash,
938
+ function_hash=composite_fhash,
939
+ )
940
+ except Exception:
941
+ cached_data = None
942
+
943
+ if not is_distributed and cached_data is None:
944
+ try:
945
+ cached_data = self.state_manager.get_cached_process_result_with_metadata(process_name)
946
+ except Exception:
947
+ cached_data = None
948
+
949
+ # Extract result and metadata from cached data
950
+ cached_proc = None
951
+ cached_run_id = None
952
+ cached_metadata = {}
953
+ if cached_data is not None:
954
+ try:
955
+ cached_proc, cached_run_id, cached_metadata = cached_data
956
+ except Exception:
957
+ # Fallback if metadata extraction fails
958
+ cached_proc = cached_data if not isinstance(cached_data, tuple) else None
959
+
960
+ if cached_proc is not None:
961
+ self.logger.info(f"[CACHE] Hit for process={process_name}; scheduling placeholder")
962
+ if dep_tasks:
963
+ task = delayed(_return_placeholder_cached_process_execution_result_with_deps)(process_name, dep_tasks)
964
+ else:
965
+ task = delayed(_return_placeholder_cached_process_execution_result)(process_name)
966
+ process_tasks[process_name] = task
967
+ try:
968
+ context.step_results[process_name] = cached_proc
969
+ except Exception:
970
+ pass
971
+ cached_task_created = True
972
+
973
+ # Pre-record a 'cached' completion so the UI reflects cache hits immediately
974
+ # rather than after the compute() phase returns.
975
+ try:
976
+ if self.state_manager:
977
+ from .step_state_manager import ProcessExecutionResult as _ProcessExec # local import to avoid import cycles
978
+ enable_logging = self._get_logging_flag(process_name)
979
+ # Extract cached metadata for frontend display
980
+ cached_exec_time = cached_metadata.get('execution_time', 0.0) if isinstance(cached_metadata, dict) else 0.0
981
+ cached_started = cached_metadata.get('started_at') if isinstance(cached_metadata, dict) else None
982
+ cached_ended = cached_metadata.get('ended_at') if isinstance(cached_metadata, dict) else None
983
+ self.state_manager.record_process_completion(
984
+ run_id or 'default',
985
+ _ProcessExec(
986
+ process_name=process_name,
987
+ success=True,
988
+ result=cached_proc,
989
+ execution_time=0.0,
990
+ timestamp=datetime.now().isoformat(),
991
+ ),
992
+ input_hash=process_input_hash,
993
+ config_hash=process_config_hash,
994
+ function_hash=composite_fhash,
995
+ was_cached=True,
996
+ enable_logging=enable_logging,
997
+ cached_run_id=cached_run_id,
998
+ cached_started_at=cached_started,
999
+ cached_ended_at=cached_ended,
1000
+ cached_execution_time=cached_exec_time,
1001
+ )
1002
+ except Exception:
1003
+ # Best-effort; if this fails we'll still write completion in the post-process phase
1004
+ pass
1005
+
1006
+ if not cached_task_created:
1007
+ self.logger.info(f"[CACHE] Miss or not loadable for process={process_name}; scheduling execution")
1008
+ ctx_payload = self._serialize_context_for_worker(context)
1009
+ # Provide dependency map for recursive signature hashing on workers
1010
+ try:
1011
+ dependency_map = {n: list(process_graph.predecessors(n)) for n in nodes_order}
1012
+ except Exception:
1013
+ dependency_map = {n: [] for n in nodes_order}
1014
+ proc_payload = {
1015
+ 'name': process_config.name,
1016
+ 'code_function': getattr(process_config, 'code_function', None),
1017
+ 'process_type': getattr(process_config, 'process_type', 'process'),
1018
+ 'cache_config': self._get_cache_config_for_worker() if self.cache_enabled and self.state_manager else None,
1019
+ 'has_state_manager': bool(self.state_manager and self.cache_enabled),
1020
+ 'logging': self._get_logging_flag(process_name),
1021
+ 'dependencies': dependencies,
1022
+ 'dependency_map': dependency_map,
1023
+ }
1024
+
1025
+ # Attach chart metadata and precomputed hashes for chart nodes
1026
+ try:
1027
+ if getattr(process_config, 'process_type', 'process') == 'chart':
1028
+ rcfg = self._get_reporting_cfg()
1029
+ entrypoint = rcfg.get('static_entrypoint') or rcfg.get('entrypoint') or ''
1030
+ function_hash_override = self._compute_chart_function_hash(entrypoint) if entrypoint else None
1031
+ config_hash_override = self._compute_chart_config_hash(process_config.name)
1032
+ spec = self._get_chart_spec(process_config.name)
1033
+ reporting_py = rcfg.get('reporting_python') or os.environ.get('MLOPS_REPORTING_PYTHON') or None
1034
+ chart_spec = {
1035
+ 'name': process_config.name,
1036
+ 'probe_paths': (spec.get('probe_paths') or {}),
1037
+ 'args': list(rcfg.get('args') or []) + list(spec.get('args') or []),
1038
+ 'theme': os.environ.get('MLOPS_CHART_THEME'),
1039
+ 'entrypoint': entrypoint,
1040
+ 'reporting_python': reporting_py,
1041
+ }
1042
+ proc_payload['hash_overrides'] = {
1043
+ 'function_hash': function_hash_override,
1044
+ 'config_hash': config_hash_override,
1045
+ }
1046
+ proc_payload['chart_spec'] = chart_spec
1047
+ except Exception:
1048
+ pass
1049
+ if dep_tasks:
1050
+ task = delayed(_worker_execute_process_with_deps)(proc_payload, ctx_payload, dep_tasks, run_id)
1051
+ else:
1052
+ task = delayed(_worker_execute_process_task)(proc_payload, ctx_payload, run_id)
1053
+ process_tasks[process_name] = task
1054
+
1055
+ if process_tasks:
1056
+ # Execute according to scheduler mode
1057
+ if is_distributed:
1058
+ try:
1059
+ futures = self._distributed_client.compute(list(process_tasks.values()))
1060
+ results_values = self._distributed_client.gather(futures)
1061
+ except Exception:
1062
+ results_values = compute(*process_tasks.values())
1063
+ else:
1064
+ results_values = compute(*process_tasks.values(), scheduler='threads', num_workers=self.n_workers)
1065
+
1066
+ proc_results = dict(zip(process_tasks.keys(), results_values))
1067
+ execution_results.update(proc_results)
1068
+
1069
+ # Surface worker-side logs in both modes
1070
+ try:
1071
+ for _pname, _pres in proc_results.items():
1072
+ try:
1073
+ _r = getattr(_pres, 'result', None)
1074
+ if isinstance(_r, dict) and '__logs__' in _r and _r['__logs__']:
1075
+ _logs = _r['__logs__']
1076
+ try:
1077
+ self.logger.info(f"[WorkerLogs][{_pname}] BEGIN")
1078
+ for _line in str(_logs).splitlines():
1079
+ self.logger.info(f"[{_pname}] {_line}")
1080
+ self.logger.info(f"[WorkerLogs][{_pname}] END")
1081
+ except Exception:
1082
+ pass
1083
+ except Exception:
1084
+ continue
1085
+ except Exception:
1086
+ pass
1087
+
1088
+ # Post-process results: errors, context hydration, cache recording
1089
+ for process_name, result in proc_results.items():
1090
+ # Rehydrate placeholders
1091
+ try:
1092
+ if getattr(result, 'was_cached', False) and not getattr(result, 'result', None):
1093
+ if self.cache_enabled and self.state_manager:
1094
+ try:
1095
+ ih, ch, fh = self._compute_process_lookup_hashes(context, process_name)
1096
+ ch, fh = self._maybe_apply_chart_hash_overrides(process_name, ch, fh)
1097
+ loaded = self.state_manager.get_cached_process_result(process_name, input_hash=ih, config_hash=ch, function_hash=fh)
1098
+ except Exception:
1099
+ loaded = None
1100
+ if loaded is not None:
1101
+ result.result = loaded
1102
+ except Exception:
1103
+ pass
1104
+
1105
+ if result.error is not None:
1106
+ # NOTE: Historically we supported a "stop_and_resume" mode. In practice, users
1107
+ # resume work by re-running and leveraging cache hits (process + step cache).
1108
+ # Treat stop_and_resume as deprecated and equivalent to a simple stop-on-failure.
1109
+ try:
1110
+ failure_mode = str((config.execution or {}).get("failure_mode", "stop") or "stop").strip().lower()
1111
+ except Exception:
1112
+ failure_mode = "stop"
1113
+ if failure_mode == "stop_and_resume":
1114
+ try:
1115
+ self.logger.warning(
1116
+ "Deprecated failure_mode 'stop_and_resume' encountered; treating as 'stop'. "
1117
+ "Re-run to reuse cached results."
1118
+ )
1119
+ except Exception:
1120
+ pass
1121
+ if self.state_manager and run_id and not stop_after_process:
1122
+ self.state_manager.complete_pipeline_execution(run_id, False)
1123
+ raise RuntimeError(f"Process {process_name} failed: {result.error}")
1124
+
1125
+ # Update driver context with clean result
1126
+ if result.error is None and getattr(result, 'result', None) is not None:
1127
+ try:
1128
+ clean_result = {k: v for k, v in result.result.items() if not k.startswith('__')} if isinstance(result.result, dict) else result.result
1129
+ context.step_results[process_name] = clean_result
1130
+ except Exception:
1131
+ pass
1132
+
1133
+ if self.cache_enabled and self.state_manager:
1134
+ try:
1135
+ from .step_state_manager import ProcessExecutionResult as _ProcessExec
1136
+ ih, ch, fh = None, None, None
1137
+ try:
1138
+ ih, ch, fh = self._compute_process_lookup_hashes(context, process_name)
1139
+ ch, fh = self._maybe_apply_chart_hash_overrides(process_name, ch, fh)
1140
+ self.logger.debug(f"[CACHE WRITE] process={process_name} ih={ih} ch={ch} fh={fh}")
1141
+ except Exception as e:
1142
+ self.logger.warning(f"[CACHE WRITE] Failed to compute hashes for {process_name}: {e}")
1143
+ enable_logging = self._get_logging_flag(process_name)
1144
+
1145
+ # Determine success based on error field
1146
+ is_success = result.error is None
1147
+
1148
+ self.state_manager.record_process_completion(
1149
+ run_id or 'default',
1150
+ _ProcessExec(
1151
+ process_name=process_name,
1152
+ success=is_success,
1153
+ result=result.result if is_success else None,
1154
+ execution_time=result.execution_time,
1155
+ timestamp=datetime.now().isoformat(),
1156
+ ),
1157
+ input_hash=ih,
1158
+ config_hash=ch,
1159
+ function_hash=fh,
1160
+ was_cached=bool(getattr(result, 'was_cached', False)),
1161
+ enable_logging=enable_logging,
1162
+ )
1163
+ except Exception as e:
1164
+ self.logger.warning(f"❌ Failed to record process completion for {process_name}: {e}")
1165
+
1166
+ # Step-level cache recording - only if process succeeded
1167
+ if is_success:
1168
+ try:
1169
+ sr_map = result.result.get('__step_results__', {}) if isinstance(result.result, dict) else {}
1170
+ if isinstance(sr_map, dict) and sr_map:
1171
+ from .step_state_manager import StepExecutionResult as _StepExec
1172
+ for _sname, _sres in sr_map.items():
1173
+ try:
1174
+ step_def = self.step_registry.get_step(_sname)
1175
+ if not step_def:
1176
+ continue
1177
+ tmp_ctx = None
1178
+ try:
1179
+ from .step_system import StepContext as _Ctx
1180
+ tmp_ctx = _Ctx(
1181
+ project_id=getattr(context, 'project_id', None),
1182
+ run_id=run_id or getattr(context, 'run_id', None),
1183
+ tracker=None,
1184
+ step_results=sr_map,
1185
+ global_config=getattr(context, 'global_config', {}) or {},
1186
+ data_paths=getattr(context, 'data_paths', {}) or {},
1187
+ checkpoint_dir=getattr(context, 'checkpoint_dir', None),
1188
+ )
1189
+ try:
1190
+ tmp_ctx.current_process = process_name # type: ignore[attr-defined]
1191
+ except Exception:
1192
+ pass
1193
+ except Exception:
1194
+ tmp_ctx = context
1195
+ input_hash = self._compute_step_input_hash(step_def, tmp_ctx)
1196
+ try:
1197
+ function_hash = self.state_manager._compute_function_hash(step_def.original_func)
1198
+ except Exception:
1199
+ function_hash = None
1200
+ try:
1201
+ config_source = getattr(context, 'global_config', None) or {}
1202
+ config_hash = self.state_manager._compute_hash(config_source)
1203
+ except Exception:
1204
+ config_hash = None
1205
+ enable_logging = getattr(step_def, 'logging', True) if step_def else True
1206
+ # Extract execution time from step result metadata
1207
+ step_exec_time = 0.0
1208
+ if isinstance(_sres, dict):
1209
+ step_exec_time = float(_sres.get('__execution_time__', 0.0))
1210
+ self.state_manager.record_step_completion(
1211
+ run_id or 'default',
1212
+ _StepExec(
1213
+ step_name=_sname,
1214
+ success=True,
1215
+ result=_sres,
1216
+ execution_time=step_exec_time,
1217
+ timestamp=datetime.now().isoformat(),
1218
+ ),
1219
+ input_hash=input_hash,
1220
+ config_hash=config_hash,
1221
+ function_name=_sname,
1222
+ function_hash=function_hash,
1223
+ was_cached=bool(_sres.get('__was_cached__')) if isinstance(_sres, dict) else False,
1224
+ process_name=process_name,
1225
+ enable_logging=enable_logging,
1226
+ )
1227
+ except Exception:
1228
+ continue
1229
+ except Exception:
1230
+ pass
1231
+
1232
+ # Log per-step cached hits summary
1233
+ try:
1234
+ sr_map2 = result.result.get('__step_results__', {}) if isinstance(result.result, dict) else {}
1235
+ if isinstance(sr_map2, dict) and sr_map2:
1236
+ total_steps = len(sr_map2)
1237
+ hits = 0
1238
+ for __v in sr_map2.values():
1239
+ try:
1240
+ if isinstance(__v, dict) and __v.get('__was_cached__'):
1241
+ hits += 1
1242
+ except Exception:
1243
+ continue
1244
+ mode_label = 'Dask Distributed' if is_distributed else 'Dask Threads'
1245
+ self.logger.info(f"Process {process_name} cached steps: {hits}/{total_steps} [{mode_label}]")
1246
+ except Exception:
1247
+ pass
1248
+
1249
+ if self.state_manager and run_id and not stop_after_process:
1250
+ self.state_manager.complete_pipeline_execution(run_id, True)
1251
+ stats = self.state_manager.get_pipeline_stats(run_id)
1252
+ try:
1253
+ if stats and 'cache_hit_rate' in stats:
1254
+ self.logger.info(f"Pipeline completed with Dask scheduler. Cache hit rate: {stats['cache_hit_rate']:.1%} "
1255
+ f"({stats.get('cache_hit_count', 0)}/{stats.get('completed_steps', 0)} steps)")
1256
+ else:
1257
+ self.logger.info("Pipeline completed with Dask scheduler.")
1258
+ except Exception:
1259
+ self.logger.info("Pipeline completed with Dask scheduler.")
1260
+ except Exception:
1261
+ if self.state_manager and run_id and not stop_after_process:
1262
+ self.state_manager.complete_pipeline_execution(run_id, False)
1263
+ raise
1264
+
1265
+ return execution_results