expops 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- expops-0.1.3.dist-info/METADATA +826 -0
- expops-0.1.3.dist-info/RECORD +86 -0
- expops-0.1.3.dist-info/WHEEL +5 -0
- expops-0.1.3.dist-info/entry_points.txt +3 -0
- expops-0.1.3.dist-info/licenses/LICENSE +674 -0
- expops-0.1.3.dist-info/top_level.txt +1 -0
- mlops/__init__.py +0 -0
- mlops/__main__.py +11 -0
- mlops/_version.py +34 -0
- mlops/adapters/__init__.py +12 -0
- mlops/adapters/base.py +86 -0
- mlops/adapters/config_schema.py +89 -0
- mlops/adapters/custom/__init__.py +3 -0
- mlops/adapters/custom/custom_adapter.py +447 -0
- mlops/adapters/plugin_manager.py +113 -0
- mlops/adapters/sklearn/__init__.py +3 -0
- mlops/adapters/sklearn/adapter.py +94 -0
- mlops/cluster/__init__.py +3 -0
- mlops/cluster/controller.py +496 -0
- mlops/cluster/process_runner.py +91 -0
- mlops/cluster/providers.py +258 -0
- mlops/core/__init__.py +95 -0
- mlops/core/custom_model_base.py +38 -0
- mlops/core/dask_networkx_executor.py +1265 -0
- mlops/core/executor_worker.py +1239 -0
- mlops/core/experiment_tracker.py +81 -0
- mlops/core/graph_types.py +64 -0
- mlops/core/networkx_parser.py +135 -0
- mlops/core/payload_spill.py +278 -0
- mlops/core/pipeline_utils.py +162 -0
- mlops/core/process_hashing.py +216 -0
- mlops/core/step_state_manager.py +1298 -0
- mlops/core/step_system.py +956 -0
- mlops/core/workspace.py +99 -0
- mlops/environment/__init__.py +10 -0
- mlops/environment/base.py +43 -0
- mlops/environment/conda_manager.py +307 -0
- mlops/environment/factory.py +70 -0
- mlops/environment/pyenv_manager.py +146 -0
- mlops/environment/setup_env.py +31 -0
- mlops/environment/system_manager.py +66 -0
- mlops/environment/utils.py +105 -0
- mlops/environment/venv_manager.py +134 -0
- mlops/main.py +527 -0
- mlops/managers/project_manager.py +400 -0
- mlops/managers/reproducibility_manager.py +575 -0
- mlops/platform.py +996 -0
- mlops/reporting/__init__.py +16 -0
- mlops/reporting/context.py +187 -0
- mlops/reporting/entrypoint.py +292 -0
- mlops/reporting/kv_utils.py +77 -0
- mlops/reporting/registry.py +50 -0
- mlops/runtime/__init__.py +9 -0
- mlops/runtime/context.py +34 -0
- mlops/runtime/env_export.py +113 -0
- mlops/storage/__init__.py +12 -0
- mlops/storage/adapters/__init__.py +9 -0
- mlops/storage/adapters/gcp_kv_store.py +778 -0
- mlops/storage/adapters/gcs_object_store.py +96 -0
- mlops/storage/adapters/memory_store.py +240 -0
- mlops/storage/adapters/redis_store.py +438 -0
- mlops/storage/factory.py +199 -0
- mlops/storage/interfaces/__init__.py +6 -0
- mlops/storage/interfaces/kv_store.py +118 -0
- mlops/storage/path_utils.py +38 -0
- mlops/templates/premier-league/charts/plot_metrics.js +70 -0
- mlops/templates/premier-league/charts/plot_metrics.py +145 -0
- mlops/templates/premier-league/charts/requirements.txt +6 -0
- mlops/templates/premier-league/configs/cluster_config.yaml +13 -0
- mlops/templates/premier-league/configs/project_config.yaml +207 -0
- mlops/templates/premier-league/data/England CSV.csv +12154 -0
- mlops/templates/premier-league/models/premier_league_model.py +638 -0
- mlops/templates/premier-league/requirements.txt +8 -0
- mlops/templates/sklearn-basic/README.md +22 -0
- mlops/templates/sklearn-basic/charts/plot_metrics.py +85 -0
- mlops/templates/sklearn-basic/charts/requirements.txt +3 -0
- mlops/templates/sklearn-basic/configs/project_config.yaml +64 -0
- mlops/templates/sklearn-basic/data/train.csv +14 -0
- mlops/templates/sklearn-basic/models/model.py +62 -0
- mlops/templates/sklearn-basic/requirements.txt +10 -0
- mlops/web/__init__.py +3 -0
- mlops/web/server.py +585 -0
- mlops/web/ui/index.html +52 -0
- mlops/web/ui/mlops-charts.js +357 -0
- mlops/web/ui/script.js +1244 -0
- mlops/web/ui/styles.css +248 -0
|
@@ -0,0 +1,1265 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
import networkx as nx
|
|
12
|
+
|
|
13
|
+
from dask import compute
|
|
14
|
+
from dask.delayed import delayed
|
|
15
|
+
|
|
16
|
+
from .graph_types import ExecutionResult, NetworkXGraphConfig, NodeType
|
|
17
|
+
from .step_state_manager import StepStateManager
|
|
18
|
+
from .step_system import StepContext, StepDefinition, StepRegistry
|
|
19
|
+
|
|
20
|
+
from .executor_worker import (
|
|
21
|
+
_return_placeholder_cached_process_execution_result,
|
|
22
|
+
_return_placeholder_cached_process_execution_result_with_deps,
|
|
23
|
+
_worker_execute_process_task,
|
|
24
|
+
_worker_execute_process_with_deps,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DaskNetworkXExecutor:
|
|
29
|
+
"""
|
|
30
|
+
Execute NetworkX DAGs with Dask (threads or distributed) and integrated caching.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def _flatten_dask_overrides(overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
35
|
+
flat: Dict[str, Any] = {}
|
|
36
|
+
if not isinstance(overrides, dict):
|
|
37
|
+
return flat
|
|
38
|
+
for key, value in overrides.items():
|
|
39
|
+
if isinstance(value, dict):
|
|
40
|
+
nested = DaskNetworkXExecutor._flatten_dask_overrides(value)
|
|
41
|
+
for sub_key, sub_val in nested.items():
|
|
42
|
+
flat_key = f"{key}.{sub_key}" if key else sub_key
|
|
43
|
+
flat[flat_key] = sub_val
|
|
44
|
+
else:
|
|
45
|
+
flat[str(key)] = value
|
|
46
|
+
return flat
|
|
47
|
+
|
|
48
|
+
def __init__(self, step_registry: StepRegistry, state_manager: Optional[StepStateManager] = None,
|
|
49
|
+
logger: Optional[logging.Logger] = None,
|
|
50
|
+
n_workers: int = 2,
|
|
51
|
+
scheduler_mode: str = "threads",
|
|
52
|
+
scheduler_address: Optional[str] = None,
|
|
53
|
+
client: Any = None,
|
|
54
|
+
extra_files_to_upload: Optional[List[str]] = None,
|
|
55
|
+
strict_cache: bool = True,
|
|
56
|
+
min_workers: Optional[int] = None,
|
|
57
|
+
wait_for_workers_sec: Optional[float] = None,
|
|
58
|
+
dask_config_overrides: Optional[Dict[str, Any]] = None):
|
|
59
|
+
self.step_registry = step_registry
|
|
60
|
+
self.state_manager = state_manager
|
|
61
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
62
|
+
self.cache_enabled = True
|
|
63
|
+
self.n_workers = n_workers
|
|
64
|
+
self.scheduler_mode = scheduler_mode
|
|
65
|
+
self._scheduler_address = scheduler_address
|
|
66
|
+
self._distributed_client = client
|
|
67
|
+
self.strict_cache = bool(strict_cache)
|
|
68
|
+
try:
|
|
69
|
+
self._extra_upload_files: List[str] = list(extra_files_to_upload) if extra_files_to_upload else []
|
|
70
|
+
except Exception:
|
|
71
|
+
self._extra_upload_files = []
|
|
72
|
+
self._min_workers_override = min_workers if isinstance(min_workers, int) and min_workers > 0 else None
|
|
73
|
+
self._wait_for_workers_override = float(wait_for_workers_sec) if isinstance(wait_for_workers_sec, (int, float)) and wait_for_workers_sec > 0 else None
|
|
74
|
+
self._dask_config_overrides = self._flatten_dask_overrides(dask_config_overrides)
|
|
75
|
+
|
|
76
|
+
def _prepare_dask_overrides(self) -> Dict[str, Any]:
|
|
77
|
+
overrides = dict(self._dask_config_overrides)
|
|
78
|
+
# Only apply env defaults when not explicitly configured.
|
|
79
|
+
if "distributed.comm.compression" not in overrides:
|
|
80
|
+
overrides["distributed.comm.compression"] = (
|
|
81
|
+
os.environ.get("DASK_DISTRIBUTED__COMM__COMPRESSION")
|
|
82
|
+
or "zlib"
|
|
83
|
+
)
|
|
84
|
+
return overrides
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _normalize_bootstrap_mode(value: Optional[str]) -> str:
|
|
88
|
+
"""Normalize the bootstrap mode string.
|
|
89
|
+
|
|
90
|
+
Supported values:
|
|
91
|
+
- "auto" (default): only bootstrap when workers/scheduler can't import `mlops`
|
|
92
|
+
- "always": always bootstrap (upload zip, sys.path, cleanup)
|
|
93
|
+
- "never": never bootstrap `mlops` package code (still may upload extra files)
|
|
94
|
+
"""
|
|
95
|
+
try:
|
|
96
|
+
v = (value or "").strip().lower()
|
|
97
|
+
except Exception:
|
|
98
|
+
v = ""
|
|
99
|
+
if v in ("1", "true", "yes", "y", "always", "on"):
|
|
100
|
+
return "always"
|
|
101
|
+
if v in ("0", "false", "no", "n", "never", "off"):
|
|
102
|
+
return "never"
|
|
103
|
+
return "auto"
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _worker_apply_dask_config(overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
107
|
+
"""Apply Dask config overrides in the current process (worker or scheduler)."""
|
|
108
|
+
applied: Dict[str, Any] = {}
|
|
109
|
+
try:
|
|
110
|
+
from dask import config as _dask_config
|
|
111
|
+
except Exception:
|
|
112
|
+
return {"error": "config_import_failed"}
|
|
113
|
+
if isinstance(overrides, dict):
|
|
114
|
+
for key, value in overrides.items():
|
|
115
|
+
try:
|
|
116
|
+
_dask_config.set({key: value})
|
|
117
|
+
applied[key] = _dask_config.get(key)
|
|
118
|
+
except Exception:
|
|
119
|
+
applied[key] = value
|
|
120
|
+
return applied
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _worker_try_import(module_name: str) -> bool:
|
|
124
|
+
try:
|
|
125
|
+
__import__(str(module_name))
|
|
126
|
+
return True
|
|
127
|
+
except Exception:
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def _worker_set_env_vars(env_vars: Dict[str, Any]) -> bool:
|
|
132
|
+
import os as _os
|
|
133
|
+
try:
|
|
134
|
+
for key, value in (env_vars or {}).items():
|
|
135
|
+
if value is not None:
|
|
136
|
+
_os.environ[str(key)] = str(value)
|
|
137
|
+
return True
|
|
138
|
+
except Exception:
|
|
139
|
+
return False
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _worker_ensure_sys_path(paths: List[str]) -> Dict[str, bool]:
|
|
143
|
+
import sys as _sys
|
|
144
|
+
results: Dict[str, bool] = {}
|
|
145
|
+
for p in paths or []:
|
|
146
|
+
try:
|
|
147
|
+
p_str = str(p)
|
|
148
|
+
except Exception:
|
|
149
|
+
continue
|
|
150
|
+
if p_str and p_str not in _sys.path:
|
|
151
|
+
_sys.path.insert(0, p_str)
|
|
152
|
+
results[p_str] = p_str in _sys.path
|
|
153
|
+
return results
|
|
154
|
+
|
|
155
|
+
@staticmethod
|
|
156
|
+
def _worker_cleanup_import_state(patterns: List[str], files_to_clean: List[str], clean_mlops: bool = True) -> Dict[str, Any]:
|
|
157
|
+
"""Best-effort cleanup to avoid stale imports between runs on long-lived worker processes."""
|
|
158
|
+
import sys as _sys
|
|
159
|
+
import re as _re
|
|
160
|
+
import importlib as _importlib
|
|
161
|
+
import os as _os
|
|
162
|
+
|
|
163
|
+
removed_paths: List[str] = []
|
|
164
|
+
removed_files: List[str] = []
|
|
165
|
+
removed_modules: List[str] = []
|
|
166
|
+
|
|
167
|
+
# Remove sys.path entries matching known uploaded artifacts (zip names, temp dirs, etc.)
|
|
168
|
+
new_path: List[str] = []
|
|
169
|
+
for p in list(_sys.path):
|
|
170
|
+
try:
|
|
171
|
+
p_str = str(p or "")
|
|
172
|
+
except Exception:
|
|
173
|
+
p_str = ""
|
|
174
|
+
try:
|
|
175
|
+
matched = any(_re.search(ptn, p_str) for ptn in (patterns or []))
|
|
176
|
+
except Exception:
|
|
177
|
+
matched = False
|
|
178
|
+
if matched:
|
|
179
|
+
removed_paths.append(p_str)
|
|
180
|
+
continue
|
|
181
|
+
if p not in new_path:
|
|
182
|
+
new_path.append(p)
|
|
183
|
+
_sys.path[:] = new_path
|
|
184
|
+
|
|
185
|
+
# Optionally evict previously loaded `mlops` modules to force a clean import.
|
|
186
|
+
if clean_mlops:
|
|
187
|
+
for name in list(_sys.modules.keys()):
|
|
188
|
+
if name == "mlops" or name.startswith("mlops."):
|
|
189
|
+
removed_modules.append(name)
|
|
190
|
+
_sys.modules.pop(name, None)
|
|
191
|
+
|
|
192
|
+
# Aggressively remove uploaded custom model modules from sys.modules and common upload locations.
|
|
193
|
+
for fname in (files_to_clean or []):
|
|
194
|
+
try:
|
|
195
|
+
base_fname = _os.path.basename(str(fname))
|
|
196
|
+
if not base_fname.endswith(".py"):
|
|
197
|
+
continue
|
|
198
|
+
mod_name = base_fname[:-3]
|
|
199
|
+
if mod_name in _sys.modules:
|
|
200
|
+
removed_modules.append(mod_name)
|
|
201
|
+
_sys.modules.pop(mod_name, None)
|
|
202
|
+
# Also clear our canonical fallback module name.
|
|
203
|
+
if "custom_model" in _sys.modules:
|
|
204
|
+
removed_modules.append("custom_model")
|
|
205
|
+
_sys.modules.pop("custom_model", None)
|
|
206
|
+
|
|
207
|
+
# Try to find and remove file from common upload locations (best-effort).
|
|
208
|
+
search_paths = [
|
|
209
|
+
_os.path.expanduser("~"),
|
|
210
|
+
_os.getcwd(),
|
|
211
|
+
]
|
|
212
|
+
for search_dir in search_paths:
|
|
213
|
+
try:
|
|
214
|
+
worker_file_path = _os.path.join(search_dir, base_fname)
|
|
215
|
+
if _os.path.exists(worker_file_path):
|
|
216
|
+
_os.remove(worker_file_path)
|
|
217
|
+
removed_files.append(worker_file_path)
|
|
218
|
+
except Exception:
|
|
219
|
+
pass
|
|
220
|
+
except Exception:
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
_importlib.invalidate_caches()
|
|
225
|
+
except Exception:
|
|
226
|
+
pass
|
|
227
|
+
return {
|
|
228
|
+
"removed_paths": removed_paths,
|
|
229
|
+
"removed_modules": removed_modules,
|
|
230
|
+
"removed_files": removed_files,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
def _ensure_connected_to_scheduler(self, force: bool = False) -> None:
|
|
234
|
+
"""Connect to an external Dask scheduler if requested; otherwise use threads.
|
|
235
|
+
|
|
236
|
+
Note: For many real clusters, workers already have `mlops` installed or share a filesystem.
|
|
237
|
+
In that case, this method will skip the heavy "upload source zip / sys.path surgery" work
|
|
238
|
+
by default (bootstrap mode = auto).
|
|
239
|
+
"""
|
|
240
|
+
if self.scheduler_mode != 'distributed':
|
|
241
|
+
return
|
|
242
|
+
if self._distributed_client is not None and not force:
|
|
243
|
+
return
|
|
244
|
+
scheduler_addr = self._scheduler_address or os.environ.get('DASK_SCHEDULER_ADDRESS')
|
|
245
|
+
# If a client already exists (e.g., user passed one in, or restart created one),
|
|
246
|
+
# prefer its known scheduler address as a fallback.
|
|
247
|
+
if not scheduler_addr and self._distributed_client is not None:
|
|
248
|
+
try:
|
|
249
|
+
scheduler_addr = getattr(getattr(self._distributed_client, "scheduler", None), "address", None)
|
|
250
|
+
except Exception:
|
|
251
|
+
scheduler_addr = None
|
|
252
|
+
if not scheduler_addr:
|
|
253
|
+
self.logger.warning(
|
|
254
|
+
"Distributed scheduler requested but no Client or address provided. Falling back to threads."
|
|
255
|
+
)
|
|
256
|
+
self.scheduler_mode = 'threads'
|
|
257
|
+
return
|
|
258
|
+
try:
|
|
259
|
+
# Ensure we have a client connection.
|
|
260
|
+
if self._distributed_client is None:
|
|
261
|
+
try:
|
|
262
|
+
from distributed import Client
|
|
263
|
+
except Exception:
|
|
264
|
+
from dask.distributed import Client
|
|
265
|
+
self._distributed_client = Client(scheduler_addr)
|
|
266
|
+
self.logger.info(f"Connected to Dask scheduler at {scheduler_addr}")
|
|
267
|
+
|
|
268
|
+
self._configure_distributed_runtime(scheduler_addr)
|
|
269
|
+
except Exception as e:
|
|
270
|
+
self.logger.warning(
|
|
271
|
+
f"Failed to connect/configure scheduler at {scheduler_addr} ({e}). Falling back to threads."
|
|
272
|
+
)
|
|
273
|
+
self.scheduler_mode = 'threads'
|
|
274
|
+
|
|
275
|
+
def _configure_distributed_runtime(self, scheduler_addr: str) -> None:
|
|
276
|
+
"""Best-effort cluster setup after a Client is connected."""
|
|
277
|
+
if not self._distributed_client:
|
|
278
|
+
raise RuntimeError("Distributed client is not initialized")
|
|
279
|
+
|
|
280
|
+
client = self._distributed_client
|
|
281
|
+
config_overrides = self._prepare_dask_overrides()
|
|
282
|
+
|
|
283
|
+
# Push config to scheduler first (so it is applied even before workers join).
|
|
284
|
+
try:
|
|
285
|
+
sched_conf = client.run_on_scheduler(self._worker_apply_dask_config, config_overrides)
|
|
286
|
+
self.logger.info(f"Scheduler Dask config applied: {sched_conf}")
|
|
287
|
+
except Exception as e:
|
|
288
|
+
self.logger.warning(f"Failed to push Dask config to scheduler: {e}")
|
|
289
|
+
|
|
290
|
+
# Resolve worker wait settings.
|
|
291
|
+
min_workers_env = os.environ.get('MLOPS_DASK_MIN_WORKERS', '').strip()
|
|
292
|
+
timeout_env = os.environ.get('MLOPS_DASK_WAIT_FOR_WORKERS_SEC', '').strip()
|
|
293
|
+
if self._min_workers_override:
|
|
294
|
+
min_workers = self._min_workers_override
|
|
295
|
+
elif min_workers_env.isdigit():
|
|
296
|
+
min_workers = int(min_workers_env)
|
|
297
|
+
else:
|
|
298
|
+
min_workers = self.n_workers if isinstance(self.n_workers, int) and self.n_workers > 0 else 1
|
|
299
|
+
if min_workers < 1:
|
|
300
|
+
min_workers = 1
|
|
301
|
+
if self._wait_for_workers_override:
|
|
302
|
+
wait_timeout = self._wait_for_workers_override
|
|
303
|
+
elif timeout_env:
|
|
304
|
+
try:
|
|
305
|
+
wait_timeout = float(timeout_env)
|
|
306
|
+
except Exception:
|
|
307
|
+
wait_timeout = 30.0
|
|
308
|
+
else:
|
|
309
|
+
wait_timeout = 30.0
|
|
310
|
+
|
|
311
|
+
# Wait for workers to connect (best-effort; verify count afterwards).
|
|
312
|
+
try:
|
|
313
|
+
self.logger.info(f"Waiting for at least {min_workers} worker(s) to connect (timeout={wait_timeout}s)")
|
|
314
|
+
client.wait_for_workers(min_workers, timeout=wait_timeout) # type: ignore[attr-defined]
|
|
315
|
+
except Exception:
|
|
316
|
+
pass
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
info = client.scheduler_info()
|
|
320
|
+
workers_dict = info.get('workers', {}) if isinstance(info, dict) else {}
|
|
321
|
+
worker_count = len(workers_dict)
|
|
322
|
+
except Exception:
|
|
323
|
+
worker_count = 0
|
|
324
|
+
|
|
325
|
+
if worker_count < min_workers:
|
|
326
|
+
raise RuntimeError(
|
|
327
|
+
f"Connected to Dask scheduler at {scheduler_addr} but only {worker_count} worker(s) "
|
|
328
|
+
f"available after waiting {wait_timeout}s. Ensure workers are started and connected."
|
|
329
|
+
)
|
|
330
|
+
self.logger.info(f"Workers connected: {worker_count}")
|
|
331
|
+
|
|
332
|
+
# After workers connect, push config to workers as well.
|
|
333
|
+
try:
|
|
334
|
+
workers_conf = client.run(self._worker_apply_dask_config, config_overrides)
|
|
335
|
+
self.logger.info(f"Workers Dask config applied: {workers_conf}")
|
|
336
|
+
except Exception as e:
|
|
337
|
+
self.logger.warning(f"Failed to push Dask config to workers: {e}")
|
|
338
|
+
|
|
339
|
+
# Bootstrap code/env on long-lived worker processes (auto-skip when not needed).
|
|
340
|
+
self._bootstrap_distributed_imports_and_env()
|
|
341
|
+
|
|
342
|
+
def _bootstrap_distributed_imports_and_env(self) -> None:
|
|
343
|
+
if not self._distributed_client:
|
|
344
|
+
return
|
|
345
|
+
|
|
346
|
+
client = self._distributed_client
|
|
347
|
+
bootstrap_mode = self._normalize_bootstrap_mode(os.environ.get("MLOPS_DASK_BOOTSTRAP", "auto"))
|
|
348
|
+
|
|
349
|
+
# Decide whether we need to ship the `mlops` package code to workers.
|
|
350
|
+
needs_mlops_bootstrap = bootstrap_mode == "always"
|
|
351
|
+
if bootstrap_mode == "auto":
|
|
352
|
+
try:
|
|
353
|
+
import_ok_workers = client.run(self._worker_try_import, "mlops")
|
|
354
|
+
needs_mlops_bootstrap = not all(bool(v) for v in import_ok_workers.values())
|
|
355
|
+
except Exception:
|
|
356
|
+
needs_mlops_bootstrap = True
|
|
357
|
+
try:
|
|
358
|
+
import_ok_sched = client.run_on_scheduler(self._worker_try_import, "mlops")
|
|
359
|
+
needs_mlops_bootstrap = needs_mlops_bootstrap or not bool(import_ok_sched)
|
|
360
|
+
except Exception:
|
|
361
|
+
needs_mlops_bootstrap = True
|
|
362
|
+
elif bootstrap_mode == "never":
|
|
363
|
+
needs_mlops_bootstrap = False
|
|
364
|
+
|
|
365
|
+
# Cleanup import state only when we are about to upload something (mlops zip and/or extra files).
|
|
366
|
+
if needs_mlops_bootstrap or self._extra_upload_files:
|
|
367
|
+
try:
|
|
368
|
+
cleanup_patterns = [r"mlops_src.*\\.zip", r"mlops_src_"]
|
|
369
|
+
files_to_clean = self._extra_upload_files if self._extra_upload_files else []
|
|
370
|
+
try:
|
|
371
|
+
cleaned_workers = client.run(
|
|
372
|
+
self._worker_cleanup_import_state,
|
|
373
|
+
cleanup_patterns,
|
|
374
|
+
files_to_clean,
|
|
375
|
+
needs_mlops_bootstrap,
|
|
376
|
+
)
|
|
377
|
+
self.logger.info(f"Cleaned worker import state: {cleaned_workers}")
|
|
378
|
+
except Exception as e:
|
|
379
|
+
self.logger.warning(f"Failed to clean workers import state: {e}")
|
|
380
|
+
try:
|
|
381
|
+
cleaned_sched = client.run_on_scheduler(
|
|
382
|
+
self._worker_cleanup_import_state,
|
|
383
|
+
cleanup_patterns,
|
|
384
|
+
files_to_clean,
|
|
385
|
+
needs_mlops_bootstrap,
|
|
386
|
+
)
|
|
387
|
+
self.logger.info(f"Cleaned scheduler import state: {cleaned_sched}")
|
|
388
|
+
except Exception as e:
|
|
389
|
+
self.logger.warning(f"Failed to clean scheduler import state: {e}")
|
|
390
|
+
except Exception as e:
|
|
391
|
+
self.logger.warning(f"Cleanup of previous uploaded code failed: {e}")
|
|
392
|
+
|
|
393
|
+
# Ship mlops code only when required.
|
|
394
|
+
if needs_mlops_bootstrap:
|
|
395
|
+
try:
|
|
396
|
+
from pathlib import Path as _Path
|
|
397
|
+
import tempfile
|
|
398
|
+
import zipfile
|
|
399
|
+
|
|
400
|
+
# Package just the `mlops/` python package, not an entire repo root or site-packages.
|
|
401
|
+
mlops_pkg_dir = _Path(__file__).resolve().parents[1] # .../mlops
|
|
402
|
+
pkg_parent = mlops_pkg_dir.parent # .../src (source) or .../site-packages (installed)
|
|
403
|
+
|
|
404
|
+
zip_path = _Path(tempfile.gettempdir()) / "mlops_pkg.zip"
|
|
405
|
+
with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
|
|
406
|
+
for path in mlops_pkg_dir.rglob('*.py'):
|
|
407
|
+
try:
|
|
408
|
+
# Keep `mlops/...` as the top-level folder inside the zip.
|
|
409
|
+
zf.write(path, arcname=str(path.relative_to(pkg_parent)))
|
|
410
|
+
except Exception:
|
|
411
|
+
pass
|
|
412
|
+
client.upload_file(str(zip_path))
|
|
413
|
+
self.logger.info(f"Uploaded mlops source package to workers: {zip_path}")
|
|
414
|
+
|
|
415
|
+
# Ensure the zip and repo paths are importable (best-effort; may be a no-op on real clusters).
|
|
416
|
+
paths_to_add: List[str] = []
|
|
417
|
+
try:
|
|
418
|
+
paths_to_add.append(str(zip_path))
|
|
419
|
+
except Exception:
|
|
420
|
+
pass
|
|
421
|
+
try:
|
|
422
|
+
added = client.run(self._worker_ensure_sys_path, paths_to_add)
|
|
423
|
+
self.logger.info(f"Adjusted sys.path on workers for importability: {added}")
|
|
424
|
+
except Exception as e:
|
|
425
|
+
self.logger.warning(f"Failed to adjust worker sys.path: {e}")
|
|
426
|
+
try:
|
|
427
|
+
added_sched = client.run_on_scheduler(self._worker_ensure_sys_path, paths_to_add)
|
|
428
|
+
self.logger.info(f"Adjusted sys.path on scheduler for importability: {added_sched}")
|
|
429
|
+
except Exception as e:
|
|
430
|
+
self.logger.warning(f"Failed to adjust scheduler sys.path: {e}")
|
|
431
|
+
|
|
432
|
+
# Validate import on workers (warn-only).
|
|
433
|
+
try:
|
|
434
|
+
import_ok = client.run(self._worker_try_import, "mlops")
|
|
435
|
+
if isinstance(import_ok, dict) and not all(bool(v) for v in import_ok.values()):
|
|
436
|
+
self.logger.warning(
|
|
437
|
+
"One or more workers cannot import 'mlops'. Ensure shared filesystem or install package on workers."
|
|
438
|
+
)
|
|
439
|
+
except Exception:
|
|
440
|
+
pass
|
|
441
|
+
except Exception as e:
|
|
442
|
+
self.logger.warning(f"Failed to package/upload mlops code to workers: {e}")
|
|
443
|
+
|
|
444
|
+
# Upload any additional files requested (e.g., custom model script, reporting entrypoint).
|
|
445
|
+
if self._extra_upload_files:
|
|
446
|
+
for _f in self._extra_upload_files:
|
|
447
|
+
try:
|
|
448
|
+
if _f and os.path.exists(_f):
|
|
449
|
+
load_flag = True
|
|
450
|
+
try:
|
|
451
|
+
norm = str(_f).replace("\\", "/")
|
|
452
|
+
if norm.endswith(".py") and "/charts/" in norm:
|
|
453
|
+
load_flag = False
|
|
454
|
+
except Exception:
|
|
455
|
+
load_flag = True
|
|
456
|
+
try:
|
|
457
|
+
client.upload_file(str(_f), load=load_flag)
|
|
458
|
+
except TypeError:
|
|
459
|
+
# Backward-compatible fallback (older clients without load kwarg)
|
|
460
|
+
client.upload_file(str(_f))
|
|
461
|
+
if load_flag:
|
|
462
|
+
self.logger.info(f"Uploaded extra file to workers: {_f}")
|
|
463
|
+
else:
|
|
464
|
+
self.logger.info(f"Uploaded extra file to workers (load=False): {_f}")
|
|
465
|
+
except Exception as e:
|
|
466
|
+
self.logger.warning(f"Failed to upload extra file '{_f}' to workers: {e}")
|
|
467
|
+
|
|
468
|
+
# Propagate critical environment variables to workers (best-effort).
|
|
469
|
+
try:
|
|
470
|
+
env_to_propagate: Dict[str, str] = {}
|
|
471
|
+
critical_env_vars = [
|
|
472
|
+
'MLOPS_PROJECT_ID',
|
|
473
|
+
'MLOPS_WORKSPACE_DIR',
|
|
474
|
+
'GOOGLE_APPLICATION_CREDENTIALS',
|
|
475
|
+
'GOOGLE_CLOUD_PROJECT',
|
|
476
|
+
'FIRESTORE_EMULATOR_HOST',
|
|
477
|
+
'DASK_DISTRIBUTED__COMM__COMPRESSION',
|
|
478
|
+
'MLOPS_RUNTIME_PYTHON',
|
|
479
|
+
'MLOPS_REPORTING_PYTHON',
|
|
480
|
+
'MLOPS_REPORTING_CONFIG',
|
|
481
|
+
'MLOPS_RANDOM_SEED',
|
|
482
|
+
'MLOPS_TASK_LEVEL_SEEDING',
|
|
483
|
+
]
|
|
484
|
+
for env_var in critical_env_vars:
|
|
485
|
+
value = os.environ.get(env_var)
|
|
486
|
+
if value is not None:
|
|
487
|
+
env_to_propagate[env_var] = value
|
|
488
|
+
if env_to_propagate:
|
|
489
|
+
client.run(self._worker_set_env_vars, env_to_propagate)
|
|
490
|
+
client.run_on_scheduler(self._worker_set_env_vars, env_to_propagate)
|
|
491
|
+
self.logger.info(f"Propagated environment variables: {list(env_to_propagate.keys())}")
|
|
492
|
+
except Exception as e:
|
|
493
|
+
self.logger.warning(f"Failed to propagate environment variables to workers: {e}")
|
|
494
|
+
|
|
495
|
+
def set_cache_enabled(self, enabled: bool) -> None:
|
|
496
|
+
"""Enable or disable step caching."""
|
|
497
|
+
self.cache_enabled = enabled
|
|
498
|
+
|
|
499
|
+
def restart_distributed_client(self) -> None:
|
|
500
|
+
"""Restart the distributed client connection to ensure clean state between runs.
|
|
501
|
+
|
|
502
|
+
This is a more aggressive cleanup than the automatic cleanup performed in
|
|
503
|
+
_ensure_connected_to_scheduler. Use this if you're experiencing persistent
|
|
504
|
+
deserialization errors between consecutive runs.
|
|
505
|
+
"""
|
|
506
|
+
if self.scheduler_mode != 'distributed' or not self._distributed_client:
|
|
507
|
+
return
|
|
508
|
+
|
|
509
|
+
try:
|
|
510
|
+
scheduler_addr = self._distributed_client.scheduler.address
|
|
511
|
+
self.logger.info(f"Restarting Dask client connection to {scheduler_addr}")
|
|
512
|
+
|
|
513
|
+
# Close the existing client
|
|
514
|
+
try:
|
|
515
|
+
self._distributed_client.close()
|
|
516
|
+
except Exception as e:
|
|
517
|
+
self.logger.warning(f"Error closing existing client: {e}")
|
|
518
|
+
|
|
519
|
+
# Create a new client
|
|
520
|
+
try:
|
|
521
|
+
from distributed import Client # type: ignore
|
|
522
|
+
except Exception:
|
|
523
|
+
from dask.distributed import Client # type: ignore
|
|
524
|
+
|
|
525
|
+
self._distributed_client = Client(scheduler_addr)
|
|
526
|
+
self.logger.info(f"Successfully restarted Dask client connection")
|
|
527
|
+
|
|
528
|
+
# Re-run the setup (compression, worker checks, etc.)
|
|
529
|
+
self._ensure_connected_to_scheduler(force=True)
|
|
530
|
+
except Exception as e:
|
|
531
|
+
self.logger.error(f"Failed to restart Dask client: {e}")
|
|
532
|
+
raise
|
|
533
|
+
|
|
534
|
+
def _build_process_graph(self, config: NetworkXGraphConfig) -> nx.DiGraph:
|
|
535
|
+
"""Build the main process-level DAG."""
|
|
536
|
+
process_graph = nx.DiGraph()
|
|
537
|
+
|
|
538
|
+
for process_config in config.processes:
|
|
539
|
+
process_graph.add_node(
|
|
540
|
+
process_config.name,
|
|
541
|
+
type=NodeType.PROCESS,
|
|
542
|
+
config=process_config
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
for process_config in config.processes:
|
|
546
|
+
depends_on = getattr(process_config, 'depends_on', [])
|
|
547
|
+
for dependency in depends_on:
|
|
548
|
+
process_graph.add_edge(dependency, process_config.name)
|
|
549
|
+
|
|
550
|
+
return process_graph
|
|
551
|
+
|
|
552
|
+
def _validate_process_dag(self, process_graph: nx.DiGraph) -> None:
|
|
553
|
+
"""Validate that the process-level graph is a DAG."""
|
|
554
|
+
if not nx.is_directed_acyclic_graph(process_graph):
|
|
555
|
+
cycles = list(nx.simple_cycles(process_graph))
|
|
556
|
+
raise ValueError(f"Process-level graph contains cycles: {cycles}")
|
|
557
|
+
|
|
558
|
+
# Removed unused step-config hashing; step-level caching uses stable context hash instead
|
|
559
|
+
|
|
560
|
+
def _compute_step_input_hash(self, step_def: StepDefinition, context: StepContext) -> str:
|
|
561
|
+
"""Compute a stable step input hash without parameter resolution.
|
|
562
|
+
Uses available context surface to approximate variability.
|
|
563
|
+
"""
|
|
564
|
+
if not self.state_manager:
|
|
565
|
+
return ""
|
|
566
|
+
try:
|
|
567
|
+
context_data = {
|
|
568
|
+
'step': getattr(step_def, 'name', None),
|
|
569
|
+
'process': getattr(context, 'current_process', None),
|
|
570
|
+
'step_results_keys': sorted(list((getattr(context, 'step_results', {}) or {}).keys())),
|
|
571
|
+
'iteration': getattr(context, 'iteration', 0),
|
|
572
|
+
}
|
|
573
|
+
return self.state_manager._compute_hash(context_data)
|
|
574
|
+
except Exception as e:
|
|
575
|
+
try:
|
|
576
|
+
self.logger.warning(f"Failed to compute context-based step hash: {e}")
|
|
577
|
+
except Exception:
|
|
578
|
+
pass
|
|
579
|
+
return ""
|
|
580
|
+
|
|
581
|
+
# Removed unused Dask step task; steps execute via the step wrapper inside process runners
|
|
582
|
+
|
|
583
|
+
def _serialize_context_for_worker(self, context: StepContext) -> dict:
|
|
584
|
+
"""Serialize minimal context payload with only primitives for safe graph shipping."""
|
|
585
|
+
try:
|
|
586
|
+
data_paths = {k: str(v) for k, v in (context.data_paths or {}).items()}
|
|
587
|
+
except Exception:
|
|
588
|
+
data_paths = {}
|
|
589
|
+
# Sanitize step_results: keep only JSON-serializable fields and avoid heavy objects
|
|
590
|
+
sanitized_results: Dict[str, Dict[str, Any]] = {}
|
|
591
|
+
for process_name, result in (context.step_results or {}).items():
|
|
592
|
+
try:
|
|
593
|
+
# Copy data and drop recursive/heavy fields (nested step maps, logs, large artifact pointers)
|
|
594
|
+
_data = dict(result) if isinstance(result, dict) else {}
|
|
595
|
+
_common_drop_keys = ['__step_results__', '__logs__', 'checkpoint_path', 'cache_path', 'artifacts']
|
|
596
|
+
_heavy_drop_keys = ['model', 'saved_model'] if str(self.scheduler_mode) == 'distributed' else []
|
|
597
|
+
for _k in (_common_drop_keys + _heavy_drop_keys):
|
|
598
|
+
try:
|
|
599
|
+
_data.pop(_k, None)
|
|
600
|
+
except Exception:
|
|
601
|
+
pass
|
|
602
|
+
# Best-effort: attach cache path so workers can rehydrate dependencies locally without shipping heavy objects
|
|
603
|
+
_cache_path = None
|
|
604
|
+
try:
|
|
605
|
+
if self.state_manager:
|
|
606
|
+
ih, ch, fh = self._compute_process_lookup_hashes(context, process_name)
|
|
607
|
+
# Guard against missing kv_store in rare cases
|
|
608
|
+
_kvs = getattr(self.state_manager, 'kv_store', None)
|
|
609
|
+
if _kvs and hasattr(_kvs, 'get_process_cache_path'):
|
|
610
|
+
_cache_path = _kvs.get_process_cache_path(process_name, ih, ch, fh)
|
|
611
|
+
except Exception:
|
|
612
|
+
_cache_path = None
|
|
613
|
+
# Keep the same shape as original result surface for dependency injection.
|
|
614
|
+
# Attach cache path as a top-level meta key to enable rehydration when needed.
|
|
615
|
+
_data['cache_path'] = _cache_path
|
|
616
|
+
sanitized_results[process_name] = _data
|
|
617
|
+
except Exception:
|
|
618
|
+
continue
|
|
619
|
+
return {
|
|
620
|
+
'project_id': context.project_id,
|
|
621
|
+
'run_id': context.run_id,
|
|
622
|
+
'global_config': dict(getattr(context, 'global_config', {}) or {}),
|
|
623
|
+
'data_paths': data_paths,
|
|
624
|
+
'checkpoint_dir': str(getattr(context, 'checkpoint_dir', 'artifacts/checkpoints')),
|
|
625
|
+
'step_results': sanitized_results,
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
def _compute_process_lookup_hashes(self, context: StepContext, process_name: str) -> tuple:
|
|
629
|
+
"""Compute (ih, ch, fh) via shared helper to keep driver/worker in lockstep."""
|
|
630
|
+
try:
|
|
631
|
+
from .process_hashing import compute_process_hashes
|
|
632
|
+
except Exception:
|
|
633
|
+
compute_process_hashes = None
|
|
634
|
+
|
|
635
|
+
# Build deterministic dependency_map from the process graph
|
|
636
|
+
dependency_map = {}
|
|
637
|
+
try:
|
|
638
|
+
if hasattr(self, 'process_graph'):
|
|
639
|
+
for n in list(self.process_graph.nodes):
|
|
640
|
+
try:
|
|
641
|
+
preds = list(self.process_graph.predecessors(n))
|
|
642
|
+
preds = sorted(set(preds))
|
|
643
|
+
dependency_map[n] = preds
|
|
644
|
+
except Exception:
|
|
645
|
+
dependency_map[n] = []
|
|
646
|
+
except Exception:
|
|
647
|
+
dependency_map = {}
|
|
648
|
+
|
|
649
|
+
if compute_process_hashes and self.state_manager:
|
|
650
|
+
# Use code_function mapping when available to resolve the correct process definition
|
|
651
|
+
try:
|
|
652
|
+
lookup_name = self._get_lookup_name(process_name) or process_name
|
|
653
|
+
except Exception:
|
|
654
|
+
lookup_name = process_name
|
|
655
|
+
ih, ch, fh = compute_process_hashes(
|
|
656
|
+
self.state_manager,
|
|
657
|
+
context,
|
|
658
|
+
process_name,
|
|
659
|
+
dependency_map,
|
|
660
|
+
lookup_name=lookup_name,
|
|
661
|
+
)
|
|
662
|
+
else:
|
|
663
|
+
ih = ch = fh = None
|
|
664
|
+
|
|
665
|
+
self.logger.debug(f"[HashTrace] side=driver process={process_name} ih={ih} ch={ch} fh={fh}")
|
|
666
|
+
|
|
667
|
+
return (ih, ch, fh)
|
|
668
|
+
|
|
669
|
+
def _get_cache_config_for_worker(self) -> Dict[str, Any]:
|
|
670
|
+
"""Extract cache configuration for worker state manager creation."""
|
|
671
|
+
try:
|
|
672
|
+
if self.state_manager is None:
|
|
673
|
+
return {}
|
|
674
|
+
|
|
675
|
+
cfg: Dict[str, Any] = {}
|
|
676
|
+
|
|
677
|
+
kv_store = getattr(self.state_manager, "kv_store", None)
|
|
678
|
+
if kv_store is not None:
|
|
679
|
+
kv_type = type(kv_store).__name__
|
|
680
|
+
cfg["kv_store_type"] = kv_type
|
|
681
|
+
if "GCP" in kv_type or "Firestore" in kv_type:
|
|
682
|
+
cfg["kv_store_config"] = {
|
|
683
|
+
"project_id": getattr(kv_store, "project_id", None),
|
|
684
|
+
"gcp_project": getattr(kv_store, "gcp_project", None),
|
|
685
|
+
"topic_name": getattr(kv_store, "topic_name", None),
|
|
686
|
+
"emulator_host": getattr(kv_store, "_emulator_host", None),
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
obj_store = getattr(self.state_manager, "object_store", None)
|
|
690
|
+
if obj_store is not None:
|
|
691
|
+
obj_type = type(obj_store).__name__
|
|
692
|
+
cfg["object_store_type"] = obj_type
|
|
693
|
+
if "GCS" in obj_type:
|
|
694
|
+
bucket_name = None
|
|
695
|
+
try:
|
|
696
|
+
bucket_obj = getattr(obj_store, "_bucket", None)
|
|
697
|
+
bucket_name = getattr(bucket_obj, "name", None) if bucket_obj is not None else None
|
|
698
|
+
except Exception:
|
|
699
|
+
bucket_name = None
|
|
700
|
+
cfg["object_store_config"] = {
|
|
701
|
+
"bucket": bucket_name,
|
|
702
|
+
"prefix": getattr(obj_store, "_prefix", None),
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
return cfg
|
|
706
|
+
except Exception:
|
|
707
|
+
return {}
|
|
708
|
+
|
|
709
|
+
def _get_lookup_name(self, proc_name: str) -> Optional[str]:
|
|
710
|
+
try:
|
|
711
|
+
if hasattr(self, 'process_graph') and self.process_graph.has_node(proc_name):
|
|
712
|
+
cfg = self.process_graph.nodes[proc_name].get('config')
|
|
713
|
+
return getattr(cfg, 'code_function', None)
|
|
714
|
+
except Exception:
|
|
715
|
+
return None
|
|
716
|
+
return None
|
|
717
|
+
|
|
718
|
+
def _get_logging_flag(self, proc_name: str) -> bool:
|
|
719
|
+
try:
|
|
720
|
+
from .step_system import get_process_registry # type: ignore
|
|
721
|
+
pr = get_process_registry()
|
|
722
|
+
pdef = pr.get_process(self._get_lookup_name(proc_name) or proc_name) if pr else None
|
|
723
|
+
return getattr(pdef, 'logging', True) if pdef else True
|
|
724
|
+
except Exception:
|
|
725
|
+
return True
|
|
726
|
+
|
|
727
|
+
def _repo_root(self) -> Path:
|
|
728
|
+
# Legacy: many code paths historically treated this as the repo root.
|
|
729
|
+
# For installed packages, use the workspace root (where projects/ lives).
|
|
730
|
+
try:
|
|
731
|
+
from .workspace import get_workspace_root
|
|
732
|
+
return get_workspace_root()
|
|
733
|
+
except Exception:
|
|
734
|
+
return Path.cwd()
|
|
735
|
+
|
|
736
|
+
def _resolve_path(self, p: str) -> Path:
|
|
737
|
+
try:
|
|
738
|
+
path = Path(p)
|
|
739
|
+
if path.is_absolute():
|
|
740
|
+
return path
|
|
741
|
+
return self._repo_root() / p
|
|
742
|
+
except Exception:
|
|
743
|
+
return Path(p)
|
|
744
|
+
|
|
745
|
+
def _get_reporting_cfg(self) -> dict:
|
|
746
|
+
try:
|
|
747
|
+
txt = os.environ.get('MLOPS_REPORTING_CONFIG') or ''
|
|
748
|
+
if not txt:
|
|
749
|
+
return {}
|
|
750
|
+
return json.loads(txt)
|
|
751
|
+
except Exception:
|
|
752
|
+
return {}
|
|
753
|
+
|
|
754
|
+
def _get_chart_spec(self, name: str) -> dict:
|
|
755
|
+
rcfg = self._get_reporting_cfg()
|
|
756
|
+
charts = rcfg.get('charts') or []
|
|
757
|
+
for item in charts:
|
|
758
|
+
try:
|
|
759
|
+
if isinstance(item, dict) and str(item.get('name')) == name and str(item.get('type', 'static')).lower() != 'dynamic':
|
|
760
|
+
return item
|
|
761
|
+
except Exception:
|
|
762
|
+
continue
|
|
763
|
+
return {}
|
|
764
|
+
|
|
765
|
+
def _compute_chart_function_hash(self, static_entrypoint: str) -> Optional[str]:
|
|
766
|
+
try:
|
|
767
|
+
p = self._resolve_path(static_entrypoint)
|
|
768
|
+
with open(p, 'rb') as f:
|
|
769
|
+
data = f.read()
|
|
770
|
+
return hashlib.sha256(data).hexdigest()
|
|
771
|
+
except Exception:
|
|
772
|
+
return None
|
|
773
|
+
|
|
774
|
+
def _compute_chart_config_hash(self, name: str) -> Optional[str]:
|
|
775
|
+
try:
|
|
776
|
+
rcfg = self._get_reporting_cfg()
|
|
777
|
+
global_args = rcfg.get('args') or []
|
|
778
|
+
theme = os.environ.get('MLOPS_CHART_THEME')
|
|
779
|
+
spec = self._get_chart_spec(name)
|
|
780
|
+
cfg_payload = {
|
|
781
|
+
'name': name,
|
|
782
|
+
'probe_paths': spec.get('probe_paths') or {},
|
|
783
|
+
'chart_args': spec.get('args') or [],
|
|
784
|
+
'global_args': global_args,
|
|
785
|
+
'theme': theme,
|
|
786
|
+
}
|
|
787
|
+
if self.state_manager:
|
|
788
|
+
return self.state_manager._compute_hash(cfg_payload)
|
|
789
|
+
else:
|
|
790
|
+
payload = json.dumps(cfg_payload, sort_keys=True, separators=(",", ":")).encode()
|
|
791
|
+
return hashlib.sha256(payload).hexdigest()
|
|
792
|
+
except Exception:
|
|
793
|
+
return None
|
|
794
|
+
|
|
795
|
+
def _maybe_apply_chart_hash_overrides(
|
|
796
|
+
self,
|
|
797
|
+
process_name: str,
|
|
798
|
+
config_hash: Optional[str],
|
|
799
|
+
function_hash: Optional[str],
|
|
800
|
+
) -> tuple[Optional[str], Optional[str]]:
|
|
801
|
+
"""Override hashes for chart nodes so cache keys track chart config + entrypoint content."""
|
|
802
|
+
try:
|
|
803
|
+
if not hasattr(self, "process_graph") or not self.process_graph.has_node(process_name):
|
|
804
|
+
return config_hash, function_hash
|
|
805
|
+
cfg = self.process_graph.nodes[process_name].get("config")
|
|
806
|
+
if getattr(cfg, "process_type", "process") != "chart":
|
|
807
|
+
return config_hash, function_hash
|
|
808
|
+
rcfg = self._get_reporting_cfg()
|
|
809
|
+
entrypoint = rcfg.get("static_entrypoint") or rcfg.get("entrypoint") or ""
|
|
810
|
+
ch_override = self._compute_chart_config_hash(process_name)
|
|
811
|
+
fh_override = self._compute_chart_function_hash(entrypoint) if entrypoint else None
|
|
812
|
+
return ch_override or config_hash, fh_override or function_hash
|
|
813
|
+
except Exception:
|
|
814
|
+
return config_hash, function_hash
|
|
815
|
+
|
|
816
|
+
def execute_graph(
|
|
817
|
+
self,
|
|
818
|
+
config: NetworkXGraphConfig,
|
|
819
|
+
context: StepContext,
|
|
820
|
+
run_id: Optional[str] = None,
|
|
821
|
+
resume_from_process: Optional[str] = None,
|
|
822
|
+
stop_after_process: bool | str = False,
|
|
823
|
+
) -> Dict[str, ExecutionResult]:
|
|
824
|
+
"""
|
|
825
|
+
Execute the NetworkX-based graph using Dask's advanced scheduler.
|
|
826
|
+
|
|
827
|
+
This is the main entry point that replaces the manual scheduling approach
|
|
828
|
+
with Dask's task scheduler for automatic dependency resolution and parallelization.
|
|
829
|
+
"""
|
|
830
|
+
|
|
831
|
+
process_graph = self._build_process_graph(config)
|
|
832
|
+
# Persist the main process graph so hashing can access per-process code_function
|
|
833
|
+
self.process_graph = process_graph
|
|
834
|
+
self._validate_process_dag(process_graph)
|
|
835
|
+
|
|
836
|
+
if self.state_manager and run_id and not stop_after_process:
|
|
837
|
+
self.state_manager.start_pipeline_execution(run_id, config.__dict__, self.cache_enabled)
|
|
838
|
+
|
|
839
|
+
try:
|
|
840
|
+
failure_mode_cfg = None
|
|
841
|
+
try:
|
|
842
|
+
failure_mode_cfg = (config.execution or {}).get("failure_mode")
|
|
843
|
+
except Exception:
|
|
844
|
+
failure_mode_cfg = None
|
|
845
|
+
except Exception:
|
|
846
|
+
pass
|
|
847
|
+
|
|
848
|
+
self.logger.info(f"Executing {len(process_graph.nodes)} processes with Dask scheduler")
|
|
849
|
+
self.logger.info(f"Process execution order will be determined by Dask: {list(nx.topological_sort(process_graph))}")
|
|
850
|
+
if self.scheduler_mode == 'distributed':
|
|
851
|
+
self.logger.info("Using distributed scheduler (external Dask Client)")
|
|
852
|
+
self._ensure_connected_to_scheduler()
|
|
853
|
+
# Wire distributed client into step_system so @step calls submit to workers
|
|
854
|
+
try:
|
|
855
|
+
from .step_system import set_distributed_client as _set_dc
|
|
856
|
+
_set_dc(self._distributed_client)
|
|
857
|
+
except Exception:
|
|
858
|
+
pass
|
|
859
|
+
else:
|
|
860
|
+
self.logger.info(f"Using threaded scheduler with {self.n_workers} workers")
|
|
861
|
+
# Ensure no distributed client is set in threaded mode
|
|
862
|
+
try:
|
|
863
|
+
from .step_system import set_distributed_client as _set_dc # type: ignore
|
|
864
|
+
_set_dc(None)
|
|
865
|
+
except Exception:
|
|
866
|
+
pass
|
|
867
|
+
|
|
868
|
+
execution_results = {}
|
|
869
|
+
|
|
870
|
+
try:
|
|
871
|
+
is_distributed = self.scheduler_mode == 'distributed'
|
|
872
|
+
if is_distributed:
|
|
873
|
+
self._ensure_connected_to_scheduler()
|
|
874
|
+
process_tasks: Dict[str, Any] = {}
|
|
875
|
+
|
|
876
|
+
topo_order = list(nx.topological_sort(process_graph))
|
|
877
|
+
|
|
878
|
+
# Limit scheduling for resume/single-process execution modes.
|
|
879
|
+
stop_target: Optional[str] = None
|
|
880
|
+
if isinstance(stop_after_process, str) and stop_after_process.strip():
|
|
881
|
+
stop_target = stop_after_process.strip()
|
|
882
|
+
elif stop_after_process and resume_from_process:
|
|
883
|
+
stop_target = str(resume_from_process)
|
|
884
|
+
|
|
885
|
+
targets: set[str] = set(topo_order)
|
|
886
|
+
if stop_target and stop_target in process_graph:
|
|
887
|
+
targets = {stop_target}
|
|
888
|
+
elif resume_from_process and resume_from_process in process_graph:
|
|
889
|
+
# Resume from a given process and continue downstream.
|
|
890
|
+
targets = set(nx.descendants(process_graph, resume_from_process))
|
|
891
|
+
targets.add(resume_from_process)
|
|
892
|
+
|
|
893
|
+
required_nodes: set[str] = set()
|
|
894
|
+
for n in targets:
|
|
895
|
+
required_nodes.add(n)
|
|
896
|
+
try:
|
|
897
|
+
required_nodes.update(nx.ancestors(process_graph, n))
|
|
898
|
+
except Exception:
|
|
899
|
+
pass
|
|
900
|
+
|
|
901
|
+
nodes_order = [n for n in topo_order if n in required_nodes]
|
|
902
|
+
|
|
903
|
+
for process_name in nodes_order:
|
|
904
|
+
process_config = process_graph.nodes[process_name]['config']
|
|
905
|
+
dependencies = list(process_graph.predecessors(process_name))
|
|
906
|
+
dep_tasks = [process_tasks[dep] for dep in dependencies if dep in process_tasks]
|
|
907
|
+
|
|
908
|
+
cached_task_created = False
|
|
909
|
+
|
|
910
|
+
if self.cache_enabled and self.state_manager:
|
|
911
|
+
# Unified enhanced hashing for cache lookup
|
|
912
|
+
try:
|
|
913
|
+
process_input_hash, process_config_hash, composite_fhash = self._compute_process_lookup_hashes(context, process_name)
|
|
914
|
+
process_config_hash, composite_fhash = self._maybe_apply_chart_hash_overrides(
|
|
915
|
+
process_name, process_config_hash, composite_fhash
|
|
916
|
+
)
|
|
917
|
+
except Exception:
|
|
918
|
+
process_input_hash = process_config_hash = composite_fhash = None
|
|
919
|
+
|
|
920
|
+
if composite_fhash is None:
|
|
921
|
+
try:
|
|
922
|
+
from .step_system import get_process_registry
|
|
923
|
+
pr = get_process_registry()
|
|
924
|
+
pdef = pr.get_process(self._get_lookup_name(process_name) or process_name) if pr else None
|
|
925
|
+
orig_fn = getattr(pdef, 'original_func', None) if pdef else None
|
|
926
|
+
composite_fhash = self.state_manager._compute_function_hash(orig_fn or getattr(pdef, 'runner', None)) if pdef else None
|
|
927
|
+
except Exception:
|
|
928
|
+
composite_fhash = None
|
|
929
|
+
|
|
930
|
+
try:
|
|
931
|
+
self.logger.info(
|
|
932
|
+
f"[CACHE] Lookup process={process_name} ih={process_input_hash} ch={process_config_hash} fh={composite_fhash}"
|
|
933
|
+
)
|
|
934
|
+
cached_data = self.state_manager.get_cached_process_result_with_metadata(
|
|
935
|
+
process_name,
|
|
936
|
+
input_hash=process_input_hash,
|
|
937
|
+
config_hash=process_config_hash,
|
|
938
|
+
function_hash=composite_fhash,
|
|
939
|
+
)
|
|
940
|
+
except Exception:
|
|
941
|
+
cached_data = None
|
|
942
|
+
|
|
943
|
+
if not is_distributed and cached_data is None:
|
|
944
|
+
try:
|
|
945
|
+
cached_data = self.state_manager.get_cached_process_result_with_metadata(process_name)
|
|
946
|
+
except Exception:
|
|
947
|
+
cached_data = None
|
|
948
|
+
|
|
949
|
+
# Extract result and metadata from cached data
|
|
950
|
+
cached_proc = None
|
|
951
|
+
cached_run_id = None
|
|
952
|
+
cached_metadata = {}
|
|
953
|
+
if cached_data is not None:
|
|
954
|
+
try:
|
|
955
|
+
cached_proc, cached_run_id, cached_metadata = cached_data
|
|
956
|
+
except Exception:
|
|
957
|
+
# Fallback if metadata extraction fails
|
|
958
|
+
cached_proc = cached_data if not isinstance(cached_data, tuple) else None
|
|
959
|
+
|
|
960
|
+
if cached_proc is not None:
|
|
961
|
+
self.logger.info(f"[CACHE] Hit for process={process_name}; scheduling placeholder")
|
|
962
|
+
if dep_tasks:
|
|
963
|
+
task = delayed(_return_placeholder_cached_process_execution_result_with_deps)(process_name, dep_tasks)
|
|
964
|
+
else:
|
|
965
|
+
task = delayed(_return_placeholder_cached_process_execution_result)(process_name)
|
|
966
|
+
process_tasks[process_name] = task
|
|
967
|
+
try:
|
|
968
|
+
context.step_results[process_name] = cached_proc
|
|
969
|
+
except Exception:
|
|
970
|
+
pass
|
|
971
|
+
cached_task_created = True
|
|
972
|
+
|
|
973
|
+
# Pre-record a 'cached' completion so the UI reflects cache hits immediately
|
|
974
|
+
# rather than after the compute() phase returns.
|
|
975
|
+
try:
|
|
976
|
+
if self.state_manager:
|
|
977
|
+
from .step_state_manager import ProcessExecutionResult as _ProcessExec # local import to avoid import cycles
|
|
978
|
+
enable_logging = self._get_logging_flag(process_name)
|
|
979
|
+
# Extract cached metadata for frontend display
|
|
980
|
+
cached_exec_time = cached_metadata.get('execution_time', 0.0) if isinstance(cached_metadata, dict) else 0.0
|
|
981
|
+
cached_started = cached_metadata.get('started_at') if isinstance(cached_metadata, dict) else None
|
|
982
|
+
cached_ended = cached_metadata.get('ended_at') if isinstance(cached_metadata, dict) else None
|
|
983
|
+
self.state_manager.record_process_completion(
|
|
984
|
+
run_id or 'default',
|
|
985
|
+
_ProcessExec(
|
|
986
|
+
process_name=process_name,
|
|
987
|
+
success=True,
|
|
988
|
+
result=cached_proc,
|
|
989
|
+
execution_time=0.0,
|
|
990
|
+
timestamp=datetime.now().isoformat(),
|
|
991
|
+
),
|
|
992
|
+
input_hash=process_input_hash,
|
|
993
|
+
config_hash=process_config_hash,
|
|
994
|
+
function_hash=composite_fhash,
|
|
995
|
+
was_cached=True,
|
|
996
|
+
enable_logging=enable_logging,
|
|
997
|
+
cached_run_id=cached_run_id,
|
|
998
|
+
cached_started_at=cached_started,
|
|
999
|
+
cached_ended_at=cached_ended,
|
|
1000
|
+
cached_execution_time=cached_exec_time,
|
|
1001
|
+
)
|
|
1002
|
+
except Exception:
|
|
1003
|
+
# Best-effort; if this fails we'll still write completion in the post-process phase
|
|
1004
|
+
pass
|
|
1005
|
+
|
|
1006
|
+
if not cached_task_created:
|
|
1007
|
+
self.logger.info(f"[CACHE] Miss or not loadable for process={process_name}; scheduling execution")
|
|
1008
|
+
ctx_payload = self._serialize_context_for_worker(context)
|
|
1009
|
+
# Provide dependency map for recursive signature hashing on workers
|
|
1010
|
+
try:
|
|
1011
|
+
dependency_map = {n: list(process_graph.predecessors(n)) for n in nodes_order}
|
|
1012
|
+
except Exception:
|
|
1013
|
+
dependency_map = {n: [] for n in nodes_order}
|
|
1014
|
+
proc_payload = {
|
|
1015
|
+
'name': process_config.name,
|
|
1016
|
+
'code_function': getattr(process_config, 'code_function', None),
|
|
1017
|
+
'process_type': getattr(process_config, 'process_type', 'process'),
|
|
1018
|
+
'cache_config': self._get_cache_config_for_worker() if self.cache_enabled and self.state_manager else None,
|
|
1019
|
+
'has_state_manager': bool(self.state_manager and self.cache_enabled),
|
|
1020
|
+
'logging': self._get_logging_flag(process_name),
|
|
1021
|
+
'dependencies': dependencies,
|
|
1022
|
+
'dependency_map': dependency_map,
|
|
1023
|
+
}
|
|
1024
|
+
|
|
1025
|
+
# Attach chart metadata and precomputed hashes for chart nodes
|
|
1026
|
+
try:
|
|
1027
|
+
if getattr(process_config, 'process_type', 'process') == 'chart':
|
|
1028
|
+
rcfg = self._get_reporting_cfg()
|
|
1029
|
+
entrypoint = rcfg.get('static_entrypoint') or rcfg.get('entrypoint') or ''
|
|
1030
|
+
function_hash_override = self._compute_chart_function_hash(entrypoint) if entrypoint else None
|
|
1031
|
+
config_hash_override = self._compute_chart_config_hash(process_config.name)
|
|
1032
|
+
spec = self._get_chart_spec(process_config.name)
|
|
1033
|
+
reporting_py = rcfg.get('reporting_python') or os.environ.get('MLOPS_REPORTING_PYTHON') or None
|
|
1034
|
+
chart_spec = {
|
|
1035
|
+
'name': process_config.name,
|
|
1036
|
+
'probe_paths': (spec.get('probe_paths') or {}),
|
|
1037
|
+
'args': list(rcfg.get('args') or []) + list(spec.get('args') or []),
|
|
1038
|
+
'theme': os.environ.get('MLOPS_CHART_THEME'),
|
|
1039
|
+
'entrypoint': entrypoint,
|
|
1040
|
+
'reporting_python': reporting_py,
|
|
1041
|
+
}
|
|
1042
|
+
proc_payload['hash_overrides'] = {
|
|
1043
|
+
'function_hash': function_hash_override,
|
|
1044
|
+
'config_hash': config_hash_override,
|
|
1045
|
+
}
|
|
1046
|
+
proc_payload['chart_spec'] = chart_spec
|
|
1047
|
+
except Exception:
|
|
1048
|
+
pass
|
|
1049
|
+
if dep_tasks:
|
|
1050
|
+
task = delayed(_worker_execute_process_with_deps)(proc_payload, ctx_payload, dep_tasks, run_id)
|
|
1051
|
+
else:
|
|
1052
|
+
task = delayed(_worker_execute_process_task)(proc_payload, ctx_payload, run_id)
|
|
1053
|
+
process_tasks[process_name] = task
|
|
1054
|
+
|
|
1055
|
+
if process_tasks:
|
|
1056
|
+
# Execute according to scheduler mode
|
|
1057
|
+
if is_distributed:
|
|
1058
|
+
try:
|
|
1059
|
+
futures = self._distributed_client.compute(list(process_tasks.values()))
|
|
1060
|
+
results_values = self._distributed_client.gather(futures)
|
|
1061
|
+
except Exception:
|
|
1062
|
+
results_values = compute(*process_tasks.values())
|
|
1063
|
+
else:
|
|
1064
|
+
results_values = compute(*process_tasks.values(), scheduler='threads', num_workers=self.n_workers)
|
|
1065
|
+
|
|
1066
|
+
proc_results = dict(zip(process_tasks.keys(), results_values))
|
|
1067
|
+
execution_results.update(proc_results)
|
|
1068
|
+
|
|
1069
|
+
# Surface worker-side logs in both modes
|
|
1070
|
+
try:
|
|
1071
|
+
for _pname, _pres in proc_results.items():
|
|
1072
|
+
try:
|
|
1073
|
+
_r = getattr(_pres, 'result', None)
|
|
1074
|
+
if isinstance(_r, dict) and '__logs__' in _r and _r['__logs__']:
|
|
1075
|
+
_logs = _r['__logs__']
|
|
1076
|
+
try:
|
|
1077
|
+
self.logger.info(f"[WorkerLogs][{_pname}] BEGIN")
|
|
1078
|
+
for _line in str(_logs).splitlines():
|
|
1079
|
+
self.logger.info(f"[{_pname}] {_line}")
|
|
1080
|
+
self.logger.info(f"[WorkerLogs][{_pname}] END")
|
|
1081
|
+
except Exception:
|
|
1082
|
+
pass
|
|
1083
|
+
except Exception:
|
|
1084
|
+
continue
|
|
1085
|
+
except Exception:
|
|
1086
|
+
pass
|
|
1087
|
+
|
|
1088
|
+
# Post-process results: errors, context hydration, cache recording
|
|
1089
|
+
for process_name, result in proc_results.items():
|
|
1090
|
+
# Rehydrate placeholders
|
|
1091
|
+
try:
|
|
1092
|
+
if getattr(result, 'was_cached', False) and not getattr(result, 'result', None):
|
|
1093
|
+
if self.cache_enabled and self.state_manager:
|
|
1094
|
+
try:
|
|
1095
|
+
ih, ch, fh = self._compute_process_lookup_hashes(context, process_name)
|
|
1096
|
+
ch, fh = self._maybe_apply_chart_hash_overrides(process_name, ch, fh)
|
|
1097
|
+
loaded = self.state_manager.get_cached_process_result(process_name, input_hash=ih, config_hash=ch, function_hash=fh)
|
|
1098
|
+
except Exception:
|
|
1099
|
+
loaded = None
|
|
1100
|
+
if loaded is not None:
|
|
1101
|
+
result.result = loaded
|
|
1102
|
+
except Exception:
|
|
1103
|
+
pass
|
|
1104
|
+
|
|
1105
|
+
if result.error is not None:
|
|
1106
|
+
# NOTE: Historically we supported a "stop_and_resume" mode. In practice, users
|
|
1107
|
+
# resume work by re-running and leveraging cache hits (process + step cache).
|
|
1108
|
+
# Treat stop_and_resume as deprecated and equivalent to a simple stop-on-failure.
|
|
1109
|
+
try:
|
|
1110
|
+
failure_mode = str((config.execution or {}).get("failure_mode", "stop") or "stop").strip().lower()
|
|
1111
|
+
except Exception:
|
|
1112
|
+
failure_mode = "stop"
|
|
1113
|
+
if failure_mode == "stop_and_resume":
|
|
1114
|
+
try:
|
|
1115
|
+
self.logger.warning(
|
|
1116
|
+
"Deprecated failure_mode 'stop_and_resume' encountered; treating as 'stop'. "
|
|
1117
|
+
"Re-run to reuse cached results."
|
|
1118
|
+
)
|
|
1119
|
+
except Exception:
|
|
1120
|
+
pass
|
|
1121
|
+
if self.state_manager and run_id and not stop_after_process:
|
|
1122
|
+
self.state_manager.complete_pipeline_execution(run_id, False)
|
|
1123
|
+
raise RuntimeError(f"Process {process_name} failed: {result.error}")
|
|
1124
|
+
|
|
1125
|
+
# Update driver context with clean result
|
|
1126
|
+
if result.error is None and getattr(result, 'result', None) is not None:
|
|
1127
|
+
try:
|
|
1128
|
+
clean_result = {k: v for k, v in result.result.items() if not k.startswith('__')} if isinstance(result.result, dict) else result.result
|
|
1129
|
+
context.step_results[process_name] = clean_result
|
|
1130
|
+
except Exception:
|
|
1131
|
+
pass
|
|
1132
|
+
|
|
1133
|
+
if self.cache_enabled and self.state_manager:
|
|
1134
|
+
try:
|
|
1135
|
+
from .step_state_manager import ProcessExecutionResult as _ProcessExec
|
|
1136
|
+
ih, ch, fh = None, None, None
|
|
1137
|
+
try:
|
|
1138
|
+
ih, ch, fh = self._compute_process_lookup_hashes(context, process_name)
|
|
1139
|
+
ch, fh = self._maybe_apply_chart_hash_overrides(process_name, ch, fh)
|
|
1140
|
+
self.logger.debug(f"[CACHE WRITE] process={process_name} ih={ih} ch={ch} fh={fh}")
|
|
1141
|
+
except Exception as e:
|
|
1142
|
+
self.logger.warning(f"[CACHE WRITE] Failed to compute hashes for {process_name}: {e}")
|
|
1143
|
+
enable_logging = self._get_logging_flag(process_name)
|
|
1144
|
+
|
|
1145
|
+
# Determine success based on error field
|
|
1146
|
+
is_success = result.error is None
|
|
1147
|
+
|
|
1148
|
+
self.state_manager.record_process_completion(
|
|
1149
|
+
run_id or 'default',
|
|
1150
|
+
_ProcessExec(
|
|
1151
|
+
process_name=process_name,
|
|
1152
|
+
success=is_success,
|
|
1153
|
+
result=result.result if is_success else None,
|
|
1154
|
+
execution_time=result.execution_time,
|
|
1155
|
+
timestamp=datetime.now().isoformat(),
|
|
1156
|
+
),
|
|
1157
|
+
input_hash=ih,
|
|
1158
|
+
config_hash=ch,
|
|
1159
|
+
function_hash=fh,
|
|
1160
|
+
was_cached=bool(getattr(result, 'was_cached', False)),
|
|
1161
|
+
enable_logging=enable_logging,
|
|
1162
|
+
)
|
|
1163
|
+
except Exception as e:
|
|
1164
|
+
self.logger.warning(f"❌ Failed to record process completion for {process_name}: {e}")
|
|
1165
|
+
|
|
1166
|
+
# Step-level cache recording - only if process succeeded
|
|
1167
|
+
if is_success:
|
|
1168
|
+
try:
|
|
1169
|
+
sr_map = result.result.get('__step_results__', {}) if isinstance(result.result, dict) else {}
|
|
1170
|
+
if isinstance(sr_map, dict) and sr_map:
|
|
1171
|
+
from .step_state_manager import StepExecutionResult as _StepExec
|
|
1172
|
+
for _sname, _sres in sr_map.items():
|
|
1173
|
+
try:
|
|
1174
|
+
step_def = self.step_registry.get_step(_sname)
|
|
1175
|
+
if not step_def:
|
|
1176
|
+
continue
|
|
1177
|
+
tmp_ctx = None
|
|
1178
|
+
try:
|
|
1179
|
+
from .step_system import StepContext as _Ctx
|
|
1180
|
+
tmp_ctx = _Ctx(
|
|
1181
|
+
project_id=getattr(context, 'project_id', None),
|
|
1182
|
+
run_id=run_id or getattr(context, 'run_id', None),
|
|
1183
|
+
tracker=None,
|
|
1184
|
+
step_results=sr_map,
|
|
1185
|
+
global_config=getattr(context, 'global_config', {}) or {},
|
|
1186
|
+
data_paths=getattr(context, 'data_paths', {}) or {},
|
|
1187
|
+
checkpoint_dir=getattr(context, 'checkpoint_dir', None),
|
|
1188
|
+
)
|
|
1189
|
+
try:
|
|
1190
|
+
tmp_ctx.current_process = process_name # type: ignore[attr-defined]
|
|
1191
|
+
except Exception:
|
|
1192
|
+
pass
|
|
1193
|
+
except Exception:
|
|
1194
|
+
tmp_ctx = context
|
|
1195
|
+
input_hash = self._compute_step_input_hash(step_def, tmp_ctx)
|
|
1196
|
+
try:
|
|
1197
|
+
function_hash = self.state_manager._compute_function_hash(step_def.original_func)
|
|
1198
|
+
except Exception:
|
|
1199
|
+
function_hash = None
|
|
1200
|
+
try:
|
|
1201
|
+
config_source = getattr(context, 'global_config', None) or {}
|
|
1202
|
+
config_hash = self.state_manager._compute_hash(config_source)
|
|
1203
|
+
except Exception:
|
|
1204
|
+
config_hash = None
|
|
1205
|
+
enable_logging = getattr(step_def, 'logging', True) if step_def else True
|
|
1206
|
+
# Extract execution time from step result metadata
|
|
1207
|
+
step_exec_time = 0.0
|
|
1208
|
+
if isinstance(_sres, dict):
|
|
1209
|
+
step_exec_time = float(_sres.get('__execution_time__', 0.0))
|
|
1210
|
+
self.state_manager.record_step_completion(
|
|
1211
|
+
run_id or 'default',
|
|
1212
|
+
_StepExec(
|
|
1213
|
+
step_name=_sname,
|
|
1214
|
+
success=True,
|
|
1215
|
+
result=_sres,
|
|
1216
|
+
execution_time=step_exec_time,
|
|
1217
|
+
timestamp=datetime.now().isoformat(),
|
|
1218
|
+
),
|
|
1219
|
+
input_hash=input_hash,
|
|
1220
|
+
config_hash=config_hash,
|
|
1221
|
+
function_name=_sname,
|
|
1222
|
+
function_hash=function_hash,
|
|
1223
|
+
was_cached=bool(_sres.get('__was_cached__')) if isinstance(_sres, dict) else False,
|
|
1224
|
+
process_name=process_name,
|
|
1225
|
+
enable_logging=enable_logging,
|
|
1226
|
+
)
|
|
1227
|
+
except Exception:
|
|
1228
|
+
continue
|
|
1229
|
+
except Exception:
|
|
1230
|
+
pass
|
|
1231
|
+
|
|
1232
|
+
# Log per-step cached hits summary
|
|
1233
|
+
try:
|
|
1234
|
+
sr_map2 = result.result.get('__step_results__', {}) if isinstance(result.result, dict) else {}
|
|
1235
|
+
if isinstance(sr_map2, dict) and sr_map2:
|
|
1236
|
+
total_steps = len(sr_map2)
|
|
1237
|
+
hits = 0
|
|
1238
|
+
for __v in sr_map2.values():
|
|
1239
|
+
try:
|
|
1240
|
+
if isinstance(__v, dict) and __v.get('__was_cached__'):
|
|
1241
|
+
hits += 1
|
|
1242
|
+
except Exception:
|
|
1243
|
+
continue
|
|
1244
|
+
mode_label = 'Dask Distributed' if is_distributed else 'Dask Threads'
|
|
1245
|
+
self.logger.info(f"Process {process_name} cached steps: {hits}/{total_steps} [{mode_label}]")
|
|
1246
|
+
except Exception:
|
|
1247
|
+
pass
|
|
1248
|
+
|
|
1249
|
+
if self.state_manager and run_id and not stop_after_process:
|
|
1250
|
+
self.state_manager.complete_pipeline_execution(run_id, True)
|
|
1251
|
+
stats = self.state_manager.get_pipeline_stats(run_id)
|
|
1252
|
+
try:
|
|
1253
|
+
if stats and 'cache_hit_rate' in stats:
|
|
1254
|
+
self.logger.info(f"Pipeline completed with Dask scheduler. Cache hit rate: {stats['cache_hit_rate']:.1%} "
|
|
1255
|
+
f"({stats.get('cache_hit_count', 0)}/{stats.get('completed_steps', 0)} steps)")
|
|
1256
|
+
else:
|
|
1257
|
+
self.logger.info("Pipeline completed with Dask scheduler.")
|
|
1258
|
+
except Exception:
|
|
1259
|
+
self.logger.info("Pipeline completed with Dask scheduler.")
|
|
1260
|
+
except Exception:
|
|
1261
|
+
if self.state_manager and run_id and not stop_after_process:
|
|
1262
|
+
self.state_manager.complete_pipeline_execution(run_id, False)
|
|
1263
|
+
raise
|
|
1264
|
+
|
|
1265
|
+
return execution_results
|