wafer-core 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/aligner.py +13 -6
- wafer_core/lib/trace_compare/analyzer.py +12 -3
- wafer_core/lib/trace_compare/classifier.py +18 -9
- wafer_core/lib/trace_compare/fusion_analyzer.py +424 -275
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/METADATA +1 -1
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/RECORD +19 -9
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.29.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""RunPod provider — adapts existing RunPod GraphQL API to TargetProvider protocol."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
|
|
9
|
+
from wafer_core.targets.runpod import (
|
|
10
|
+
RunPodError,
|
|
11
|
+
_graphql_request_async,
|
|
12
|
+
_wait_for_ssh,
|
|
13
|
+
)
|
|
14
|
+
from wafer_core.targets.runpod import (
|
|
15
|
+
terminate_pod as _terminate_pod,
|
|
16
|
+
)
|
|
17
|
+
from wafer_core.targets.types import Target, TargetSpec
|
|
18
|
+
from wafer_core.utils.kernel_utils.targets.config import RunPodTarget
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _parse_pod_to_target(pod: dict) -> Target | None:
|
|
24
|
+
"""Parse a RunPod API pod response into a Target.
|
|
25
|
+
|
|
26
|
+
Returns None if the pod has no usable SSH info.
|
|
27
|
+
"""
|
|
28
|
+
pod_id = pod.get("id", "")
|
|
29
|
+
pod_name = pod.get("name", "")
|
|
30
|
+
status_raw = pod.get("desiredStatus", "").lower()
|
|
31
|
+
|
|
32
|
+
# Map RunPod statuses to our status values
|
|
33
|
+
status = status_raw if status_raw else "unknown"
|
|
34
|
+
|
|
35
|
+
# Extract SSH info from runtime ports
|
|
36
|
+
public_ip = None
|
|
37
|
+
ssh_port = None
|
|
38
|
+
runtime = pod.get("runtime")
|
|
39
|
+
if runtime:
|
|
40
|
+
for port in runtime.get("ports") or []:
|
|
41
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
42
|
+
ip = port.get("ip")
|
|
43
|
+
# Skip proxy SSH (ssh.runpod.io), want direct IP
|
|
44
|
+
if ip and ip != "ssh.runpod.io":
|
|
45
|
+
public_ip = ip
|
|
46
|
+
ssh_port = port.get("publicPort")
|
|
47
|
+
break
|
|
48
|
+
|
|
49
|
+
# Infer spec_name from pod naming convention: wafer-{spec_name}-{timestamp}
|
|
50
|
+
spec_name = None
|
|
51
|
+
if pod_name.startswith("wafer-"):
|
|
52
|
+
parts = pod_name.split("-")
|
|
53
|
+
if len(parts) >= 3:
|
|
54
|
+
spec_name = "-".join(parts[1:-1])
|
|
55
|
+
|
|
56
|
+
# Extract GPU type
|
|
57
|
+
gpu_type = ""
|
|
58
|
+
machine = pod.get("machine")
|
|
59
|
+
if machine:
|
|
60
|
+
gpu_type_info = machine.get("gpuType")
|
|
61
|
+
if gpu_type_info:
|
|
62
|
+
gpu_type = gpu_type_info.get("displayName", "")
|
|
63
|
+
|
|
64
|
+
cost = pod.get("costPerHr")
|
|
65
|
+
|
|
66
|
+
return Target(
|
|
67
|
+
resource_id=pod_id,
|
|
68
|
+
provider="runpod",
|
|
69
|
+
status=status,
|
|
70
|
+
public_ip=public_ip,
|
|
71
|
+
ssh_port=ssh_port,
|
|
72
|
+
ssh_username="root",
|
|
73
|
+
gpu_type=gpu_type,
|
|
74
|
+
name=pod_name or None,
|
|
75
|
+
created_at=None, # RunPod API doesn't expose creation time in list query
|
|
76
|
+
spec_name=spec_name,
|
|
77
|
+
price_per_hour=float(cost) if cost else None,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RunPodProvider:
|
|
82
|
+
"""RunPod implementation of TargetProvider.
|
|
83
|
+
|
|
84
|
+
Wraps existing GraphQL API calls:
|
|
85
|
+
- list_targets: myself { pods { ... } }
|
|
86
|
+
- get_target: pod(input: { podId }) { ... }
|
|
87
|
+
- provision: podFindAndDeployOnDemand
|
|
88
|
+
- terminate: podTerminate
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
async def list_targets(self) -> list[Target]:
|
|
92
|
+
"""List all running pods on the RunPod account."""
|
|
93
|
+
query = """
|
|
94
|
+
query {
|
|
95
|
+
myself {
|
|
96
|
+
pods {
|
|
97
|
+
id
|
|
98
|
+
name
|
|
99
|
+
desiredStatus
|
|
100
|
+
costPerHr
|
|
101
|
+
machine {
|
|
102
|
+
podHostId
|
|
103
|
+
gpuType {
|
|
104
|
+
displayName
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
runtime {
|
|
108
|
+
ports {
|
|
109
|
+
ip
|
|
110
|
+
isIpPublic
|
|
111
|
+
privatePort
|
|
112
|
+
publicPort
|
|
113
|
+
type
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
data = await _graphql_request_async(query)
|
|
123
|
+
except RunPodError:
|
|
124
|
+
raise
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.warning(f"Failed to list RunPod pods: {e}")
|
|
127
|
+
return []
|
|
128
|
+
|
|
129
|
+
pods = data.get("myself", {}).get("pods", [])
|
|
130
|
+
targets = []
|
|
131
|
+
|
|
132
|
+
for pod in pods:
|
|
133
|
+
target = _parse_pod_to_target(pod)
|
|
134
|
+
if target is not None:
|
|
135
|
+
targets.append(target)
|
|
136
|
+
|
|
137
|
+
return targets
|
|
138
|
+
|
|
139
|
+
async def get_target(self, resource_id: str) -> Target | None:
|
|
140
|
+
"""Get a specific pod by ID."""
|
|
141
|
+
query = """
|
|
142
|
+
query pod($input: PodFilter!) {
|
|
143
|
+
pod(input: $input) {
|
|
144
|
+
id
|
|
145
|
+
name
|
|
146
|
+
desiredStatus
|
|
147
|
+
costPerHr
|
|
148
|
+
machine {
|
|
149
|
+
podHostId
|
|
150
|
+
gpuType {
|
|
151
|
+
displayName
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
runtime {
|
|
155
|
+
ports {
|
|
156
|
+
ip
|
|
157
|
+
isIpPublic
|
|
158
|
+
privatePort
|
|
159
|
+
publicPort
|
|
160
|
+
type
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
"""
|
|
166
|
+
variables = {"input": {"podId": resource_id}}
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
data = await _graphql_request_async(query, variables)
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.warning(f"Failed to get RunPod pod {resource_id}: {e}")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
pod = data.get("pod")
|
|
175
|
+
if not pod:
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
return _parse_pod_to_target(pod)
|
|
179
|
+
|
|
180
|
+
async def provision(self, spec: TargetSpec) -> Target:
|
|
181
|
+
"""Provision a new RunPod pod from a spec.
|
|
182
|
+
|
|
183
|
+
Blocks until SSH is ready.
|
|
184
|
+
"""
|
|
185
|
+
assert isinstance(spec, RunPodTarget), (
|
|
186
|
+
f"RunPodProvider.provision requires RunPodTarget, got {type(spec).__name__}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
pod_name = f"wafer-{spec.name}-{int(time.time())}"
|
|
190
|
+
|
|
191
|
+
mutation = """
|
|
192
|
+
mutation podFindAndDeployOnDemand($input: PodFindAndDeployOnDemandInput!) {
|
|
193
|
+
podFindAndDeployOnDemand(input: $input) {
|
|
194
|
+
id
|
|
195
|
+
machineId
|
|
196
|
+
machine {
|
|
197
|
+
podHostId
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
pod_input: dict = {
|
|
204
|
+
"gpuTypeId": spec.gpu_type_id,
|
|
205
|
+
"gpuCount": spec.gpu_count,
|
|
206
|
+
"cloudType": "SECURE",
|
|
207
|
+
"name": pod_name,
|
|
208
|
+
"supportPublicIp": True,
|
|
209
|
+
"containerDiskInGb": spec.container_disk_gb,
|
|
210
|
+
"minVcpuCount": 1,
|
|
211
|
+
"minMemoryInGb": 4,
|
|
212
|
+
"ports": "22/tcp",
|
|
213
|
+
"startSsh": True,
|
|
214
|
+
"startJupyter": False,
|
|
215
|
+
"env": [],
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
if spec.template_id:
|
|
219
|
+
pod_input["templateId"] = spec.template_id
|
|
220
|
+
else:
|
|
221
|
+
pod_input["imageName"] = spec.image
|
|
222
|
+
|
|
223
|
+
logger.info(f"Provisioning RunPod pod: {pod_name}")
|
|
224
|
+
data = await _graphql_request_async(mutation, {"input": pod_input})
|
|
225
|
+
|
|
226
|
+
pod_data = data.get("podFindAndDeployOnDemand")
|
|
227
|
+
if not pod_data:
|
|
228
|
+
raise RunPodError("No pod returned from deployment")
|
|
229
|
+
|
|
230
|
+
pod_id = pod_data["id"]
|
|
231
|
+
logger.info(f"Pod created: {pod_id}")
|
|
232
|
+
|
|
233
|
+
public_ip, ssh_port, ssh_username = await _wait_for_ssh(pod_id, spec.provision_timeout)
|
|
234
|
+
|
|
235
|
+
return Target(
|
|
236
|
+
resource_id=pod_id,
|
|
237
|
+
provider="runpod",
|
|
238
|
+
status="running",
|
|
239
|
+
public_ip=public_ip,
|
|
240
|
+
ssh_port=ssh_port,
|
|
241
|
+
ssh_username=ssh_username,
|
|
242
|
+
gpu_type=spec.gpu_type,
|
|
243
|
+
name=pod_name,
|
|
244
|
+
created_at=datetime.now(timezone.utc).isoformat(),
|
|
245
|
+
spec_name=spec.name,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
async def terminate(self, resource_id: str) -> bool:
|
|
249
|
+
"""Terminate a RunPod pod."""
|
|
250
|
+
return await _terminate_pod(resource_id)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Reconciliation: compare TargetSpecs to live Targets.
|
|
2
|
+
|
|
3
|
+
Pure function — no API calls, no side effects. Takes specs and targets as
|
|
4
|
+
inputs, returns a ReconcileResult describing what's bound, what's orphaned,
|
|
5
|
+
and what's unprovisioned.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from wafer_core.targets.types import ReconcileResult, Target, TargetSpec
|
|
11
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
12
|
+
BaremetalTarget,
|
|
13
|
+
DigitalOceanTarget,
|
|
14
|
+
RunPodTarget,
|
|
15
|
+
VMTarget,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_cloud_spec(spec: TargetSpec) -> bool:
|
|
20
|
+
"""Check if a spec represents a cloud-provisioned resource.
|
|
21
|
+
|
|
22
|
+
Baremetal and VM specs don't have cloud-managed lifecycles,
|
|
23
|
+
so they're excluded from "unprovisioned" checks.
|
|
24
|
+
"""
|
|
25
|
+
return isinstance(spec, (RunPodTarget, DigitalOceanTarget))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _spec_provider(spec: TargetSpec) -> str | None:
|
|
29
|
+
"""Get the provider name for a spec, or None if not cloud-managed."""
|
|
30
|
+
if isinstance(spec, RunPodTarget):
|
|
31
|
+
return "runpod"
|
|
32
|
+
if isinstance(spec, DigitalOceanTarget):
|
|
33
|
+
return "digitalocean"
|
|
34
|
+
if isinstance(spec, (BaremetalTarget, VMTarget)):
|
|
35
|
+
return "baremetal"
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def reconcile(
|
|
40
|
+
specs: list[TargetSpec],
|
|
41
|
+
targets: list[Target],
|
|
42
|
+
binding_hints: dict[str, str] | None = None,
|
|
43
|
+
) -> ReconcileResult:
|
|
44
|
+
"""Compare specs to live targets and classify each.
|
|
45
|
+
|
|
46
|
+
Matching rules (in priority order):
|
|
47
|
+
1. Target.spec_name matches Spec.name exactly (set by naming convention
|
|
48
|
+
or explicit binding).
|
|
49
|
+
2. binding_hints maps resource_id → spec_name (from local cache).
|
|
50
|
+
3. No match → target is unbound (orphan).
|
|
51
|
+
|
|
52
|
+
A cloud spec with no matching target is "unprovisioned".
|
|
53
|
+
Baremetal/VM specs are never "unprovisioned" (they don't have a cloud
|
|
54
|
+
lifecycle — the machine is always there or it isn't).
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
specs: All known TargetSpecs (loaded from TOML files).
|
|
58
|
+
targets: All live Targets (fetched from provider APIs).
|
|
59
|
+
binding_hints: Optional resource_id → spec_name cache for targets
|
|
60
|
+
whose spec_name can't be inferred from naming convention.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
ReconcileResult with bound, unbound, and unprovisioned lists.
|
|
64
|
+
"""
|
|
65
|
+
hints = binding_hints or {}
|
|
66
|
+
spec_by_name = {s.name: s for s in specs}
|
|
67
|
+
claimed_spec_names: set[str] = set()
|
|
68
|
+
|
|
69
|
+
bound: list[tuple[TargetSpec, Target]] = []
|
|
70
|
+
unbound: list[Target] = []
|
|
71
|
+
|
|
72
|
+
for target in targets:
|
|
73
|
+
# Try to find the spec this target belongs to
|
|
74
|
+
resolved_spec_name = target.spec_name or hints.get(target.resource_id)
|
|
75
|
+
|
|
76
|
+
if resolved_spec_name and resolved_spec_name in spec_by_name:
|
|
77
|
+
spec = spec_by_name[resolved_spec_name]
|
|
78
|
+
bound.append((spec, target))
|
|
79
|
+
claimed_spec_names.add(resolved_spec_name)
|
|
80
|
+
else:
|
|
81
|
+
unbound.append(target)
|
|
82
|
+
|
|
83
|
+
# Cloud specs with no bound target are unprovisioned
|
|
84
|
+
unprovisioned = [s for s in specs if s.name not in claimed_spec_names and _is_cloud_spec(s)]
|
|
85
|
+
|
|
86
|
+
return ReconcileResult(
|
|
87
|
+
bound=bound,
|
|
88
|
+
unbound=unbound,
|
|
89
|
+
unprovisioned=unprovisioned,
|
|
90
|
+
)
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Spec store: CRUD for TargetSpec TOML files.
|
|
2
|
+
|
|
3
|
+
Specs live in ~/.wafer/specs/{name}.toml. On first access, auto-migrates
|
|
4
|
+
from the old ~/.wafer/targets/ directory if specs/ doesn't exist yet.
|
|
5
|
+
|
|
6
|
+
This module provides the same operations as the old targets.py but under
|
|
7
|
+
the "spec" vocabulary. The CLI-layer targets.py still works and delegates
|
|
8
|
+
here where needed.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import shutil
|
|
15
|
+
from dataclasses import asdict
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import tomllib
|
|
20
|
+
|
|
21
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
22
|
+
BaremetalTarget,
|
|
23
|
+
DigitalOceanTarget,
|
|
24
|
+
LocalTarget,
|
|
25
|
+
ModalTarget,
|
|
26
|
+
RunPodTarget,
|
|
27
|
+
TargetConfig,
|
|
28
|
+
VMTarget,
|
|
29
|
+
WorkspaceTarget,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
WAFER_DIR = Path.home() / ".wafer"
|
|
35
|
+
SPECS_DIR = WAFER_DIR / "specs"
|
|
36
|
+
OLD_TARGETS_DIR = WAFER_DIR / "targets"
|
|
37
|
+
CONFIG_FILE = WAFER_DIR / "config.toml"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _ensure_specs_dir() -> None:
|
|
41
|
+
"""Ensure ~/.wafer/specs/ exists, migrating from targets/ if needed."""
|
|
42
|
+
if SPECS_DIR.exists():
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
if OLD_TARGETS_DIR.exists() and any(OLD_TARGETS_DIR.glob("*.toml")):
|
|
46
|
+
logger.info(
|
|
47
|
+
f"Migrating {OLD_TARGETS_DIR} -> {SPECS_DIR} (target configs are now called 'specs')"
|
|
48
|
+
)
|
|
49
|
+
shutil.copytree(OLD_TARGETS_DIR, SPECS_DIR)
|
|
50
|
+
# Don't delete old dir yet — other code may still read from it.
|
|
51
|
+
# It becomes a dead symlink target once all callers migrate.
|
|
52
|
+
logger.info(
|
|
53
|
+
f"Migration complete. Old directory preserved at {OLD_TARGETS_DIR}. "
|
|
54
|
+
"You can safely delete it once 'wafer specs list' works."
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
SPECS_DIR.mkdir(parents=True, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _spec_path(name: str) -> Path:
|
|
61
|
+
return SPECS_DIR / f"{name}.toml"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# ── Parsing ──────────────────────────────────────────────────────────────────
|
|
65
|
+
|
|
66
|
+
_TYPE_MAP: dict[str, type] = {
|
|
67
|
+
"baremetal": BaremetalTarget,
|
|
68
|
+
"vm": VMTarget,
|
|
69
|
+
"modal": ModalTarget,
|
|
70
|
+
"workspace": WorkspaceTarget,
|
|
71
|
+
"runpod": RunPodTarget,
|
|
72
|
+
"digitalocean": DigitalOceanTarget,
|
|
73
|
+
"local": LocalTarget,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
_TYPE_REVERSE: dict[type, str] = {v: k for k, v in _TYPE_MAP.items()}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def parse_spec(data: dict[str, Any]) -> TargetConfig:
|
|
80
|
+
"""Parse TOML dict into TargetSpec (TargetConfig union)."""
|
|
81
|
+
target_type = data.get("type")
|
|
82
|
+
if not target_type:
|
|
83
|
+
raise ValueError("Spec must have 'type' field")
|
|
84
|
+
|
|
85
|
+
cls = _TYPE_MAP.get(target_type)
|
|
86
|
+
if cls is None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Unknown spec type: {target_type}. Must be one of: {', '.join(sorted(_TYPE_MAP))}"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
fields = {k: v for k, v in data.items() if k != "type"}
|
|
92
|
+
|
|
93
|
+
# TOML parses lists; dataclasses may want tuples
|
|
94
|
+
if "pip_packages" in fields and isinstance(fields["pip_packages"], list):
|
|
95
|
+
fields["pip_packages"] = tuple(fields["pip_packages"])
|
|
96
|
+
if "gpu_ids" in fields and isinstance(fields["gpu_ids"], list):
|
|
97
|
+
fields["gpu_ids"] = tuple(fields["gpu_ids"])
|
|
98
|
+
|
|
99
|
+
return cls(**fields)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def serialize_spec(spec: TargetConfig) -> dict[str, Any]:
|
|
103
|
+
"""Serialize TargetSpec to TOML-compatible dict."""
|
|
104
|
+
data = asdict(spec)
|
|
105
|
+
data["type"] = _TYPE_REVERSE.get(type(spec), "unknown")
|
|
106
|
+
|
|
107
|
+
# Tuples -> lists for TOML
|
|
108
|
+
for key in ("pip_packages", "gpu_ids"):
|
|
109
|
+
if key in data and isinstance(data[key], tuple):
|
|
110
|
+
data[key] = list(data[key])
|
|
111
|
+
|
|
112
|
+
# Drop empty pip_packages
|
|
113
|
+
if "pip_packages" in data and not data["pip_packages"]:
|
|
114
|
+
del data["pip_packages"]
|
|
115
|
+
|
|
116
|
+
return data
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# ── CRUD ─────────────────────────────────────────────────────────────────────
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def load_spec(name: str) -> TargetConfig:
|
|
123
|
+
"""Load spec by name from ~/.wafer/specs/{name}.toml.
|
|
124
|
+
|
|
125
|
+
Falls back to ~/.wafer/targets/{name}.toml for backwards compatibility.
|
|
126
|
+
"""
|
|
127
|
+
_ensure_specs_dir()
|
|
128
|
+
|
|
129
|
+
path = _spec_path(name)
|
|
130
|
+
if not path.exists():
|
|
131
|
+
# Fallback to old location
|
|
132
|
+
old_path = OLD_TARGETS_DIR / f"{name}.toml"
|
|
133
|
+
if old_path.exists():
|
|
134
|
+
path = old_path
|
|
135
|
+
else:
|
|
136
|
+
raise FileNotFoundError(f"Spec not found: {name} (looked in {SPECS_DIR})")
|
|
137
|
+
|
|
138
|
+
with open(path, "rb") as f:
|
|
139
|
+
data = tomllib.load(f)
|
|
140
|
+
|
|
141
|
+
return parse_spec(data)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def save_spec(spec: TargetConfig) -> None:
|
|
145
|
+
"""Save spec to ~/.wafer/specs/{name}.toml."""
|
|
146
|
+
_ensure_specs_dir()
|
|
147
|
+
|
|
148
|
+
data = serialize_spec(spec)
|
|
149
|
+
path = _spec_path(spec.name)
|
|
150
|
+
_write_toml(path, data)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def list_spec_names() -> list[str]:
|
|
154
|
+
"""List all spec names from ~/.wafer/specs/."""
|
|
155
|
+
_ensure_specs_dir()
|
|
156
|
+
return sorted(p.stem for p in SPECS_DIR.glob("*.toml"))
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def remove_spec(name: str) -> None:
|
|
160
|
+
"""Remove a spec by name."""
|
|
161
|
+
path = _spec_path(name)
|
|
162
|
+
if not path.exists():
|
|
163
|
+
raise FileNotFoundError(f"Spec not found: {name}")
|
|
164
|
+
path.unlink()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def load_all_specs() -> list[TargetConfig]:
|
|
168
|
+
"""Load all specs. Skips specs that fail to parse (logs warning)."""
|
|
169
|
+
specs = []
|
|
170
|
+
for name in list_spec_names():
|
|
171
|
+
try:
|
|
172
|
+
specs.append(load_spec(name))
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.warning(f"Failed to load spec {name}: {e}")
|
|
175
|
+
return specs
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# ── TOML writer ──────────────────────────────────────────────────────────────
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _write_toml(path: Path, data: dict[str, Any]) -> None:
|
|
182
|
+
"""Write dict as flat TOML file."""
|
|
183
|
+
lines = []
|
|
184
|
+
for key, value in data.items():
|
|
185
|
+
if value is None:
|
|
186
|
+
continue
|
|
187
|
+
if isinstance(value, bool):
|
|
188
|
+
lines.append(f"{key} = {str(value).lower()}")
|
|
189
|
+
elif isinstance(value, int | float):
|
|
190
|
+
lines.append(f"{key} = {value}")
|
|
191
|
+
elif isinstance(value, str):
|
|
192
|
+
lines.append(f'{key} = "{value}"')
|
|
193
|
+
elif isinstance(value, list):
|
|
194
|
+
if all(isinstance(v, int) for v in value):
|
|
195
|
+
lines.append(f"{key} = {value}")
|
|
196
|
+
else:
|
|
197
|
+
formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
|
|
198
|
+
lines.append(f"{key} = [{formatted}]")
|
|
199
|
+
|
|
200
|
+
path.write_text("\n".join(lines) + "\n")
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Target state cache: bindings and labels for live resources.
|
|
2
|
+
|
|
3
|
+
Cache file: ~/.wafer/target_state.json
|
|
4
|
+
|
|
5
|
+
Bindings map resource_id -> spec_name (performance hint for reconciliation).
|
|
6
|
+
Labels map resource_id -> {key: value} (probed software versions).
|
|
7
|
+
|
|
8
|
+
The provider API is always the source of truth for whether a resource exists.
|
|
9
|
+
This cache stores metadata that's expensive to recompute (SSH probes, name inference).
|
|
10
|
+
|
|
11
|
+
Format:
|
|
12
|
+
{
|
|
13
|
+
"bindings": {
|
|
14
|
+
"<resource_id>": {
|
|
15
|
+
"spec_name": "<spec_name>",
|
|
16
|
+
"provider": "<provider>",
|
|
17
|
+
"bound_at": "<ISO timestamp>"
|
|
18
|
+
}
|
|
19
|
+
},
|
|
20
|
+
"labels": {
|
|
21
|
+
"<resource_id>": {
|
|
22
|
+
"rocm_version": "7.0.2",
|
|
23
|
+
"python_version": "3.12",
|
|
24
|
+
...
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
import json
|
|
33
|
+
import logging
|
|
34
|
+
from dataclasses import asdict, dataclass
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
WAFER_DIR = Path.home() / ".wafer"
|
|
40
|
+
STATE_FILE = WAFER_DIR / "target_state.json"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class BindingEntry:
|
|
45
|
+
"""A cached binding from resource_id to spec_name."""
|
|
46
|
+
|
|
47
|
+
spec_name: str
|
|
48
|
+
provider: str
|
|
49
|
+
bound_at: str # ISO timestamp
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ---------------------------------------------------------------------------
|
|
53
|
+
# Raw file I/O
|
|
54
|
+
# ---------------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
def _load_state() -> dict:
|
|
57
|
+
"""Load the full state file. Returns empty dict if missing/corrupted."""
|
|
58
|
+
if not STATE_FILE.exists():
|
|
59
|
+
return {}
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
return json.loads(STATE_FILE.read_text())
|
|
63
|
+
except (json.JSONDecodeError, TypeError) as e:
|
|
64
|
+
logger.warning(f"Corrupted state cache, ignoring: {e}")
|
|
65
|
+
return {}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _save_state(data: dict) -> None:
|
|
69
|
+
"""Write the full state file."""
|
|
70
|
+
STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
71
|
+
STATE_FILE.write_text(json.dumps(data, indent=2) + "\n")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
# Bindings
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
def load_bindings() -> dict[str, BindingEntry]:
|
|
79
|
+
"""Load binding cache from disk."""
|
|
80
|
+
data = _load_state()
|
|
81
|
+
bindings_raw = data.get("bindings", {})
|
|
82
|
+
result = {}
|
|
83
|
+
for rid, entry in bindings_raw.items():
|
|
84
|
+
try:
|
|
85
|
+
result[rid] = BindingEntry(**entry)
|
|
86
|
+
except TypeError:
|
|
87
|
+
logger.warning(f"Skipping malformed binding for {rid}")
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def save_bindings(bindings: dict[str, BindingEntry]) -> None:
|
|
92
|
+
"""Write bindings to disk (preserves labels)."""
|
|
93
|
+
data = _load_state()
|
|
94
|
+
data["bindings"] = {rid: asdict(entry) for rid, entry in bindings.items()}
|
|
95
|
+
_save_state(data)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def add_binding(resource_id: str, entry: BindingEntry) -> None:
|
|
99
|
+
"""Add a single binding to the cache."""
|
|
100
|
+
bindings = load_bindings()
|
|
101
|
+
bindings[resource_id] = entry
|
|
102
|
+
save_bindings(bindings)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def remove_binding(resource_id: str) -> None:
|
|
106
|
+
"""Remove a binding from the cache. No-op if not found."""
|
|
107
|
+
bindings = load_bindings()
|
|
108
|
+
if resource_id in bindings:
|
|
109
|
+
del bindings[resource_id]
|
|
110
|
+
save_bindings(bindings)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_binding_hints() -> dict[str, str]:
|
|
114
|
+
"""Get resource_id -> spec_name map for reconciliation."""
|
|
115
|
+
bindings = load_bindings()
|
|
116
|
+
return {rid: entry.spec_name for rid, entry in bindings.items()}
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# ---------------------------------------------------------------------------
|
|
120
|
+
# Labels
|
|
121
|
+
# ---------------------------------------------------------------------------
|
|
122
|
+
|
|
123
|
+
def load_all_labels() -> dict[str, dict[str, str]]:
|
|
124
|
+
"""Load all cached labels. Returns resource_id -> labels dict."""
|
|
125
|
+
data = _load_state()
|
|
126
|
+
return data.get("labels", {})
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def load_labels(resource_id: str) -> dict[str, str]:
|
|
130
|
+
"""Load cached labels for a single resource. Returns empty dict if none."""
|
|
131
|
+
return load_all_labels().get(resource_id, {})
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def save_labels(resource_id: str, labels: dict[str, str]) -> None:
|
|
135
|
+
"""Save labels for a resource (preserves bindings and other labels)."""
|
|
136
|
+
data = _load_state()
|
|
137
|
+
if "labels" not in data:
|
|
138
|
+
data["labels"] = {}
|
|
139
|
+
data["labels"][resource_id] = labels
|
|
140
|
+
_save_state(data)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def remove_labels(resource_id: str) -> None:
|
|
144
|
+
"""Remove cached labels for a resource. No-op if not found."""
|
|
145
|
+
data = _load_state()
|
|
146
|
+
labels = data.get("labels", {})
|
|
147
|
+
if resource_id in labels:
|
|
148
|
+
del labels[resource_id]
|
|
149
|
+
data["labels"] = labels
|
|
150
|
+
_save_state(data)
|