expops 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. expops-0.1.3.dist-info/METADATA +826 -0
  2. expops-0.1.3.dist-info/RECORD +86 -0
  3. expops-0.1.3.dist-info/WHEEL +5 -0
  4. expops-0.1.3.dist-info/entry_points.txt +3 -0
  5. expops-0.1.3.dist-info/licenses/LICENSE +674 -0
  6. expops-0.1.3.dist-info/top_level.txt +1 -0
  7. mlops/__init__.py +0 -0
  8. mlops/__main__.py +11 -0
  9. mlops/_version.py +34 -0
  10. mlops/adapters/__init__.py +12 -0
  11. mlops/adapters/base.py +86 -0
  12. mlops/adapters/config_schema.py +89 -0
  13. mlops/adapters/custom/__init__.py +3 -0
  14. mlops/adapters/custom/custom_adapter.py +447 -0
  15. mlops/adapters/plugin_manager.py +113 -0
  16. mlops/adapters/sklearn/__init__.py +3 -0
  17. mlops/adapters/sklearn/adapter.py +94 -0
  18. mlops/cluster/__init__.py +3 -0
  19. mlops/cluster/controller.py +496 -0
  20. mlops/cluster/process_runner.py +91 -0
  21. mlops/cluster/providers.py +258 -0
  22. mlops/core/__init__.py +95 -0
  23. mlops/core/custom_model_base.py +38 -0
  24. mlops/core/dask_networkx_executor.py +1265 -0
  25. mlops/core/executor_worker.py +1239 -0
  26. mlops/core/experiment_tracker.py +81 -0
  27. mlops/core/graph_types.py +64 -0
  28. mlops/core/networkx_parser.py +135 -0
  29. mlops/core/payload_spill.py +278 -0
  30. mlops/core/pipeline_utils.py +162 -0
  31. mlops/core/process_hashing.py +216 -0
  32. mlops/core/step_state_manager.py +1298 -0
  33. mlops/core/step_system.py +956 -0
  34. mlops/core/workspace.py +99 -0
  35. mlops/environment/__init__.py +10 -0
  36. mlops/environment/base.py +43 -0
  37. mlops/environment/conda_manager.py +307 -0
  38. mlops/environment/factory.py +70 -0
  39. mlops/environment/pyenv_manager.py +146 -0
  40. mlops/environment/setup_env.py +31 -0
  41. mlops/environment/system_manager.py +66 -0
  42. mlops/environment/utils.py +105 -0
  43. mlops/environment/venv_manager.py +134 -0
  44. mlops/main.py +527 -0
  45. mlops/managers/project_manager.py +400 -0
  46. mlops/managers/reproducibility_manager.py +575 -0
  47. mlops/platform.py +996 -0
  48. mlops/reporting/__init__.py +16 -0
  49. mlops/reporting/context.py +187 -0
  50. mlops/reporting/entrypoint.py +292 -0
  51. mlops/reporting/kv_utils.py +77 -0
  52. mlops/reporting/registry.py +50 -0
  53. mlops/runtime/__init__.py +9 -0
  54. mlops/runtime/context.py +34 -0
  55. mlops/runtime/env_export.py +113 -0
  56. mlops/storage/__init__.py +12 -0
  57. mlops/storage/adapters/__init__.py +9 -0
  58. mlops/storage/adapters/gcp_kv_store.py +778 -0
  59. mlops/storage/adapters/gcs_object_store.py +96 -0
  60. mlops/storage/adapters/memory_store.py +240 -0
  61. mlops/storage/adapters/redis_store.py +438 -0
  62. mlops/storage/factory.py +199 -0
  63. mlops/storage/interfaces/__init__.py +6 -0
  64. mlops/storage/interfaces/kv_store.py +118 -0
  65. mlops/storage/path_utils.py +38 -0
  66. mlops/templates/premier-league/charts/plot_metrics.js +70 -0
  67. mlops/templates/premier-league/charts/plot_metrics.py +145 -0
  68. mlops/templates/premier-league/charts/requirements.txt +6 -0
  69. mlops/templates/premier-league/configs/cluster_config.yaml +13 -0
  70. mlops/templates/premier-league/configs/project_config.yaml +207 -0
  71. mlops/templates/premier-league/data/England CSV.csv +12154 -0
  72. mlops/templates/premier-league/models/premier_league_model.py +638 -0
  73. mlops/templates/premier-league/requirements.txt +8 -0
  74. mlops/templates/sklearn-basic/README.md +22 -0
  75. mlops/templates/sklearn-basic/charts/plot_metrics.py +85 -0
  76. mlops/templates/sklearn-basic/charts/requirements.txt +3 -0
  77. mlops/templates/sklearn-basic/configs/project_config.yaml +64 -0
  78. mlops/templates/sklearn-basic/data/train.csv +14 -0
  79. mlops/templates/sklearn-basic/models/model.py +62 -0
  80. mlops/templates/sklearn-basic/requirements.txt +10 -0
  81. mlops/web/__init__.py +3 -0
  82. mlops/web/server.py +585 -0
  83. mlops/web/ui/index.html +52 -0
  84. mlops/web/ui/mlops-charts.js +357 -0
  85. mlops/web/ui/script.js +1244 -0
  86. mlops/web/ui/styles.css +248 -0
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional, Protocol
4
+
5
+
6
+ class ExperimentTracker(Protocol):
7
+ """
8
+ Minimal interface for an experiment tracker.
9
+
10
+ The platform uses this for optional experiment tracking (params/metrics/artifacts/tags)
11
+ alongside the built-in KV-store metric logging in `mlops.core.step_system`.
12
+ """
13
+
14
+ def start_run(
15
+ self,
16
+ run_name: Optional[str] = None,
17
+ run_id: Optional[str] = None,
18
+ tags: Optional[Dict[str, Any]] = None,
19
+ ) -> Any:
20
+ """Start a new run (returns context manager or run handle)."""
21
+ ...
22
+
23
+ def end_run(self, status: Optional[str] = "FINISHED") -> None:
24
+ """End the current active run."""
25
+ ...
26
+
27
+
28
+ class NoOpExperimentTracker(ExperimentTracker):
29
+ """
30
+ Default tracker: prints a few lifecycle messages but intentionally ignores metrics
31
+ to avoid noisy logs. Safe fallback when no experiment tracking backend is configured.
32
+ """
33
+
34
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
35
+ self.config = config if config else {}
36
+ self.run_active = False
37
+ self.current_run_id = None
38
+ print(f"[NoOpTracker] Initialized with config: {self.config}")
39
+
40
+ def start_run(
41
+ self,
42
+ run_name: Optional[str] = None,
43
+ run_id: Optional[str] = None,
44
+ tags: Optional[Dict[str, Any]] = None,
45
+ ) -> "NoOpExperimentTracker":
46
+ import uuid
47
+
48
+ if self.run_active:
49
+ print(
50
+ f"[NoOpTracker] Warning: A run (ID: {self.current_run_id}) is already active. "
51
+ f"Starting a new nested run is not fully supported by NoOpTracker; state will be overridden."
52
+ )
53
+
54
+ self.current_run_id = run_id if run_id else str(uuid.uuid4())
55
+ self.run_active = True
56
+
57
+ run_display_name = run_name if run_name else "default_run"
58
+ print(f"[NoOpTracker] Started run. Name: '{run_display_name}', ID: '{self.current_run_id}'")
59
+ if tags:
60
+ # Intentionally ignore tags in the NoOp tracker.
61
+ pass
62
+ return self # Return self to allow use as a context manager
63
+
64
+ def end_run(self, status: Optional[str] = "FINISHED") -> None:
65
+ if self.run_active:
66
+ print(f"[NoOpTracker][RunID: {self.current_run_id}] Ended run with status: {status}")
67
+ self.run_active = False
68
+ self.current_run_id = None
69
+ else:
70
+ print("[NoOpTracker] No active run to end.")
71
+
72
+ def __enter__(self):
73
+ if not self.run_active:
74
+ self.start_run()
75
+ return self
76
+
77
+ def __exit__(self, exc_type, exc_val, exc_tb):
78
+ status = "FAILED" if exc_type else "FINISHED"
79
+ self.end_run(status=status)
80
+
81
+
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+
7
+
8
+ class NodeType(Enum):
9
+ """Types of nodes in the execution graph."""
10
+ PROCESS = "process"
11
+ STEP = "step"
12
+
13
+
14
+ @dataclass
15
+ class ProcessConfig:
16
+ """Configuration for a process node."""
17
+ name: str
18
+ depends_on: list[str] | None = None
19
+ parallel: bool = True
20
+ code_function: Optional[str] = None # Function name to execute (if different from name)
21
+ process_type: str = "process" # e.g., "process" or special types like "chart"
22
+
23
+ def __post_init__(self) -> None:
24
+ if self.depends_on is None:
25
+ self.depends_on = []
26
+
27
+
28
+ @dataclass
29
+ class StepConfig:
30
+ """Configuration for a step node."""
31
+ name: str
32
+ type: str = "step"
33
+ process: Optional[str] = None
34
+ inputs: list[str] | None = None
35
+ outputs: list[str] | None = None
36
+ loop_back_to: Optional[str] = None
37
+ condition: Optional[str] = None
38
+ parallel: bool = True
39
+
40
+
41
+ @dataclass
42
+ class NetworkXGraphConfig:
43
+ """Configuration for NetworkX-based graph execution."""
44
+ processes: list[ProcessConfig] | None = None
45
+ steps: list[StepConfig] | None = None
46
+ execution: dict[str, Any] | None = None
47
+
48
+ def __post_init__(self) -> None:
49
+ if self.processes is None:
50
+ self.processes = []
51
+ if self.steps is None:
52
+ self.steps = []
53
+ if self.execution is None:
54
+ self.execution = {}
55
+
56
+
57
+ @dataclass
58
+ class ExecutionResult:
59
+ """Result of executing a node (process or step)."""
60
+ name: str
61
+ result: Optional[dict[str, Any]] = None # Dictionary containing step/process results
62
+ execution_time: float = 0.0
63
+ was_cached: bool = False
64
+ error: Optional[str] = None
@@ -0,0 +1,135 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
4
+ import logging
5
+
6
+ from .graph_types import NetworkXGraphConfig, ProcessConfig, StepConfig
7
+ from .step_system import get_step_registry
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class NetworkXPipelineParser:
13
+ """Parser for NetworkX-based pipeline configurations with loop support."""
14
+
15
+ def __init__(self) -> None:
16
+ self.processes: Dict[str, ProcessConfig] = {}
17
+ self.steps: Dict[str, StepConfig] = {}
18
+
19
+ def parse_pipeline_config(self, pipeline_config: Dict[str, Any]) -> NetworkXGraphConfig:
20
+ """Parse pipeline config and return NetworkX graph configuration."""
21
+ self.processes = {}
22
+ self.steps = {}
23
+
24
+ if "processes" in pipeline_config:
25
+ self._parse_networkx_format(pipeline_config)
26
+
27
+ config = self._generate_networkx_config(pipeline_config)
28
+
29
+ logger.info(f"Parsed pipeline with {len(self.processes)} processes and {len(self.steps)} steps")
30
+ return config
31
+
32
+ def _parse_networkx_format(self, pipeline_config: Dict[str, Any]) -> None:
33
+ """Parse NetworkX configuration format with DAG flow support."""
34
+
35
+ if "process_adjlist" in pipeline_config:
36
+ self._parse_process_adjlist(pipeline_config["process_adjlist"])
37
+
38
+ for process_data in pipeline_config.get("processes", []):
39
+ process_name = process_data["name"]
40
+ code_function = process_data.get("code_function") # Extract code_function if provided
41
+ proc_type = process_data.get("type", "process")
42
+
43
+ if process_name in self.processes:
44
+ process_config = self.processes[process_name]
45
+ process_config.parallel = process_data.get("parallel", process_config.parallel)
46
+ # Update code_function if provided in config
47
+ if code_function:
48
+ process_config.code_function = code_function
49
+ # Update process_type if provided (important for chart processes)
50
+ if proc_type:
51
+ process_config.process_type = str(proc_type)
52
+ logger.debug(f"Updated process '{process_name}' type to '{proc_type}'")
53
+ else:
54
+ process_config = ProcessConfig(
55
+ name=process_name,
56
+ parallel=process_data.get("parallel", True),
57
+ code_function=code_function, # Set code_function from config
58
+ process_type=str(proc_type)
59
+ )
60
+
61
+ self.processes[process_name] = process_config
62
+
63
+ # Manual-step mode: steps are executed inside process runners, not scheduled as nodes.
64
+
65
+ def _parse_process_adjlist(self, adjlist_value: Any) -> None:
66
+ """Parse process-level DAG from adjacency list string or list of lines.
67
+
68
+ Adjacency list lines follow NetworkX semantics: first token is the source node,
69
+ subsequent tokens are target nodes. Lines may include comments after a '#'.
70
+ """
71
+ if isinstance(adjlist_value, str):
72
+ lines = adjlist_value.splitlines()
73
+ elif isinstance(adjlist_value, list):
74
+ lines = [str(x) for x in adjlist_value]
75
+ else:
76
+ raise ValueError("process_adjlist must be a string or list of lines")
77
+
78
+ for raw_line in lines:
79
+ line = raw_line.strip()
80
+ if not line:
81
+ continue
82
+ # Remove comments after '#'
83
+ if "#" in line:
84
+ line = line.split("#", 1)[0].strip()
85
+ if not line:
86
+ continue
87
+ parts = line.split()
88
+ if not parts:
89
+ continue
90
+ source = parts[0]
91
+ targets = parts[1:]
92
+
93
+ # Ensure source process exists
94
+ if source not in self.processes:
95
+ self.processes[source] = ProcessConfig(name=source, depends_on=[])
96
+
97
+ # Add or update targets with dependency on source
98
+ for target in targets:
99
+ if target not in self.processes:
100
+ self.processes[target] = ProcessConfig(name=target, depends_on=[source])
101
+ else:
102
+ deps = self.processes[target].depends_on or []
103
+ if source not in deps:
104
+ self.processes[target].depends_on = deps + [source]
105
+
106
+ def _discover_steps_from_registry(self) -> None:
107
+ """Manual-step mode: keep for compatibility; intentionally a no-op."""
108
+ try:
109
+ step_registry = get_step_registry()
110
+ registered_steps = step_registry.list_steps() if step_registry else []
111
+ logger.debug(f"Manual-step mode enabled; ignoring {len(registered_steps)} registered steps during parsing")
112
+ except Exception:
113
+ logger.debug("Manual-step mode enabled; no step registry available")
114
+
115
+ def _generate_networkx_config(self, pipeline_config: Dict[str, Any]) -> NetworkXGraphConfig:
116
+ """Generate NetworkX configuration from parsed processes and steps."""
117
+
118
+ execution_config = pipeline_config.get("execution", {})
119
+ execution = {
120
+ "parallel": execution_config.get("parallel", True),
121
+ "failure_mode": execution_config.get("failure_mode", "stop"),
122
+ "max_workers": execution_config.get("max_workers", 4)
123
+ }
124
+
125
+ return NetworkXGraphConfig(
126
+ processes=list(self.processes.values()),
127
+ steps=list(self.steps.values()),
128
+ execution=execution
129
+ )
130
+
131
+
132
+ def parse_networkx_pipeline_from_config(pipeline_config: Dict[str, Any]) -> NetworkXGraphConfig:
133
+ """Parse pipeline configuration into NetworkX format."""
134
+ parser = NetworkXPipelineParser()
135
+ return parser.parse_pipeline_config(pipeline_config)
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import os
5
+ import uuid
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Iterable, Tuple, Optional
9
+
10
+ try:
11
+ import numpy as np # type: ignore
12
+ except Exception: # pragma: no cover
13
+ np = None # type: ignore
14
+
15
+ if np is not None:
16
+
17
+ class SpillArray(np.ndarray): # type: ignore[misc]
18
+ """ndarray subclass with list-like truthiness semantics."""
19
+
20
+ def __new__(cls, input_array: "np.ndarray", origin_type: Optional[str] = None):
21
+ obj = np.asarray(input_array).view(cls)
22
+ obj._origin_type = origin_type # type: ignore[attr-defined]
23
+ return obj
24
+
25
+ def __array_finalize__(self, obj):
26
+ if obj is None:
27
+ return
28
+ self._origin_type = getattr(obj, "_origin_type", None)
29
+
30
+ def __bool__(self) -> bool: # pragma: no cover - trivial behaviour
31
+ return bool(self.size)
32
+
33
+ else: # pragma: no cover - numpy unavailable
34
+ SpillArray = None # type: ignore
35
+
36
+ PAYLOAD_REF_KEY = "__mlops_payload_ref__"
37
+ PAYLOAD_META_VERSION = 2
38
+
39
+
40
+ def _get_logger() -> logging.Logger:
41
+ return logging.getLogger(__name__)
42
+
43
+
44
+ def _coerce_array(value: Any) -> Optional["np.ndarray"]:
45
+ """Best-effort conversion of supported payloads to a numpy array."""
46
+ if np is None:
47
+ return None
48
+ if isinstance(value, np.ndarray):
49
+ return value
50
+ if hasattr(value, "to_numpy"):
51
+ try:
52
+ arr = value.to_numpy()
53
+ return arr if isinstance(arr, np.ndarray) else None
54
+ except Exception:
55
+ return None
56
+ if isinstance(value, (list, tuple)):
57
+ try:
58
+ arr = np.asarray(value)
59
+ if arr.dtype == object:
60
+ return None
61
+ return arr
62
+ except Exception:
63
+ return None
64
+ return None
65
+
66
+
67
+ def _estimate_bytes(value: Any) -> int:
68
+ if np is not None and isinstance(value, np.ndarray):
69
+ return int(value.nbytes)
70
+ if isinstance(value, (bytes, bytearray)):
71
+ return len(value)
72
+ if isinstance(value, (list, tuple)):
73
+ try:
74
+ if all(isinstance(v, (int, float)) for v in value):
75
+ return len(value) * 8
76
+ except Exception:
77
+ return 0
78
+ return 0
79
+
80
+
81
+ def _should_spill(value: Any, threshold_bytes: int) -> bool:
82
+ arr = _coerce_array(value)
83
+ if arr is None:
84
+ return False
85
+ approx = _estimate_bytes(arr)
86
+ return approx >= threshold_bytes
87
+
88
+
89
+ def _serialize_array(arr: "np.ndarray", origin_type: Optional[str] = None) -> Tuple[bytes, Dict[str, Any]]:
90
+ buf = io.BytesIO()
91
+ np.savez_compressed(buf, data=arr)
92
+ meta = {
93
+ "shape": list(arr.shape),
94
+ "dtype": str(arr.dtype),
95
+ "approx_bytes": int(arr.nbytes),
96
+ "format": "npz",
97
+ "origin_type": origin_type or "ndarray",
98
+ }
99
+ return buf.getvalue(), meta
100
+
101
+
102
+ def _wrap_spilled_array(arr: "np.ndarray", origin_type: Optional[str]) -> Any:
103
+ if origin_type == "list":
104
+ return arr.tolist()
105
+ if origin_type == "tuple":
106
+ return tuple(arr.tolist())
107
+ if np is not None and isinstance(arr, np.ndarray) and SpillArray is not None:
108
+ try:
109
+ wrapped = arr.view(SpillArray)
110
+ wrapped._origin_type = origin_type # type: ignore[attr-defined]
111
+ return wrapped
112
+ except Exception:
113
+ return arr
114
+ return arr
115
+
116
+
117
+ def _build_payload_filename(run_id: Optional[str], process_name: Optional[str], key_path: Iterable[str]) -> str:
118
+ segments = ["payloads"]
119
+ if run_id:
120
+ segments.append(run_id)
121
+ if process_name:
122
+ segments.append(process_name)
123
+ path_tuple = tuple(key_path)
124
+ if path_tuple:
125
+ path_segment = "-".join(part or "part" for part in path_tuple)
126
+ else:
127
+ path_segment = "data"
128
+ segments.append(path_segment)
129
+ segments.append(str(uuid.uuid4()))
130
+ filename = "/".join(s.strip("/").replace(" ", "_") for s in segments)
131
+ return f"{filename}.npz"
132
+
133
+
134
+ def _store_bytes(state_manager: Any, filename: str, payload: bytes) -> str:
135
+ """Persist bytes via object store when configured, otherwise on the local cache path."""
136
+ if getattr(state_manager, "object_store", None):
137
+ try:
138
+ build_uri = getattr(state_manager, "_build_object_uri", None)
139
+ if callable(build_uri):
140
+ uri = build_uri(filename)
141
+ else:
142
+ uri = state_manager.object_store.build_uri(filename) # type: ignore[call-arg]
143
+ state_manager.object_store.put_bytes(uri, payload, content_type="application/octet-stream") # type: ignore[attr-defined]
144
+ return uri
145
+ except Exception as e:
146
+ _get_logger().warning(f"[PayloadSpill] Failed to put bytes to object store ({filename}): {e}")
147
+ cache_dir = getattr(state_manager, "cache_dir", Path("."))
148
+ local_path = Path(cache_dir) / filename
149
+ local_path.parent.mkdir(parents=True, exist_ok=True)
150
+ with open(local_path, "wb") as fout:
151
+ fout.write(payload)
152
+ return str(local_path.resolve())
153
+
154
+
155
+ def spill_large_payloads(result: Dict[str, Any],
156
+ state_manager: Any,
157
+ run_id: Optional[str],
158
+ process_name: Optional[str],
159
+ threshold_bytes: int = 5_000_000) -> Dict[str, Any]:
160
+ """
161
+ Replace large numeric payloads inside a result dict with lightweight references.
162
+ """
163
+ if not isinstance(result, dict) or state_manager is None or np is None:
164
+ return result
165
+
166
+ def _process(key: str, value: Any, path: Tuple[str, ...]) -> Any:
167
+ if isinstance(value, dict):
168
+ return {k: _process(k, v, path + (k,)) for k, v in value.items()}
169
+ if isinstance(value, list):
170
+ if _should_spill(value, threshold_bytes):
171
+ return _spill_value(value, path)
172
+ try:
173
+ return [_process(str(idx), item, path + (str(idx),)) for idx, item in enumerate(value)]
174
+ except Exception:
175
+ pass
176
+ if isinstance(value, tuple):
177
+ if _should_spill(value, threshold_bytes):
178
+ return _spill_value(list(value), path)
179
+ return tuple(_process(str(idx), item, path + (str(idx),)) for idx, item in enumerate(value))
180
+ if _should_spill(value, threshold_bytes):
181
+ return _spill_value(value, path)
182
+ return value
183
+
184
+ def _spill_value(payload_value: Any, key_path: Tuple[str, ...]) -> Dict[str, Any]:
185
+ arr = _coerce_array(payload_value)
186
+ if arr is None:
187
+ return payload_value # type: ignore[return-value]
188
+ if isinstance(payload_value, list):
189
+ origin_type = "list"
190
+ elif isinstance(payload_value, tuple):
191
+ origin_type = "tuple"
192
+ elif np is not None and isinstance(payload_value, np.ndarray):
193
+ origin_type = "ndarray"
194
+ else:
195
+ origin_type = type(payload_value).__name__
196
+ data_bytes, meta = _serialize_array(arr, origin_type=origin_type)
197
+ path_tuple = key_path or ("payload",)
198
+ filename = _build_payload_filename(run_id, process_name, path_tuple)
199
+ uri = _store_bytes(state_manager, filename, data_bytes)
200
+ ref = {
201
+ PAYLOAD_REF_KEY: True,
202
+ "uri": uri,
203
+ "meta": meta,
204
+ "version": PAYLOAD_META_VERSION,
205
+ "key_path": "/".join(path_tuple),
206
+ "process": process_name,
207
+ "run_id": run_id,
208
+ }
209
+ try:
210
+ _get_logger().info(f"[PayloadSpill] Spilled payload for {process_name}:{ref['key_path']} -> {uri} ({meta['approx_bytes']} bytes)")
211
+ except Exception:
212
+ pass
213
+ return ref
214
+
215
+ new_result = {}
216
+ for k, v in result.items():
217
+ key = str(k)
218
+ new_result[key] = _process(key, v, (key,))
219
+ return new_result
220
+
221
+
222
+ def hydrate_payload_refs(data: Any, state_manager: Any) -> Any:
223
+ """Replace payload reference dicts with their hydrated numpy arrays."""
224
+ if state_manager is None or np is None:
225
+ return data
226
+ if isinstance(data, dict):
227
+ if data.get(PAYLOAD_REF_KEY):
228
+ return _load_payload(data, state_manager)
229
+ return {k: hydrate_payload_refs(v, state_manager) for k, v in data.items()}
230
+ if isinstance(data, list):
231
+ return [hydrate_payload_refs(v, state_manager) for v in data]
232
+ if isinstance(data, tuple):
233
+ return tuple(hydrate_payload_refs(v, state_manager) for v in data)
234
+ return data
235
+
236
+
237
+ def _load_payload(ref: Dict[str, Any], state_manager: Any) -> Any:
238
+ uri = ref.get("uri")
239
+ if not uri:
240
+ return ref
241
+ try:
242
+ payload_bytes = None
243
+ if str(uri).startswith("gs://") and getattr(state_manager, "object_store", None):
244
+ payload_bytes = state_manager.object_store.get_bytes(uri) # type: ignore[attr-defined]
245
+ else:
246
+ path = Path(uri)
247
+ if not path.is_absolute() and getattr(state_manager, "cache_dir", None):
248
+ path = Path(state_manager.cache_dir) / path
249
+ with open(path, "rb") as fin:
250
+ payload_bytes = fin.read()
251
+ if payload_bytes is None:
252
+ raise RuntimeError("No payload bytes resolved")
253
+ with np.load(io.BytesIO(payload_bytes), allow_pickle=False) as npz:
254
+ arr = npz["data"]
255
+ meta = ref.get("meta") if isinstance(ref.get("meta"), dict) else {}
256
+ origin_type = meta.get("origin_type") if isinstance(meta, dict) else None
257
+ return _wrap_spilled_array(arr, origin_type)
258
+ except Exception as e:
259
+ _get_logger().warning(f"[PayloadSpill] Failed to hydrate payload {uri}: {e}")
260
+ return ref
261
+
262
+
263
+ if __name__ == "__main__": # pragma: no cover - developer self-test hook
264
+ logging.basicConfig(level=logging.INFO)
265
+ if np is None:
266
+ print("NumPy is not available; skipping payload spill self-test.")
267
+ else:
268
+ class _TempStateManager:
269
+ def __init__(self) -> None:
270
+ self.cache_dir = Path(os.environ.get("PAYLOAD_SPILL_TMP", "/tmp/payload_spill_test"))
271
+ self.object_store = None
272
+
273
+ sm = _TempStateManager()
274
+ sample = {"X_test": np.random.rand(2000, 200)} # ~3.2 MB -> forces spill
275
+ spilled = spill_large_payloads(sample, sm, run_id="selftest", process_name="demo", threshold_bytes=1_000_000)
276
+ print(spilled)
277
+
278
+