wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +22 -9
- wafer_core/lib/trace_compare/aligner.py +376 -0
- wafer_core/lib/trace_compare/analyzer.py +558 -159
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +307 -13
- wafer_core/lib/trace_compare/fusion_analyzer.py +280 -706
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +526 -227
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/METADATA +3 -1
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/RECORD +28 -10
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Baremetal provider — degenerate case with no cloud API.
|
|
2
|
+
|
|
3
|
+
Baremetal targets have no provisioning lifecycle. The "resource" is just the
|
|
4
|
+
SSH endpoint from the spec. list_targets returns nothing (no API to query),
|
|
5
|
+
and provision/terminate are errors.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from wafer_core.targets.types import Target, TargetSpec
|
|
11
|
+
from wafer_core.utils.kernel_utils.targets.config import BaremetalTarget, VMTarget
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def target_from_ssh_spec(spec: BaremetalTarget | VMTarget) -> Target:
|
|
15
|
+
"""Build a Target from a baremetal/VM spec's SSH info.
|
|
16
|
+
|
|
17
|
+
Since there's no cloud API, the resource_id is synthetic:
|
|
18
|
+
"baremetal:{host}:{port}" to make it unique and stable.
|
|
19
|
+
"""
|
|
20
|
+
# Parse user@host:port
|
|
21
|
+
ssh_target = spec.ssh_target
|
|
22
|
+
assert ":" in ssh_target, f"ssh_target must include port, got: {ssh_target}"
|
|
23
|
+
|
|
24
|
+
user_host, port_str = ssh_target.rsplit(":", 1)
|
|
25
|
+
if "@" in user_host:
|
|
26
|
+
user, host = user_host.split("@", 1)
|
|
27
|
+
else:
|
|
28
|
+
user = "root"
|
|
29
|
+
host = user_host
|
|
30
|
+
|
|
31
|
+
port = int(port_str)
|
|
32
|
+
|
|
33
|
+
return Target(
|
|
34
|
+
resource_id=f"baremetal:{host}:{port}",
|
|
35
|
+
provider="baremetal",
|
|
36
|
+
status="running", # Assumed running; TCP check happens elsewhere
|
|
37
|
+
public_ip=host,
|
|
38
|
+
ssh_port=port,
|
|
39
|
+
ssh_username=user,
|
|
40
|
+
gpu_type=spec.gpu_type,
|
|
41
|
+
name=spec.name,
|
|
42
|
+
spec_name=spec.name,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class BaremetalProvider:
|
|
47
|
+
"""Baremetal implementation of TargetProvider.
|
|
48
|
+
|
|
49
|
+
Baremetal has no cloud API. list_targets returns empty (no remote state
|
|
50
|
+
to query). Use target_from_ssh_spec() to build a Target from a spec
|
|
51
|
+
when you already know which spec you want.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
async def list_targets(self) -> list[Target]:
|
|
55
|
+
"""Baremetal has no API to list. Returns empty."""
|
|
56
|
+
return []
|
|
57
|
+
|
|
58
|
+
async def get_target(self, resource_id: str) -> Target | None:
|
|
59
|
+
"""Baremetal has no API to query. Returns None."""
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
async def provision(self, spec: TargetSpec) -> Target:
|
|
63
|
+
"""Baremetal targets cannot be provisioned — they already exist."""
|
|
64
|
+
assert isinstance(spec, (BaremetalTarget, VMTarget)), (
|
|
65
|
+
f"BaremetalProvider.provision requires BaremetalTarget or VMTarget, "
|
|
66
|
+
f"got {type(spec).__name__}"
|
|
67
|
+
)
|
|
68
|
+
return target_from_ssh_spec(spec)
|
|
69
|
+
|
|
70
|
+
async def terminate(self, resource_id: str) -> bool:
|
|
71
|
+
"""Baremetal targets cannot be terminated via API."""
|
|
72
|
+
return False
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""DigitalOcean provider — adapts existing DO REST API to TargetProvider protocol."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
|
|
9
|
+
from wafer_core.targets.digitalocean import (
|
|
10
|
+
DigitalOceanError,
|
|
11
|
+
_api_request_async,
|
|
12
|
+
_get_ssh_key_ids,
|
|
13
|
+
_wait_for_ssh,
|
|
14
|
+
)
|
|
15
|
+
from wafer_core.targets.digitalocean import (
|
|
16
|
+
terminate_droplet as _terminate_droplet,
|
|
17
|
+
)
|
|
18
|
+
from wafer_core.targets.types import Target, TargetSpec
|
|
19
|
+
from wafer_core.utils.kernel_utils.targets.config import DigitalOceanTarget
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _parse_droplet_to_target(droplet: dict) -> Target:
|
|
25
|
+
"""Parse a DigitalOcean API droplet response into a Target."""
|
|
26
|
+
droplet_id = str(droplet.get("id", ""))
|
|
27
|
+
droplet_name = droplet.get("name", "")
|
|
28
|
+
status_raw = droplet.get("status", "").lower()
|
|
29
|
+
|
|
30
|
+
# Map DO statuses to our values
|
|
31
|
+
# DO: new, active, off, archive
|
|
32
|
+
status_map = {
|
|
33
|
+
"new": "pending",
|
|
34
|
+
"active": "running",
|
|
35
|
+
"off": "stopped",
|
|
36
|
+
"archive": "terminated",
|
|
37
|
+
}
|
|
38
|
+
status = status_map.get(status_raw, status_raw)
|
|
39
|
+
|
|
40
|
+
# Extract public IP
|
|
41
|
+
public_ip = None
|
|
42
|
+
networks = droplet.get("networks", {})
|
|
43
|
+
for net in networks.get("v4", []):
|
|
44
|
+
if net.get("type") == "public":
|
|
45
|
+
public_ip = net.get("ip_address")
|
|
46
|
+
break
|
|
47
|
+
|
|
48
|
+
# Infer spec_name from naming convention: wafer-{spec_name}-{timestamp}
|
|
49
|
+
spec_name = None
|
|
50
|
+
if droplet_name.startswith("wafer-"):
|
|
51
|
+
parts = droplet_name.split("-")
|
|
52
|
+
if len(parts) >= 3:
|
|
53
|
+
spec_name = "-".join(parts[1:-1])
|
|
54
|
+
|
|
55
|
+
created_at = droplet.get("created_at")
|
|
56
|
+
|
|
57
|
+
# Extract GPU type from size slug
|
|
58
|
+
size = droplet.get("size", {})
|
|
59
|
+
size_slug = (
|
|
60
|
+
size.get("slug", "") if isinstance(size, dict) else str(droplet.get("size_slug", ""))
|
|
61
|
+
)
|
|
62
|
+
gpu_type = "MI300X" if "mi300x" in size_slug.lower() else "unknown"
|
|
63
|
+
|
|
64
|
+
return Target(
|
|
65
|
+
resource_id=droplet_id,
|
|
66
|
+
provider="digitalocean",
|
|
67
|
+
status=status,
|
|
68
|
+
public_ip=public_ip,
|
|
69
|
+
ssh_port=22,
|
|
70
|
+
ssh_username="root",
|
|
71
|
+
gpu_type=gpu_type,
|
|
72
|
+
name=droplet_name or None,
|
|
73
|
+
created_at=created_at,
|
|
74
|
+
spec_name=spec_name,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class DigitalOceanProvider:
|
|
79
|
+
"""DigitalOcean implementation of TargetProvider.
|
|
80
|
+
|
|
81
|
+
Wraps existing REST API calls for droplet management.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
async def list_targets(self) -> list[Target]:
|
|
85
|
+
"""List all droplets on the DigitalOcean account."""
|
|
86
|
+
try:
|
|
87
|
+
response = await _api_request_async("GET", "/droplets", params={"per_page": "200"})
|
|
88
|
+
except DigitalOceanError:
|
|
89
|
+
raise
|
|
90
|
+
except Exception as e:
|
|
91
|
+
logger.warning(f"Failed to list DigitalOcean droplets: {e}")
|
|
92
|
+
return []
|
|
93
|
+
|
|
94
|
+
droplets = response.get("droplets", [])
|
|
95
|
+
return [_parse_droplet_to_target(d) for d in droplets]
|
|
96
|
+
|
|
97
|
+
async def get_target(self, resource_id: str) -> Target | None:
|
|
98
|
+
"""Get a specific droplet by ID."""
|
|
99
|
+
try:
|
|
100
|
+
response = await _api_request_async("GET", f"/droplets/{resource_id}")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.warning(f"Failed to get DigitalOcean droplet {resource_id}: {e}")
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
droplet = response.get("droplet")
|
|
106
|
+
if not droplet:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
return _parse_droplet_to_target(droplet)
|
|
110
|
+
|
|
111
|
+
async def provision(self, spec: TargetSpec) -> Target:
|
|
112
|
+
"""Provision a new DigitalOcean droplet from a spec.
|
|
113
|
+
|
|
114
|
+
Blocks until SSH is ready.
|
|
115
|
+
"""
|
|
116
|
+
assert isinstance(spec, DigitalOceanTarget), (
|
|
117
|
+
f"DigitalOceanProvider.provision requires DigitalOceanTarget, got {type(spec).__name__}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
droplet_name = f"wafer-{spec.name}-{int(time.time())}"
|
|
121
|
+
|
|
122
|
+
ssh_key_ids = await _get_ssh_key_ids()
|
|
123
|
+
if not ssh_key_ids:
|
|
124
|
+
logger.warning("No SSH keys found - droplet may not be accessible")
|
|
125
|
+
|
|
126
|
+
create_data = {
|
|
127
|
+
"name": droplet_name,
|
|
128
|
+
"region": spec.region,
|
|
129
|
+
"size": spec.size_slug,
|
|
130
|
+
"image": spec.image,
|
|
131
|
+
"ssh_keys": ssh_key_ids,
|
|
132
|
+
"backups": False,
|
|
133
|
+
"ipv6": True,
|
|
134
|
+
"monitoring": True,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
logger.info(f"Provisioning DigitalOcean droplet: {droplet_name}")
|
|
138
|
+
response = await _api_request_async("POST", "/droplets", data=create_data)
|
|
139
|
+
|
|
140
|
+
if not response or "droplet" not in response:
|
|
141
|
+
raise DigitalOceanError(f"Failed to create droplet: {response}")
|
|
142
|
+
|
|
143
|
+
droplet = response["droplet"]
|
|
144
|
+
droplet_id = str(droplet["id"])
|
|
145
|
+
logger.info(f"Droplet created: {droplet_id}")
|
|
146
|
+
|
|
147
|
+
public_ip = await _wait_for_ssh(droplet_id, spec.provision_timeout)
|
|
148
|
+
|
|
149
|
+
return Target(
|
|
150
|
+
resource_id=droplet_id,
|
|
151
|
+
provider="digitalocean",
|
|
152
|
+
status="running",
|
|
153
|
+
public_ip=public_ip,
|
|
154
|
+
ssh_port=22,
|
|
155
|
+
ssh_username="root",
|
|
156
|
+
gpu_type=spec.gpu_type,
|
|
157
|
+
name=droplet_name,
|
|
158
|
+
created_at=datetime.now(timezone.utc).isoformat(),
|
|
159
|
+
spec_name=spec.name,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def terminate(self, resource_id: str) -> bool:
|
|
163
|
+
"""Terminate a DigitalOcean droplet."""
|
|
164
|
+
return await _terminate_droplet(resource_id)
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""RunPod provider — adapts existing RunPod GraphQL API to TargetProvider protocol."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
|
|
9
|
+
from wafer_core.targets.runpod import (
|
|
10
|
+
RunPodError,
|
|
11
|
+
_graphql_request_async,
|
|
12
|
+
_wait_for_ssh,
|
|
13
|
+
)
|
|
14
|
+
from wafer_core.targets.runpod import (
|
|
15
|
+
terminate_pod as _terminate_pod,
|
|
16
|
+
)
|
|
17
|
+
from wafer_core.targets.types import Target, TargetSpec
|
|
18
|
+
from wafer_core.utils.kernel_utils.targets.config import RunPodTarget
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _parse_pod_to_target(pod: dict) -> Target | None:
|
|
24
|
+
"""Parse a RunPod API pod response into a Target.
|
|
25
|
+
|
|
26
|
+
Returns None if the pod has no usable SSH info.
|
|
27
|
+
"""
|
|
28
|
+
pod_id = pod.get("id", "")
|
|
29
|
+
pod_name = pod.get("name", "")
|
|
30
|
+
status_raw = pod.get("desiredStatus", "").lower()
|
|
31
|
+
|
|
32
|
+
# Map RunPod statuses to our status values
|
|
33
|
+
status = status_raw if status_raw else "unknown"
|
|
34
|
+
|
|
35
|
+
# Extract SSH info from runtime ports
|
|
36
|
+
public_ip = None
|
|
37
|
+
ssh_port = None
|
|
38
|
+
runtime = pod.get("runtime")
|
|
39
|
+
if runtime:
|
|
40
|
+
for port in runtime.get("ports") or []:
|
|
41
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
42
|
+
ip = port.get("ip")
|
|
43
|
+
# Skip proxy SSH (ssh.runpod.io), want direct IP
|
|
44
|
+
if ip and ip != "ssh.runpod.io":
|
|
45
|
+
public_ip = ip
|
|
46
|
+
ssh_port = port.get("publicPort")
|
|
47
|
+
break
|
|
48
|
+
|
|
49
|
+
# Infer spec_name from pod naming convention: wafer-{spec_name}-{timestamp}
|
|
50
|
+
spec_name = None
|
|
51
|
+
if pod_name.startswith("wafer-"):
|
|
52
|
+
parts = pod_name.split("-")
|
|
53
|
+
if len(parts) >= 3:
|
|
54
|
+
spec_name = "-".join(parts[1:-1])
|
|
55
|
+
|
|
56
|
+
# Extract GPU type
|
|
57
|
+
gpu_type = ""
|
|
58
|
+
machine = pod.get("machine")
|
|
59
|
+
if machine:
|
|
60
|
+
gpu_type_info = machine.get("gpuType")
|
|
61
|
+
if gpu_type_info:
|
|
62
|
+
gpu_type = gpu_type_info.get("displayName", "")
|
|
63
|
+
|
|
64
|
+
cost = pod.get("costPerHr")
|
|
65
|
+
|
|
66
|
+
return Target(
|
|
67
|
+
resource_id=pod_id,
|
|
68
|
+
provider="runpod",
|
|
69
|
+
status=status,
|
|
70
|
+
public_ip=public_ip,
|
|
71
|
+
ssh_port=ssh_port,
|
|
72
|
+
ssh_username="root",
|
|
73
|
+
gpu_type=gpu_type,
|
|
74
|
+
name=pod_name or None,
|
|
75
|
+
created_at=None, # RunPod API doesn't expose creation time in list query
|
|
76
|
+
spec_name=spec_name,
|
|
77
|
+
price_per_hour=float(cost) if cost else None,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RunPodProvider:
|
|
82
|
+
"""RunPod implementation of TargetProvider.
|
|
83
|
+
|
|
84
|
+
Wraps existing GraphQL API calls:
|
|
85
|
+
- list_targets: myself { pods { ... } }
|
|
86
|
+
- get_target: pod(input: { podId }) { ... }
|
|
87
|
+
- provision: podFindAndDeployOnDemand
|
|
88
|
+
- terminate: podTerminate
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
async def list_targets(self) -> list[Target]:
|
|
92
|
+
"""List all running pods on the RunPod account."""
|
|
93
|
+
query = """
|
|
94
|
+
query {
|
|
95
|
+
myself {
|
|
96
|
+
pods {
|
|
97
|
+
id
|
|
98
|
+
name
|
|
99
|
+
desiredStatus
|
|
100
|
+
costPerHr
|
|
101
|
+
machine {
|
|
102
|
+
podHostId
|
|
103
|
+
gpuType {
|
|
104
|
+
displayName
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
runtime {
|
|
108
|
+
ports {
|
|
109
|
+
ip
|
|
110
|
+
isIpPublic
|
|
111
|
+
privatePort
|
|
112
|
+
publicPort
|
|
113
|
+
type
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
data = await _graphql_request_async(query)
|
|
123
|
+
except RunPodError:
|
|
124
|
+
raise
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.warning(f"Failed to list RunPod pods: {e}")
|
|
127
|
+
return []
|
|
128
|
+
|
|
129
|
+
pods = data.get("myself", {}).get("pods", [])
|
|
130
|
+
targets = []
|
|
131
|
+
|
|
132
|
+
for pod in pods:
|
|
133
|
+
target = _parse_pod_to_target(pod)
|
|
134
|
+
if target is not None:
|
|
135
|
+
targets.append(target)
|
|
136
|
+
|
|
137
|
+
return targets
|
|
138
|
+
|
|
139
|
+
async def get_target(self, resource_id: str) -> Target | None:
|
|
140
|
+
"""Get a specific pod by ID."""
|
|
141
|
+
query = """
|
|
142
|
+
query pod($input: PodFilter!) {
|
|
143
|
+
pod(input: $input) {
|
|
144
|
+
id
|
|
145
|
+
name
|
|
146
|
+
desiredStatus
|
|
147
|
+
costPerHr
|
|
148
|
+
machine {
|
|
149
|
+
podHostId
|
|
150
|
+
gpuType {
|
|
151
|
+
displayName
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
runtime {
|
|
155
|
+
ports {
|
|
156
|
+
ip
|
|
157
|
+
isIpPublic
|
|
158
|
+
privatePort
|
|
159
|
+
publicPort
|
|
160
|
+
type
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
"""
|
|
166
|
+
variables = {"input": {"podId": resource_id}}
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
data = await _graphql_request_async(query, variables)
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.warning(f"Failed to get RunPod pod {resource_id}: {e}")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
pod = data.get("pod")
|
|
175
|
+
if not pod:
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
return _parse_pod_to_target(pod)
|
|
179
|
+
|
|
180
|
+
async def provision(self, spec: TargetSpec) -> Target:
|
|
181
|
+
"""Provision a new RunPod pod from a spec.
|
|
182
|
+
|
|
183
|
+
Blocks until SSH is ready.
|
|
184
|
+
"""
|
|
185
|
+
assert isinstance(spec, RunPodTarget), (
|
|
186
|
+
f"RunPodProvider.provision requires RunPodTarget, got {type(spec).__name__}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
pod_name = f"wafer-{spec.name}-{int(time.time())}"
|
|
190
|
+
|
|
191
|
+
mutation = """
|
|
192
|
+
mutation podFindAndDeployOnDemand($input: PodFindAndDeployOnDemandInput!) {
|
|
193
|
+
podFindAndDeployOnDemand(input: $input) {
|
|
194
|
+
id
|
|
195
|
+
machineId
|
|
196
|
+
machine {
|
|
197
|
+
podHostId
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
pod_input: dict = {
|
|
204
|
+
"gpuTypeId": spec.gpu_type_id,
|
|
205
|
+
"gpuCount": spec.gpu_count,
|
|
206
|
+
"cloudType": "SECURE",
|
|
207
|
+
"name": pod_name,
|
|
208
|
+
"supportPublicIp": True,
|
|
209
|
+
"containerDiskInGb": spec.container_disk_gb,
|
|
210
|
+
"minVcpuCount": 1,
|
|
211
|
+
"minMemoryInGb": 4,
|
|
212
|
+
"ports": "22/tcp",
|
|
213
|
+
"startSsh": True,
|
|
214
|
+
"startJupyter": False,
|
|
215
|
+
"env": [],
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
if spec.template_id:
|
|
219
|
+
pod_input["templateId"] = spec.template_id
|
|
220
|
+
else:
|
|
221
|
+
pod_input["imageName"] = spec.image
|
|
222
|
+
|
|
223
|
+
logger.info(f"Provisioning RunPod pod: {pod_name}")
|
|
224
|
+
data = await _graphql_request_async(mutation, {"input": pod_input})
|
|
225
|
+
|
|
226
|
+
pod_data = data.get("podFindAndDeployOnDemand")
|
|
227
|
+
if not pod_data:
|
|
228
|
+
raise RunPodError("No pod returned from deployment")
|
|
229
|
+
|
|
230
|
+
pod_id = pod_data["id"]
|
|
231
|
+
logger.info(f"Pod created: {pod_id}")
|
|
232
|
+
|
|
233
|
+
public_ip, ssh_port, ssh_username = await _wait_for_ssh(pod_id, spec.provision_timeout)
|
|
234
|
+
|
|
235
|
+
return Target(
|
|
236
|
+
resource_id=pod_id,
|
|
237
|
+
provider="runpod",
|
|
238
|
+
status="running",
|
|
239
|
+
public_ip=public_ip,
|
|
240
|
+
ssh_port=ssh_port,
|
|
241
|
+
ssh_username=ssh_username,
|
|
242
|
+
gpu_type=spec.gpu_type,
|
|
243
|
+
name=pod_name,
|
|
244
|
+
created_at=datetime.now(timezone.utc).isoformat(),
|
|
245
|
+
spec_name=spec.name,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
async def terminate(self, resource_id: str) -> bool:
|
|
249
|
+
"""Terminate a RunPod pod."""
|
|
250
|
+
return await _terminate_pod(resource_id)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Reconciliation: compare TargetSpecs to live Targets.
|
|
2
|
+
|
|
3
|
+
Pure function — no API calls, no side effects. Takes specs and targets as
|
|
4
|
+
inputs, returns a ReconcileResult describing what's bound, what's orphaned,
|
|
5
|
+
and what's unprovisioned.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from wafer_core.targets.types import ReconcileResult, Target, TargetSpec
|
|
11
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
12
|
+
BaremetalTarget,
|
|
13
|
+
DigitalOceanTarget,
|
|
14
|
+
RunPodTarget,
|
|
15
|
+
VMTarget,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_cloud_spec(spec: TargetSpec) -> bool:
|
|
20
|
+
"""Check if a spec represents a cloud-provisioned resource.
|
|
21
|
+
|
|
22
|
+
Baremetal and VM specs don't have cloud-managed lifecycles,
|
|
23
|
+
so they're excluded from "unprovisioned" checks.
|
|
24
|
+
"""
|
|
25
|
+
return isinstance(spec, (RunPodTarget, DigitalOceanTarget))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _spec_provider(spec: TargetSpec) -> str | None:
|
|
29
|
+
"""Get the provider name for a spec, or None if not cloud-managed."""
|
|
30
|
+
if isinstance(spec, RunPodTarget):
|
|
31
|
+
return "runpod"
|
|
32
|
+
if isinstance(spec, DigitalOceanTarget):
|
|
33
|
+
return "digitalocean"
|
|
34
|
+
if isinstance(spec, (BaremetalTarget, VMTarget)):
|
|
35
|
+
return "baremetal"
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def reconcile(
|
|
40
|
+
specs: list[TargetSpec],
|
|
41
|
+
targets: list[Target],
|
|
42
|
+
binding_hints: dict[str, str] | None = None,
|
|
43
|
+
) -> ReconcileResult:
|
|
44
|
+
"""Compare specs to live targets and classify each.
|
|
45
|
+
|
|
46
|
+
Matching rules (in priority order):
|
|
47
|
+
1. Target.spec_name matches Spec.name exactly (set by naming convention
|
|
48
|
+
or explicit binding).
|
|
49
|
+
2. binding_hints maps resource_id → spec_name (from local cache).
|
|
50
|
+
3. No match → target is unbound (orphan).
|
|
51
|
+
|
|
52
|
+
A cloud spec with no matching target is "unprovisioned".
|
|
53
|
+
Baremetal/VM specs are never "unprovisioned" (they don't have a cloud
|
|
54
|
+
lifecycle — the machine is always there or it isn't).
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
specs: All known TargetSpecs (loaded from TOML files).
|
|
58
|
+
targets: All live Targets (fetched from provider APIs).
|
|
59
|
+
binding_hints: Optional resource_id → spec_name cache for targets
|
|
60
|
+
whose spec_name can't be inferred from naming convention.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
ReconcileResult with bound, unbound, and unprovisioned lists.
|
|
64
|
+
"""
|
|
65
|
+
hints = binding_hints or {}
|
|
66
|
+
spec_by_name = {s.name: s for s in specs}
|
|
67
|
+
claimed_spec_names: set[str] = set()
|
|
68
|
+
|
|
69
|
+
bound: list[tuple[TargetSpec, Target]] = []
|
|
70
|
+
unbound: list[Target] = []
|
|
71
|
+
|
|
72
|
+
for target in targets:
|
|
73
|
+
# Try to find the spec this target belongs to
|
|
74
|
+
resolved_spec_name = target.spec_name or hints.get(target.resource_id)
|
|
75
|
+
|
|
76
|
+
if resolved_spec_name and resolved_spec_name in spec_by_name:
|
|
77
|
+
spec = spec_by_name[resolved_spec_name]
|
|
78
|
+
bound.append((spec, target))
|
|
79
|
+
claimed_spec_names.add(resolved_spec_name)
|
|
80
|
+
else:
|
|
81
|
+
unbound.append(target)
|
|
82
|
+
|
|
83
|
+
# Cloud specs with no bound target are unprovisioned
|
|
84
|
+
unprovisioned = [s for s in specs if s.name not in claimed_spec_names and _is_cloud_spec(s)]
|
|
85
|
+
|
|
86
|
+
return ReconcileResult(
|
|
87
|
+
bound=bound,
|
|
88
|
+
unbound=unbound,
|
|
89
|
+
unprovisioned=unprovisioned,
|
|
90
|
+
)
|