expops 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- expops-0.1.3.dist-info/METADATA +826 -0
- expops-0.1.3.dist-info/RECORD +86 -0
- expops-0.1.3.dist-info/WHEEL +5 -0
- expops-0.1.3.dist-info/entry_points.txt +3 -0
- expops-0.1.3.dist-info/licenses/LICENSE +674 -0
- expops-0.1.3.dist-info/top_level.txt +1 -0
- mlops/__init__.py +0 -0
- mlops/__main__.py +11 -0
- mlops/_version.py +34 -0
- mlops/adapters/__init__.py +12 -0
- mlops/adapters/base.py +86 -0
- mlops/adapters/config_schema.py +89 -0
- mlops/adapters/custom/__init__.py +3 -0
- mlops/adapters/custom/custom_adapter.py +447 -0
- mlops/adapters/plugin_manager.py +113 -0
- mlops/adapters/sklearn/__init__.py +3 -0
- mlops/adapters/sklearn/adapter.py +94 -0
- mlops/cluster/__init__.py +3 -0
- mlops/cluster/controller.py +496 -0
- mlops/cluster/process_runner.py +91 -0
- mlops/cluster/providers.py +258 -0
- mlops/core/__init__.py +95 -0
- mlops/core/custom_model_base.py +38 -0
- mlops/core/dask_networkx_executor.py +1265 -0
- mlops/core/executor_worker.py +1239 -0
- mlops/core/experiment_tracker.py +81 -0
- mlops/core/graph_types.py +64 -0
- mlops/core/networkx_parser.py +135 -0
- mlops/core/payload_spill.py +278 -0
- mlops/core/pipeline_utils.py +162 -0
- mlops/core/process_hashing.py +216 -0
- mlops/core/step_state_manager.py +1298 -0
- mlops/core/step_system.py +956 -0
- mlops/core/workspace.py +99 -0
- mlops/environment/__init__.py +10 -0
- mlops/environment/base.py +43 -0
- mlops/environment/conda_manager.py +307 -0
- mlops/environment/factory.py +70 -0
- mlops/environment/pyenv_manager.py +146 -0
- mlops/environment/setup_env.py +31 -0
- mlops/environment/system_manager.py +66 -0
- mlops/environment/utils.py +105 -0
- mlops/environment/venv_manager.py +134 -0
- mlops/main.py +527 -0
- mlops/managers/project_manager.py +400 -0
- mlops/managers/reproducibility_manager.py +575 -0
- mlops/platform.py +996 -0
- mlops/reporting/__init__.py +16 -0
- mlops/reporting/context.py +187 -0
- mlops/reporting/entrypoint.py +292 -0
- mlops/reporting/kv_utils.py +77 -0
- mlops/reporting/registry.py +50 -0
- mlops/runtime/__init__.py +9 -0
- mlops/runtime/context.py +34 -0
- mlops/runtime/env_export.py +113 -0
- mlops/storage/__init__.py +12 -0
- mlops/storage/adapters/__init__.py +9 -0
- mlops/storage/adapters/gcp_kv_store.py +778 -0
- mlops/storage/adapters/gcs_object_store.py +96 -0
- mlops/storage/adapters/memory_store.py +240 -0
- mlops/storage/adapters/redis_store.py +438 -0
- mlops/storage/factory.py +199 -0
- mlops/storage/interfaces/__init__.py +6 -0
- mlops/storage/interfaces/kv_store.py +118 -0
- mlops/storage/path_utils.py +38 -0
- mlops/templates/premier-league/charts/plot_metrics.js +70 -0
- mlops/templates/premier-league/charts/plot_metrics.py +145 -0
- mlops/templates/premier-league/charts/requirements.txt +6 -0
- mlops/templates/premier-league/configs/cluster_config.yaml +13 -0
- mlops/templates/premier-league/configs/project_config.yaml +207 -0
- mlops/templates/premier-league/data/England CSV.csv +12154 -0
- mlops/templates/premier-league/models/premier_league_model.py +638 -0
- mlops/templates/premier-league/requirements.txt +8 -0
- mlops/templates/sklearn-basic/README.md +22 -0
- mlops/templates/sklearn-basic/charts/plot_metrics.py +85 -0
- mlops/templates/sklearn-basic/charts/requirements.txt +3 -0
- mlops/templates/sklearn-basic/configs/project_config.yaml +64 -0
- mlops/templates/sklearn-basic/data/train.csv +14 -0
- mlops/templates/sklearn-basic/models/model.py +62 -0
- mlops/templates/sklearn-basic/requirements.txt +10 -0
- mlops/web/__init__.py +3 -0
- mlops/web/server.py +585 -0
- mlops/web/ui/index.html +52 -0
- mlops/web/ui/mlops-charts.js +357 -0
- mlops/web/ui/script.js +1244 -0
- mlops/web/ui/styles.css +248 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, Protocol
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExperimentTracker(Protocol):
|
|
7
|
+
"""
|
|
8
|
+
Minimal interface for an experiment tracker.
|
|
9
|
+
|
|
10
|
+
The platform uses this for optional experiment tracking (params/metrics/artifacts/tags)
|
|
11
|
+
alongside the built-in KV-store metric logging in `mlops.core.step_system`.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def start_run(
|
|
15
|
+
self,
|
|
16
|
+
run_name: Optional[str] = None,
|
|
17
|
+
run_id: Optional[str] = None,
|
|
18
|
+
tags: Optional[Dict[str, Any]] = None,
|
|
19
|
+
) -> Any:
|
|
20
|
+
"""Start a new run (returns context manager or run handle)."""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
def end_run(self, status: Optional[str] = "FINISHED") -> None:
|
|
24
|
+
"""End the current active run."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class NoOpExperimentTracker(ExperimentTracker):
|
|
29
|
+
"""
|
|
30
|
+
Default tracker: prints a few lifecycle messages but intentionally ignores metrics
|
|
31
|
+
to avoid noisy logs. Safe fallback when no experiment tracking backend is configured.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
35
|
+
self.config = config if config else {}
|
|
36
|
+
self.run_active = False
|
|
37
|
+
self.current_run_id = None
|
|
38
|
+
print(f"[NoOpTracker] Initialized with config: {self.config}")
|
|
39
|
+
|
|
40
|
+
def start_run(
|
|
41
|
+
self,
|
|
42
|
+
run_name: Optional[str] = None,
|
|
43
|
+
run_id: Optional[str] = None,
|
|
44
|
+
tags: Optional[Dict[str, Any]] = None,
|
|
45
|
+
) -> "NoOpExperimentTracker":
|
|
46
|
+
import uuid
|
|
47
|
+
|
|
48
|
+
if self.run_active:
|
|
49
|
+
print(
|
|
50
|
+
f"[NoOpTracker] Warning: A run (ID: {self.current_run_id}) is already active. "
|
|
51
|
+
f"Starting a new nested run is not fully supported by NoOpTracker; state will be overridden."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self.current_run_id = run_id if run_id else str(uuid.uuid4())
|
|
55
|
+
self.run_active = True
|
|
56
|
+
|
|
57
|
+
run_display_name = run_name if run_name else "default_run"
|
|
58
|
+
print(f"[NoOpTracker] Started run. Name: '{run_display_name}', ID: '{self.current_run_id}'")
|
|
59
|
+
if tags:
|
|
60
|
+
# Intentionally ignore tags in the NoOp tracker.
|
|
61
|
+
pass
|
|
62
|
+
return self # Return self to allow use as a context manager
|
|
63
|
+
|
|
64
|
+
def end_run(self, status: Optional[str] = "FINISHED") -> None:
|
|
65
|
+
if self.run_active:
|
|
66
|
+
print(f"[NoOpTracker][RunID: {self.current_run_id}] Ended run with status: {status}")
|
|
67
|
+
self.run_active = False
|
|
68
|
+
self.current_run_id = None
|
|
69
|
+
else:
|
|
70
|
+
print("[NoOpTracker] No active run to end.")
|
|
71
|
+
|
|
72
|
+
def __enter__(self):
|
|
73
|
+
if not self.run_active:
|
|
74
|
+
self.start_run()
|
|
75
|
+
return self
|
|
76
|
+
|
|
77
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
78
|
+
status = "FAILED" if exc_type else "FINISHED"
|
|
79
|
+
self.end_run(status=status)
|
|
80
|
+
|
|
81
|
+
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NodeType(Enum):
|
|
9
|
+
"""Types of nodes in the execution graph."""
|
|
10
|
+
PROCESS = "process"
|
|
11
|
+
STEP = "step"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ProcessConfig:
|
|
16
|
+
"""Configuration for a process node."""
|
|
17
|
+
name: str
|
|
18
|
+
depends_on: list[str] | None = None
|
|
19
|
+
parallel: bool = True
|
|
20
|
+
code_function: Optional[str] = None # Function name to execute (if different from name)
|
|
21
|
+
process_type: str = "process" # e.g., "process" or special types like "chart"
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
if self.depends_on is None:
|
|
25
|
+
self.depends_on = []
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class StepConfig:
|
|
30
|
+
"""Configuration for a step node."""
|
|
31
|
+
name: str
|
|
32
|
+
type: str = "step"
|
|
33
|
+
process: Optional[str] = None
|
|
34
|
+
inputs: list[str] | None = None
|
|
35
|
+
outputs: list[str] | None = None
|
|
36
|
+
loop_back_to: Optional[str] = None
|
|
37
|
+
condition: Optional[str] = None
|
|
38
|
+
parallel: bool = True
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class NetworkXGraphConfig:
|
|
43
|
+
"""Configuration for NetworkX-based graph execution."""
|
|
44
|
+
processes: list[ProcessConfig] | None = None
|
|
45
|
+
steps: list[StepConfig] | None = None
|
|
46
|
+
execution: dict[str, Any] | None = None
|
|
47
|
+
|
|
48
|
+
def __post_init__(self) -> None:
|
|
49
|
+
if self.processes is None:
|
|
50
|
+
self.processes = []
|
|
51
|
+
if self.steps is None:
|
|
52
|
+
self.steps = []
|
|
53
|
+
if self.execution is None:
|
|
54
|
+
self.execution = {}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class ExecutionResult:
|
|
59
|
+
"""Result of executing a node (process or step)."""
|
|
60
|
+
name: str
|
|
61
|
+
result: Optional[dict[str, Any]] = None # Dictionary containing step/process results
|
|
62
|
+
execution_time: float = 0.0
|
|
63
|
+
was_cached: bool = False
|
|
64
|
+
error: Optional[str] = None
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from .graph_types import NetworkXGraphConfig, ProcessConfig, StepConfig
|
|
7
|
+
from .step_system import get_step_registry
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NetworkXPipelineParser:
|
|
13
|
+
"""Parser for NetworkX-based pipeline configurations with loop support."""
|
|
14
|
+
|
|
15
|
+
def __init__(self) -> None:
|
|
16
|
+
self.processes: Dict[str, ProcessConfig] = {}
|
|
17
|
+
self.steps: Dict[str, StepConfig] = {}
|
|
18
|
+
|
|
19
|
+
def parse_pipeline_config(self, pipeline_config: Dict[str, Any]) -> NetworkXGraphConfig:
|
|
20
|
+
"""Parse pipeline config and return NetworkX graph configuration."""
|
|
21
|
+
self.processes = {}
|
|
22
|
+
self.steps = {}
|
|
23
|
+
|
|
24
|
+
if "processes" in pipeline_config:
|
|
25
|
+
self._parse_networkx_format(pipeline_config)
|
|
26
|
+
|
|
27
|
+
config = self._generate_networkx_config(pipeline_config)
|
|
28
|
+
|
|
29
|
+
logger.info(f"Parsed pipeline with {len(self.processes)} processes and {len(self.steps)} steps")
|
|
30
|
+
return config
|
|
31
|
+
|
|
32
|
+
def _parse_networkx_format(self, pipeline_config: Dict[str, Any]) -> None:
|
|
33
|
+
"""Parse NetworkX configuration format with DAG flow support."""
|
|
34
|
+
|
|
35
|
+
if "process_adjlist" in pipeline_config:
|
|
36
|
+
self._parse_process_adjlist(pipeline_config["process_adjlist"])
|
|
37
|
+
|
|
38
|
+
for process_data in pipeline_config.get("processes", []):
|
|
39
|
+
process_name = process_data["name"]
|
|
40
|
+
code_function = process_data.get("code_function") # Extract code_function if provided
|
|
41
|
+
proc_type = process_data.get("type", "process")
|
|
42
|
+
|
|
43
|
+
if process_name in self.processes:
|
|
44
|
+
process_config = self.processes[process_name]
|
|
45
|
+
process_config.parallel = process_data.get("parallel", process_config.parallel)
|
|
46
|
+
# Update code_function if provided in config
|
|
47
|
+
if code_function:
|
|
48
|
+
process_config.code_function = code_function
|
|
49
|
+
# Update process_type if provided (important for chart processes)
|
|
50
|
+
if proc_type:
|
|
51
|
+
process_config.process_type = str(proc_type)
|
|
52
|
+
logger.debug(f"Updated process '{process_name}' type to '{proc_type}'")
|
|
53
|
+
else:
|
|
54
|
+
process_config = ProcessConfig(
|
|
55
|
+
name=process_name,
|
|
56
|
+
parallel=process_data.get("parallel", True),
|
|
57
|
+
code_function=code_function, # Set code_function from config
|
|
58
|
+
process_type=str(proc_type)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.processes[process_name] = process_config
|
|
62
|
+
|
|
63
|
+
# Manual-step mode: steps are executed inside process runners, not scheduled as nodes.
|
|
64
|
+
|
|
65
|
+
def _parse_process_adjlist(self, adjlist_value: Any) -> None:
|
|
66
|
+
"""Parse process-level DAG from adjacency list string or list of lines.
|
|
67
|
+
|
|
68
|
+
Adjacency list lines follow NetworkX semantics: first token is the source node,
|
|
69
|
+
subsequent tokens are target nodes. Lines may include comments after a '#'.
|
|
70
|
+
"""
|
|
71
|
+
if isinstance(adjlist_value, str):
|
|
72
|
+
lines = adjlist_value.splitlines()
|
|
73
|
+
elif isinstance(adjlist_value, list):
|
|
74
|
+
lines = [str(x) for x in adjlist_value]
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("process_adjlist must be a string or list of lines")
|
|
77
|
+
|
|
78
|
+
for raw_line in lines:
|
|
79
|
+
line = raw_line.strip()
|
|
80
|
+
if not line:
|
|
81
|
+
continue
|
|
82
|
+
# Remove comments after '#'
|
|
83
|
+
if "#" in line:
|
|
84
|
+
line = line.split("#", 1)[0].strip()
|
|
85
|
+
if not line:
|
|
86
|
+
continue
|
|
87
|
+
parts = line.split()
|
|
88
|
+
if not parts:
|
|
89
|
+
continue
|
|
90
|
+
source = parts[0]
|
|
91
|
+
targets = parts[1:]
|
|
92
|
+
|
|
93
|
+
# Ensure source process exists
|
|
94
|
+
if source not in self.processes:
|
|
95
|
+
self.processes[source] = ProcessConfig(name=source, depends_on=[])
|
|
96
|
+
|
|
97
|
+
# Add or update targets with dependency on source
|
|
98
|
+
for target in targets:
|
|
99
|
+
if target not in self.processes:
|
|
100
|
+
self.processes[target] = ProcessConfig(name=target, depends_on=[source])
|
|
101
|
+
else:
|
|
102
|
+
deps = self.processes[target].depends_on or []
|
|
103
|
+
if source not in deps:
|
|
104
|
+
self.processes[target].depends_on = deps + [source]
|
|
105
|
+
|
|
106
|
+
def _discover_steps_from_registry(self) -> None:
|
|
107
|
+
"""Manual-step mode: keep for compatibility; intentionally a no-op."""
|
|
108
|
+
try:
|
|
109
|
+
step_registry = get_step_registry()
|
|
110
|
+
registered_steps = step_registry.list_steps() if step_registry else []
|
|
111
|
+
logger.debug(f"Manual-step mode enabled; ignoring {len(registered_steps)} registered steps during parsing")
|
|
112
|
+
except Exception:
|
|
113
|
+
logger.debug("Manual-step mode enabled; no step registry available")
|
|
114
|
+
|
|
115
|
+
def _generate_networkx_config(self, pipeline_config: Dict[str, Any]) -> NetworkXGraphConfig:
|
|
116
|
+
"""Generate NetworkX configuration from parsed processes and steps."""
|
|
117
|
+
|
|
118
|
+
execution_config = pipeline_config.get("execution", {})
|
|
119
|
+
execution = {
|
|
120
|
+
"parallel": execution_config.get("parallel", True),
|
|
121
|
+
"failure_mode": execution_config.get("failure_mode", "stop"),
|
|
122
|
+
"max_workers": execution_config.get("max_workers", 4)
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
return NetworkXGraphConfig(
|
|
126
|
+
processes=list(self.processes.values()),
|
|
127
|
+
steps=list(self.steps.values()),
|
|
128
|
+
execution=execution
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def parse_networkx_pipeline_from_config(pipeline_config: Dict[str, Any]) -> NetworkXGraphConfig:
|
|
133
|
+
"""Parse pipeline configuration into NetworkX format."""
|
|
134
|
+
parser = NetworkXPipelineParser()
|
|
135
|
+
return parser.parse_pipeline_config(pipeline_config)
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import uuid
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Iterable, Tuple, Optional
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import numpy as np # type: ignore
|
|
12
|
+
except Exception: # pragma: no cover
|
|
13
|
+
np = None # type: ignore
|
|
14
|
+
|
|
15
|
+
if np is not None:
|
|
16
|
+
|
|
17
|
+
class SpillArray(np.ndarray): # type: ignore[misc]
|
|
18
|
+
"""ndarray subclass with list-like truthiness semantics."""
|
|
19
|
+
|
|
20
|
+
def __new__(cls, input_array: "np.ndarray", origin_type: Optional[str] = None):
|
|
21
|
+
obj = np.asarray(input_array).view(cls)
|
|
22
|
+
obj._origin_type = origin_type # type: ignore[attr-defined]
|
|
23
|
+
return obj
|
|
24
|
+
|
|
25
|
+
def __array_finalize__(self, obj):
|
|
26
|
+
if obj is None:
|
|
27
|
+
return
|
|
28
|
+
self._origin_type = getattr(obj, "_origin_type", None)
|
|
29
|
+
|
|
30
|
+
def __bool__(self) -> bool: # pragma: no cover - trivial behaviour
|
|
31
|
+
return bool(self.size)
|
|
32
|
+
|
|
33
|
+
else: # pragma: no cover - numpy unavailable
|
|
34
|
+
SpillArray = None # type: ignore
|
|
35
|
+
|
|
36
|
+
PAYLOAD_REF_KEY = "__mlops_payload_ref__"
|
|
37
|
+
PAYLOAD_META_VERSION = 2
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_logger() -> logging.Logger:
|
|
41
|
+
return logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _coerce_array(value: Any) -> Optional["np.ndarray"]:
|
|
45
|
+
"""Best-effort conversion of supported payloads to a numpy array."""
|
|
46
|
+
if np is None:
|
|
47
|
+
return None
|
|
48
|
+
if isinstance(value, np.ndarray):
|
|
49
|
+
return value
|
|
50
|
+
if hasattr(value, "to_numpy"):
|
|
51
|
+
try:
|
|
52
|
+
arr = value.to_numpy()
|
|
53
|
+
return arr if isinstance(arr, np.ndarray) else None
|
|
54
|
+
except Exception:
|
|
55
|
+
return None
|
|
56
|
+
if isinstance(value, (list, tuple)):
|
|
57
|
+
try:
|
|
58
|
+
arr = np.asarray(value)
|
|
59
|
+
if arr.dtype == object:
|
|
60
|
+
return None
|
|
61
|
+
return arr
|
|
62
|
+
except Exception:
|
|
63
|
+
return None
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _estimate_bytes(value: Any) -> int:
|
|
68
|
+
if np is not None and isinstance(value, np.ndarray):
|
|
69
|
+
return int(value.nbytes)
|
|
70
|
+
if isinstance(value, (bytes, bytearray)):
|
|
71
|
+
return len(value)
|
|
72
|
+
if isinstance(value, (list, tuple)):
|
|
73
|
+
try:
|
|
74
|
+
if all(isinstance(v, (int, float)) for v in value):
|
|
75
|
+
return len(value) * 8
|
|
76
|
+
except Exception:
|
|
77
|
+
return 0
|
|
78
|
+
return 0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _should_spill(value: Any, threshold_bytes: int) -> bool:
|
|
82
|
+
arr = _coerce_array(value)
|
|
83
|
+
if arr is None:
|
|
84
|
+
return False
|
|
85
|
+
approx = _estimate_bytes(arr)
|
|
86
|
+
return approx >= threshold_bytes
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _serialize_array(arr: "np.ndarray", origin_type: Optional[str] = None) -> Tuple[bytes, Dict[str, Any]]:
|
|
90
|
+
buf = io.BytesIO()
|
|
91
|
+
np.savez_compressed(buf, data=arr)
|
|
92
|
+
meta = {
|
|
93
|
+
"shape": list(arr.shape),
|
|
94
|
+
"dtype": str(arr.dtype),
|
|
95
|
+
"approx_bytes": int(arr.nbytes),
|
|
96
|
+
"format": "npz",
|
|
97
|
+
"origin_type": origin_type or "ndarray",
|
|
98
|
+
}
|
|
99
|
+
return buf.getvalue(), meta
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _wrap_spilled_array(arr: "np.ndarray", origin_type: Optional[str]) -> Any:
|
|
103
|
+
if origin_type == "list":
|
|
104
|
+
return arr.tolist()
|
|
105
|
+
if origin_type == "tuple":
|
|
106
|
+
return tuple(arr.tolist())
|
|
107
|
+
if np is not None and isinstance(arr, np.ndarray) and SpillArray is not None:
|
|
108
|
+
try:
|
|
109
|
+
wrapped = arr.view(SpillArray)
|
|
110
|
+
wrapped._origin_type = origin_type # type: ignore[attr-defined]
|
|
111
|
+
return wrapped
|
|
112
|
+
except Exception:
|
|
113
|
+
return arr
|
|
114
|
+
return arr
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _build_payload_filename(run_id: Optional[str], process_name: Optional[str], key_path: Iterable[str]) -> str:
|
|
118
|
+
segments = ["payloads"]
|
|
119
|
+
if run_id:
|
|
120
|
+
segments.append(run_id)
|
|
121
|
+
if process_name:
|
|
122
|
+
segments.append(process_name)
|
|
123
|
+
path_tuple = tuple(key_path)
|
|
124
|
+
if path_tuple:
|
|
125
|
+
path_segment = "-".join(part or "part" for part in path_tuple)
|
|
126
|
+
else:
|
|
127
|
+
path_segment = "data"
|
|
128
|
+
segments.append(path_segment)
|
|
129
|
+
segments.append(str(uuid.uuid4()))
|
|
130
|
+
filename = "/".join(s.strip("/").replace(" ", "_") for s in segments)
|
|
131
|
+
return f"{filename}.npz"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _store_bytes(state_manager: Any, filename: str, payload: bytes) -> str:
|
|
135
|
+
"""Persist bytes via object store when configured, otherwise on the local cache path."""
|
|
136
|
+
if getattr(state_manager, "object_store", None):
|
|
137
|
+
try:
|
|
138
|
+
build_uri = getattr(state_manager, "_build_object_uri", None)
|
|
139
|
+
if callable(build_uri):
|
|
140
|
+
uri = build_uri(filename)
|
|
141
|
+
else:
|
|
142
|
+
uri = state_manager.object_store.build_uri(filename) # type: ignore[call-arg]
|
|
143
|
+
state_manager.object_store.put_bytes(uri, payload, content_type="application/octet-stream") # type: ignore[attr-defined]
|
|
144
|
+
return uri
|
|
145
|
+
except Exception as e:
|
|
146
|
+
_get_logger().warning(f"[PayloadSpill] Failed to put bytes to object store ({filename}): {e}")
|
|
147
|
+
cache_dir = getattr(state_manager, "cache_dir", Path("."))
|
|
148
|
+
local_path = Path(cache_dir) / filename
|
|
149
|
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
150
|
+
with open(local_path, "wb") as fout:
|
|
151
|
+
fout.write(payload)
|
|
152
|
+
return str(local_path.resolve())
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def spill_large_payloads(result: Dict[str, Any],
|
|
156
|
+
state_manager: Any,
|
|
157
|
+
run_id: Optional[str],
|
|
158
|
+
process_name: Optional[str],
|
|
159
|
+
threshold_bytes: int = 5_000_000) -> Dict[str, Any]:
|
|
160
|
+
"""
|
|
161
|
+
Replace large numeric payloads inside a result dict with lightweight references.
|
|
162
|
+
"""
|
|
163
|
+
if not isinstance(result, dict) or state_manager is None or np is None:
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def _process(key: str, value: Any, path: Tuple[str, ...]) -> Any:
|
|
167
|
+
if isinstance(value, dict):
|
|
168
|
+
return {k: _process(k, v, path + (k,)) for k, v in value.items()}
|
|
169
|
+
if isinstance(value, list):
|
|
170
|
+
if _should_spill(value, threshold_bytes):
|
|
171
|
+
return _spill_value(value, path)
|
|
172
|
+
try:
|
|
173
|
+
return [_process(str(idx), item, path + (str(idx),)) for idx, item in enumerate(value)]
|
|
174
|
+
except Exception:
|
|
175
|
+
pass
|
|
176
|
+
if isinstance(value, tuple):
|
|
177
|
+
if _should_spill(value, threshold_bytes):
|
|
178
|
+
return _spill_value(list(value), path)
|
|
179
|
+
return tuple(_process(str(idx), item, path + (str(idx),)) for idx, item in enumerate(value))
|
|
180
|
+
if _should_spill(value, threshold_bytes):
|
|
181
|
+
return _spill_value(value, path)
|
|
182
|
+
return value
|
|
183
|
+
|
|
184
|
+
def _spill_value(payload_value: Any, key_path: Tuple[str, ...]) -> Dict[str, Any]:
|
|
185
|
+
arr = _coerce_array(payload_value)
|
|
186
|
+
if arr is None:
|
|
187
|
+
return payload_value # type: ignore[return-value]
|
|
188
|
+
if isinstance(payload_value, list):
|
|
189
|
+
origin_type = "list"
|
|
190
|
+
elif isinstance(payload_value, tuple):
|
|
191
|
+
origin_type = "tuple"
|
|
192
|
+
elif np is not None and isinstance(payload_value, np.ndarray):
|
|
193
|
+
origin_type = "ndarray"
|
|
194
|
+
else:
|
|
195
|
+
origin_type = type(payload_value).__name__
|
|
196
|
+
data_bytes, meta = _serialize_array(arr, origin_type=origin_type)
|
|
197
|
+
path_tuple = key_path or ("payload",)
|
|
198
|
+
filename = _build_payload_filename(run_id, process_name, path_tuple)
|
|
199
|
+
uri = _store_bytes(state_manager, filename, data_bytes)
|
|
200
|
+
ref = {
|
|
201
|
+
PAYLOAD_REF_KEY: True,
|
|
202
|
+
"uri": uri,
|
|
203
|
+
"meta": meta,
|
|
204
|
+
"version": PAYLOAD_META_VERSION,
|
|
205
|
+
"key_path": "/".join(path_tuple),
|
|
206
|
+
"process": process_name,
|
|
207
|
+
"run_id": run_id,
|
|
208
|
+
}
|
|
209
|
+
try:
|
|
210
|
+
_get_logger().info(f"[PayloadSpill] Spilled payload for {process_name}:{ref['key_path']} -> {uri} ({meta['approx_bytes']} bytes)")
|
|
211
|
+
except Exception:
|
|
212
|
+
pass
|
|
213
|
+
return ref
|
|
214
|
+
|
|
215
|
+
new_result = {}
|
|
216
|
+
for k, v in result.items():
|
|
217
|
+
key = str(k)
|
|
218
|
+
new_result[key] = _process(key, v, (key,))
|
|
219
|
+
return new_result
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def hydrate_payload_refs(data: Any, state_manager: Any) -> Any:
|
|
223
|
+
"""Replace payload reference dicts with their hydrated numpy arrays."""
|
|
224
|
+
if state_manager is None or np is None:
|
|
225
|
+
return data
|
|
226
|
+
if isinstance(data, dict):
|
|
227
|
+
if data.get(PAYLOAD_REF_KEY):
|
|
228
|
+
return _load_payload(data, state_manager)
|
|
229
|
+
return {k: hydrate_payload_refs(v, state_manager) for k, v in data.items()}
|
|
230
|
+
if isinstance(data, list):
|
|
231
|
+
return [hydrate_payload_refs(v, state_manager) for v in data]
|
|
232
|
+
if isinstance(data, tuple):
|
|
233
|
+
return tuple(hydrate_payload_refs(v, state_manager) for v in data)
|
|
234
|
+
return data
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _load_payload(ref: Dict[str, Any], state_manager: Any) -> Any:
|
|
238
|
+
uri = ref.get("uri")
|
|
239
|
+
if not uri:
|
|
240
|
+
return ref
|
|
241
|
+
try:
|
|
242
|
+
payload_bytes = None
|
|
243
|
+
if str(uri).startswith("gs://") and getattr(state_manager, "object_store", None):
|
|
244
|
+
payload_bytes = state_manager.object_store.get_bytes(uri) # type: ignore[attr-defined]
|
|
245
|
+
else:
|
|
246
|
+
path = Path(uri)
|
|
247
|
+
if not path.is_absolute() and getattr(state_manager, "cache_dir", None):
|
|
248
|
+
path = Path(state_manager.cache_dir) / path
|
|
249
|
+
with open(path, "rb") as fin:
|
|
250
|
+
payload_bytes = fin.read()
|
|
251
|
+
if payload_bytes is None:
|
|
252
|
+
raise RuntimeError("No payload bytes resolved")
|
|
253
|
+
with np.load(io.BytesIO(payload_bytes), allow_pickle=False) as npz:
|
|
254
|
+
arr = npz["data"]
|
|
255
|
+
meta = ref.get("meta") if isinstance(ref.get("meta"), dict) else {}
|
|
256
|
+
origin_type = meta.get("origin_type") if isinstance(meta, dict) else None
|
|
257
|
+
return _wrap_spilled_array(arr, origin_type)
|
|
258
|
+
except Exception as e:
|
|
259
|
+
_get_logger().warning(f"[PayloadSpill] Failed to hydrate payload {uri}: {e}")
|
|
260
|
+
return ref
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
if __name__ == "__main__": # pragma: no cover - developer self-test hook
|
|
264
|
+
logging.basicConfig(level=logging.INFO)
|
|
265
|
+
if np is None:
|
|
266
|
+
print("NumPy is not available; skipping payload spill self-test.")
|
|
267
|
+
else:
|
|
268
|
+
class _TempStateManager:
|
|
269
|
+
def __init__(self) -> None:
|
|
270
|
+
self.cache_dir = Path(os.environ.get("PAYLOAD_SPILL_TMP", "/tmp/payload_spill_test"))
|
|
271
|
+
self.object_store = None
|
|
272
|
+
|
|
273
|
+
sm = _TempStateManager()
|
|
274
|
+
sample = {"X_test": np.random.rand(2000, 200)} # ~3.2 MB -> forces spill
|
|
275
|
+
spilled = spill_large_payloads(sample, sm, run_id="selftest", process_name="demo", threshold_bytes=1_000_000)
|
|
276
|
+
print(spilled)
|
|
277
|
+
|
|
278
|
+
|