wafer-core 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/aligner.py +13 -6
- wafer_core/lib/trace_compare/analyzer.py +12 -3
- wafer_core/lib/trace_compare/fusion_analyzer.py +392 -284
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/METADATA +1 -1
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/RECORD +18 -8
- {wafer_core-0.1.27.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
wafer_core/targets/__init__.py
CHANGED
|
@@ -1,5 +1,43 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Target system: specs (config) + targets (live resources) + reconciliation.
|
|
2
2
|
|
|
3
|
+
New API (preferred):
|
|
4
|
+
from wafer_core.targets import Target, TargetSpec, ReconcileResult
|
|
5
|
+
from wafer_core.targets.providers import get_provider
|
|
6
|
+
from wafer_core.targets.reconcile import reconcile
|
|
7
|
+
from wafer_core.targets.spec_store import load_spec, list_spec_names
|
|
8
|
+
from wafer_core.targets.state_cache import get_binding_hints
|
|
9
|
+
|
|
10
|
+
Legacy API (still works, will be deprecated):
|
|
11
|
+
from wafer_core.targets import RunPodTarget, runpod_ssh_context, ...
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# ── New types ────────────────────────────────────────────────────────────────
|
|
15
|
+
from wafer_core.targets.digitalocean import (
|
|
16
|
+
DigitalOceanError,
|
|
17
|
+
DigitalOceanSSHInfo,
|
|
18
|
+
cleanup_all_droplets,
|
|
19
|
+
cleanup_digitalocean_target,
|
|
20
|
+
digitalocean_ssh_context,
|
|
21
|
+
get_droplet_state,
|
|
22
|
+
list_running_droplets,
|
|
23
|
+
)
|
|
24
|
+
from wafer_core.targets.runpod import (
|
|
25
|
+
RunPodError,
|
|
26
|
+
RunPodSSHInfo,
|
|
27
|
+
cleanup_all_pods,
|
|
28
|
+
cleanup_target,
|
|
29
|
+
get_pod_state,
|
|
30
|
+
list_running_pods,
|
|
31
|
+
runpod_ssh_context,
|
|
32
|
+
)
|
|
33
|
+
from wafer_core.targets.types import (
|
|
34
|
+
ReconcileResult,
|
|
35
|
+
Target,
|
|
36
|
+
TargetProvider,
|
|
37
|
+
TargetSpec,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# ── Legacy re-exports (unchanged, for backwards compatibility) ───────────────
|
|
3
41
|
from wafer_core.utils.kernel_utils.targets import (
|
|
4
42
|
BaremetalTarget,
|
|
5
43
|
DigitalOceanTarget,
|
|
@@ -18,26 +56,14 @@ from wafer_core.utils.kernel_utils.targets import (
|
|
|
18
56
|
select_target_for_operation,
|
|
19
57
|
target_to_deployment_config,
|
|
20
58
|
)
|
|
21
|
-
from wafer_core.targets.runpod import (
|
|
22
|
-
RunPodError,
|
|
23
|
-
RunPodSSHInfo,
|
|
24
|
-
cleanup_all_pods,
|
|
25
|
-
cleanup_target,
|
|
26
|
-
get_pod_state,
|
|
27
|
-
list_running_pods,
|
|
28
|
-
runpod_ssh_context,
|
|
29
|
-
)
|
|
30
|
-
from wafer_core.targets.digitalocean import (
|
|
31
|
-
DigitalOceanError,
|
|
32
|
-
DigitalOceanSSHInfo,
|
|
33
|
-
cleanup_all_droplets,
|
|
34
|
-
cleanup_digitalocean_target,
|
|
35
|
-
digitalocean_ssh_context,
|
|
36
|
-
get_droplet_state,
|
|
37
|
-
list_running_droplets,
|
|
38
|
-
)
|
|
39
59
|
|
|
40
60
|
__all__ = [
|
|
61
|
+
# New API
|
|
62
|
+
"Target",
|
|
63
|
+
"TargetSpec",
|
|
64
|
+
"TargetProvider",
|
|
65
|
+
"ReconcileResult",
|
|
66
|
+
# Legacy: target config types
|
|
41
67
|
"BaremetalTarget",
|
|
42
68
|
"VMTarget",
|
|
43
69
|
"ModalTarget",
|
|
@@ -54,7 +80,7 @@ __all__ = [
|
|
|
54
80
|
"check_target_available",
|
|
55
81
|
"find_free_gpu",
|
|
56
82
|
"run_operation_on_target",
|
|
57
|
-
# RunPod provisioning
|
|
83
|
+
# Legacy: RunPod provisioning
|
|
58
84
|
"RunPodError",
|
|
59
85
|
"RunPodSSHInfo",
|
|
60
86
|
"runpod_ssh_context",
|
|
@@ -62,7 +88,7 @@ __all__ = [
|
|
|
62
88
|
"cleanup_all_pods",
|
|
63
89
|
"list_running_pods",
|
|
64
90
|
"get_pod_state",
|
|
65
|
-
# DigitalOcean provisioning
|
|
91
|
+
# Legacy: DigitalOcean provisioning
|
|
66
92
|
"DigitalOceanError",
|
|
67
93
|
"DigitalOceanSSHInfo",
|
|
68
94
|
"digitalocean_ssh_context",
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Pool queries: filter live targets by GPU type, provider, and labels.
|
|
2
|
+
|
|
3
|
+
A pool is a predicate over live targets, not a hardcoded list.
|
|
4
|
+
Pool queries are defined in ~/.wafer/config.toml:
|
|
5
|
+
|
|
6
|
+
[pools.mi300x]
|
|
7
|
+
gpu_type = "MI300X"
|
|
8
|
+
|
|
9
|
+
[pools.mi300x-rocm7]
|
|
10
|
+
gpu_type = "MI300X"
|
|
11
|
+
labels.rocm_version = "7.0.2"
|
|
12
|
+
|
|
13
|
+
[pools.runpod-only]
|
|
14
|
+
provider = "runpod"
|
|
15
|
+
|
|
16
|
+
Matching: a target matches a pool query if all specified fields match.
|
|
17
|
+
Fields not specified in the query are ignored (match anything).
|
|
18
|
+
Label matching is AND — all required labels must be present and equal.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
from wafer_core.targets.types import Target
|
|
27
|
+
|
|
28
|
+
WAFER_DIR = Path.home() / ".wafer"
|
|
29
|
+
CONFIG_FILE = WAFER_DIR / "config.toml"
|
|
30
|
+
|
|
31
|
+
# Fields on PoolQuery that map directly to Target fields
|
|
32
|
+
_TARGET_FIELDS = ("gpu_type", "provider", "status")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True)
|
|
36
|
+
class PoolQuery:
|
|
37
|
+
"""Predicate for filtering live targets.
|
|
38
|
+
|
|
39
|
+
All specified fields must match (AND semantics).
|
|
40
|
+
None means "don't care" for that field.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
gpu_type: str | None = None
|
|
44
|
+
provider: str | None = None
|
|
45
|
+
status: str | None = "running"
|
|
46
|
+
labels: dict[str, str] = field(default_factory=dict)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def match_targets(query: PoolQuery, targets: list[Target]) -> list[Target]:
|
|
50
|
+
"""Filter targets that satisfy the pool query. Pure function."""
|
|
51
|
+
matched = []
|
|
52
|
+
for target in targets:
|
|
53
|
+
if not _matches(query, target):
|
|
54
|
+
continue
|
|
55
|
+
matched.append(target)
|
|
56
|
+
return matched
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _matches(query: PoolQuery, target: Target) -> bool:
|
|
60
|
+
"""Check if a single target satisfies the query."""
|
|
61
|
+
if query.gpu_type is not None and target.gpu_type != query.gpu_type:
|
|
62
|
+
return False
|
|
63
|
+
if query.provider is not None and target.provider != query.provider:
|
|
64
|
+
return False
|
|
65
|
+
if query.status is not None and target.status != query.status:
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
# All required labels must be present and equal
|
|
69
|
+
for key, value in query.labels.items():
|
|
70
|
+
if target.labels.get(key) != value:
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def load_pool_query(name: str) -> PoolQuery:
|
|
77
|
+
"""Load a pool query from ~/.wafer/config.toml.
|
|
78
|
+
|
|
79
|
+
Raises KeyError if the pool is not defined.
|
|
80
|
+
"""
|
|
81
|
+
pools = _load_pools_section()
|
|
82
|
+
if name not in pools:
|
|
83
|
+
available = ", ".join(sorted(pools)) if pools else "(none)"
|
|
84
|
+
raise KeyError(f"Pool {name!r} not found. Available: {available}")
|
|
85
|
+
|
|
86
|
+
raw = pools[name]
|
|
87
|
+
assert isinstance(raw, dict), f"Pool {name!r} must be a table, got {type(raw).__name__}"
|
|
88
|
+
|
|
89
|
+
labels_raw = raw.get("labels", {})
|
|
90
|
+
assert isinstance(labels_raw, dict), (
|
|
91
|
+
f"Pool {name!r} labels must be a table, got {type(labels_raw).__name__}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return PoolQuery(
|
|
95
|
+
gpu_type=raw.get("gpu_type"),
|
|
96
|
+
provider=raw.get("provider"),
|
|
97
|
+
status=raw.get("status", "running"),
|
|
98
|
+
labels={str(k): str(v) for k, v in labels_raw.items()},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def list_pool_names() -> list[str]:
|
|
103
|
+
"""List all pool names from config.toml."""
|
|
104
|
+
pools = _load_pools_section()
|
|
105
|
+
return sorted(pools.keys())
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def is_query_pool(name: str) -> bool:
|
|
109
|
+
"""Check if a pool is defined as a PoolQuery (new format) vs target list (old format).
|
|
110
|
+
|
|
111
|
+
Old format: [pools.name] targets = ["t1", "t2"]
|
|
112
|
+
New format: [pools.name] gpu_type = "MI300X"
|
|
113
|
+
|
|
114
|
+
Returns False if pool doesn't exist or is old format.
|
|
115
|
+
"""
|
|
116
|
+
pools = _load_pools_section()
|
|
117
|
+
if name not in pools:
|
|
118
|
+
return False
|
|
119
|
+
raw = pools[name]
|
|
120
|
+
if not isinstance(raw, dict):
|
|
121
|
+
return False
|
|
122
|
+
# Old format has a "targets" key with a list of names
|
|
123
|
+
return "targets" not in raw
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def resolve_pool(name: str) -> list[Target]:
|
|
127
|
+
"""Resolve a pool query to live targets.
|
|
128
|
+
|
|
129
|
+
Queries all cloud providers, hydrates cached labels, filters by pool query.
|
|
130
|
+
Returns matching Target objects sorted by resource_id for determinism.
|
|
131
|
+
|
|
132
|
+
Raises KeyError if pool not found.
|
|
133
|
+
"""
|
|
134
|
+
from dataclasses import replace
|
|
135
|
+
|
|
136
|
+
from wafer_core.targets.providers import get_all_cloud_providers
|
|
137
|
+
from wafer_core.targets.state_cache import load_all_labels
|
|
138
|
+
from wafer_core.targets.types import TargetProvider
|
|
139
|
+
|
|
140
|
+
import trio
|
|
141
|
+
|
|
142
|
+
query = load_pool_query(name)
|
|
143
|
+
|
|
144
|
+
# Fetch all live targets
|
|
145
|
+
all_targets: list[Target] = []
|
|
146
|
+
|
|
147
|
+
async def _fetch(prov_impl: TargetProvider, results: list[Target]) -> None:
|
|
148
|
+
try:
|
|
149
|
+
targets = await prov_impl.list_targets()
|
|
150
|
+
results.extend(targets)
|
|
151
|
+
except Exception:
|
|
152
|
+
pass # Skip providers that fail (missing API key, etc.)
|
|
153
|
+
|
|
154
|
+
async with trio.open_nursery() as nursery:
|
|
155
|
+
for _, prov_impl in get_all_cloud_providers():
|
|
156
|
+
nursery.start_soon(_fetch, prov_impl, all_targets)
|
|
157
|
+
|
|
158
|
+
# Hydrate labels from cache
|
|
159
|
+
cached_labels = load_all_labels()
|
|
160
|
+
all_targets = [
|
|
161
|
+
replace(t, labels=cached_labels[t.resource_id])
|
|
162
|
+
if t.resource_id in cached_labels
|
|
163
|
+
else t
|
|
164
|
+
for t in all_targets
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
# Filter and sort
|
|
168
|
+
matched = match_targets(query, all_targets)
|
|
169
|
+
matched.sort(key=lambda t: t.resource_id)
|
|
170
|
+
return matched
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _load_pools_section() -> dict:
|
|
174
|
+
"""Read the [pools] section from config.toml. Returns empty dict if missing."""
|
|
175
|
+
if not CONFIG_FILE.exists():
|
|
176
|
+
return {}
|
|
177
|
+
|
|
178
|
+
import tomllib
|
|
179
|
+
|
|
180
|
+
data = tomllib.loads(CONFIG_FILE.read_text())
|
|
181
|
+
return data.get("pools", {})
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""SSH probe: detect software labels on a live target.
|
|
2
|
+
|
|
3
|
+
Runs a Python script on the target via SSH that reports installed
|
|
4
|
+
software versions. Returns a flat dict[str, str] of labels.
|
|
5
|
+
|
|
6
|
+
Only called at provision time or manually via `wafer targets probe`.
|
|
7
|
+
Results are cached in target_state.json — probe is never implicit.
|
|
8
|
+
|
|
9
|
+
Uses subprocess ssh (not asyncssh) to match existing codebase patterns.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
import subprocess
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
# Probe script runs on the target machine via SSH.
|
|
21
|
+
# Prints a JSON dict to stdout. Must work with stock Python 3.10+.
|
|
22
|
+
_PROBE_SCRIPT = r"""
|
|
23
|
+
import json, shutil, subprocess, sys
|
|
24
|
+
|
|
25
|
+
def probe():
|
|
26
|
+
result = {}
|
|
27
|
+
|
|
28
|
+
# Python version
|
|
29
|
+
result["python_version"] = ".".join(map(str, sys.version_info[:2]))
|
|
30
|
+
|
|
31
|
+
# ROCm version from filesystem
|
|
32
|
+
try:
|
|
33
|
+
with open("/opt/rocm/.info/version") as f:
|
|
34
|
+
result["rocm_version"] = f.read().strip().split("-")[0]
|
|
35
|
+
except Exception:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
# CUDA version from nvcc
|
|
39
|
+
nvcc = shutil.which("nvcc")
|
|
40
|
+
if nvcc:
|
|
41
|
+
try:
|
|
42
|
+
out = subprocess.check_output([nvcc, "--version"], text=True)
|
|
43
|
+
for line in out.split("\n"):
|
|
44
|
+
if "release" in line.lower():
|
|
45
|
+
parts = line.split("release")
|
|
46
|
+
if len(parts) > 1:
|
|
47
|
+
result["cuda_version"] = parts[1].split(",")[0].strip()
|
|
48
|
+
break
|
|
49
|
+
except Exception:
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
# PyTorch version
|
|
53
|
+
try:
|
|
54
|
+
import torch
|
|
55
|
+
result["pytorch_version"] = torch.__version__.split("+")[0]
|
|
56
|
+
except ImportError:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
# Triton version
|
|
60
|
+
try:
|
|
61
|
+
import triton
|
|
62
|
+
result["triton_version"] = triton.__version__
|
|
63
|
+
except ImportError:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
print(json.dumps(result))
|
|
67
|
+
|
|
68
|
+
probe()
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def probe_target_labels(
|
|
73
|
+
host: str,
|
|
74
|
+
port: int,
|
|
75
|
+
username: str,
|
|
76
|
+
ssh_key_path: str | None = None,
|
|
77
|
+
timeout: int = 60,
|
|
78
|
+
) -> dict[str, str]:
|
|
79
|
+
"""SSH into a target and probe installed software. Returns labels dict.
|
|
80
|
+
|
|
81
|
+
Raises on SSH failure — caller decides how to handle.
|
|
82
|
+
"""
|
|
83
|
+
ssh_args = [
|
|
84
|
+
"ssh",
|
|
85
|
+
"-p", str(port),
|
|
86
|
+
"-o", "StrictHostKeyChecking=no",
|
|
87
|
+
"-o", "UserKnownHostsFile=/dev/null",
|
|
88
|
+
"-o", "LogLevel=ERROR",
|
|
89
|
+
"-o", "ConnectTimeout=10",
|
|
90
|
+
]
|
|
91
|
+
if ssh_key_path:
|
|
92
|
+
ssh_args.extend(["-i", ssh_key_path])
|
|
93
|
+
|
|
94
|
+
ssh_args.append(f"{username}@{host}")
|
|
95
|
+
ssh_args.append("python3")
|
|
96
|
+
|
|
97
|
+
result = subprocess.run(
|
|
98
|
+
ssh_args,
|
|
99
|
+
input=_PROBE_SCRIPT,
|
|
100
|
+
capture_output=True,
|
|
101
|
+
text=True,
|
|
102
|
+
timeout=timeout,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if result.returncode != 0:
|
|
106
|
+
stderr = result.stderr.strip()
|
|
107
|
+
raise RuntimeError(f"Probe failed (exit {result.returncode}): {stderr}")
|
|
108
|
+
|
|
109
|
+
stdout = result.stdout.strip()
|
|
110
|
+
labels = json.loads(stdout)
|
|
111
|
+
assert isinstance(labels, dict), f"Probe returned {type(labels).__name__}, expected dict"
|
|
112
|
+
|
|
113
|
+
return {str(k): str(v) for k, v in labels.items()}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Provider registry for GPU resource management.
|
|
2
|
+
|
|
3
|
+
Each provider implements TargetProvider protocol: list, get, provision, terminate.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from wafer_core.targets.providers.baremetal import BaremetalProvider
|
|
9
|
+
from wafer_core.targets.providers.digitalocean import DigitalOceanProvider
|
|
10
|
+
from wafer_core.targets.providers.runpod import RunPodProvider
|
|
11
|
+
from wafer_core.targets.types import TargetProvider
|
|
12
|
+
|
|
13
|
+
_PROVIDERS: dict[str, type] = {
|
|
14
|
+
"runpod": RunPodProvider,
|
|
15
|
+
"digitalocean": DigitalOceanProvider,
|
|
16
|
+
"baremetal": BaremetalProvider,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_provider(name: str) -> TargetProvider:
|
|
21
|
+
"""Get a provider instance by name.
|
|
22
|
+
|
|
23
|
+
Raises KeyError if provider is not registered.
|
|
24
|
+
"""
|
|
25
|
+
cls = _PROVIDERS.get(name)
|
|
26
|
+
if cls is None:
|
|
27
|
+
raise KeyError(f"Unknown provider: {name!r}. Available: {', '.join(sorted(_PROVIDERS))}")
|
|
28
|
+
return cls()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_all_cloud_providers() -> list[tuple[str, TargetProvider]]:
|
|
32
|
+
"""Get all cloud providers that can list remote resources.
|
|
33
|
+
|
|
34
|
+
Excludes baremetal (no remote API to query).
|
|
35
|
+
Returns list of (name, provider) tuples.
|
|
36
|
+
"""
|
|
37
|
+
return [(name, cls()) for name, cls in _PROVIDERS.items() if name != "baremetal"]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
"BaremetalProvider",
|
|
42
|
+
"DigitalOceanProvider",
|
|
43
|
+
"RunPodProvider",
|
|
44
|
+
"get_all_cloud_providers",
|
|
45
|
+
"get_provider",
|
|
46
|
+
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Baremetal provider — degenerate case with no cloud API.
|
|
2
|
+
|
|
3
|
+
Baremetal targets have no provisioning lifecycle. The "resource" is just the
|
|
4
|
+
SSH endpoint from the spec. list_targets returns nothing (no API to query),
|
|
5
|
+
and provision/terminate are errors.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from wafer_core.targets.types import Target, TargetSpec
|
|
11
|
+
from wafer_core.utils.kernel_utils.targets.config import BaremetalTarget, VMTarget
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def target_from_ssh_spec(spec: BaremetalTarget | VMTarget) -> Target:
|
|
15
|
+
"""Build a Target from a baremetal/VM spec's SSH info.
|
|
16
|
+
|
|
17
|
+
Since there's no cloud API, the resource_id is synthetic:
|
|
18
|
+
"baremetal:{host}:{port}" to make it unique and stable.
|
|
19
|
+
"""
|
|
20
|
+
# Parse user@host:port
|
|
21
|
+
ssh_target = spec.ssh_target
|
|
22
|
+
assert ":" in ssh_target, f"ssh_target must include port, got: {ssh_target}"
|
|
23
|
+
|
|
24
|
+
user_host, port_str = ssh_target.rsplit(":", 1)
|
|
25
|
+
if "@" in user_host:
|
|
26
|
+
user, host = user_host.split("@", 1)
|
|
27
|
+
else:
|
|
28
|
+
user = "root"
|
|
29
|
+
host = user_host
|
|
30
|
+
|
|
31
|
+
port = int(port_str)
|
|
32
|
+
|
|
33
|
+
return Target(
|
|
34
|
+
resource_id=f"baremetal:{host}:{port}",
|
|
35
|
+
provider="baremetal",
|
|
36
|
+
status="running", # Assumed running; TCP check happens elsewhere
|
|
37
|
+
public_ip=host,
|
|
38
|
+
ssh_port=port,
|
|
39
|
+
ssh_username=user,
|
|
40
|
+
gpu_type=spec.gpu_type,
|
|
41
|
+
name=spec.name,
|
|
42
|
+
spec_name=spec.name,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class BaremetalProvider:
|
|
47
|
+
"""Baremetal implementation of TargetProvider.
|
|
48
|
+
|
|
49
|
+
Baremetal has no cloud API. list_targets returns empty (no remote state
|
|
50
|
+
to query). Use target_from_ssh_spec() to build a Target from a spec
|
|
51
|
+
when you already know which spec you want.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
async def list_targets(self) -> list[Target]:
|
|
55
|
+
"""Baremetal has no API to list. Returns empty."""
|
|
56
|
+
return []
|
|
57
|
+
|
|
58
|
+
async def get_target(self, resource_id: str) -> Target | None:
|
|
59
|
+
"""Baremetal has no API to query. Returns None."""
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
async def provision(self, spec: TargetSpec) -> Target:
|
|
63
|
+
"""Baremetal targets cannot be provisioned — they already exist."""
|
|
64
|
+
assert isinstance(spec, (BaremetalTarget, VMTarget)), (
|
|
65
|
+
f"BaremetalProvider.provision requires BaremetalTarget or VMTarget, "
|
|
66
|
+
f"got {type(spec).__name__}"
|
|
67
|
+
)
|
|
68
|
+
return target_from_ssh_spec(spec)
|
|
69
|
+
|
|
70
|
+
async def terminate(self, resource_id: str) -> bool:
|
|
71
|
+
"""Baremetal targets cannot be terminated via API."""
|
|
72
|
+
return False
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""DigitalOcean provider — adapts existing DO REST API to TargetProvider protocol."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
|
|
9
|
+
from wafer_core.targets.digitalocean import (
|
|
10
|
+
DigitalOceanError,
|
|
11
|
+
_api_request_async,
|
|
12
|
+
_get_ssh_key_ids,
|
|
13
|
+
_wait_for_ssh,
|
|
14
|
+
)
|
|
15
|
+
from wafer_core.targets.digitalocean import (
|
|
16
|
+
terminate_droplet as _terminate_droplet,
|
|
17
|
+
)
|
|
18
|
+
from wafer_core.targets.types import Target, TargetSpec
|
|
19
|
+
from wafer_core.utils.kernel_utils.targets.config import DigitalOceanTarget
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _parse_droplet_to_target(droplet: dict) -> Target:
|
|
25
|
+
"""Parse a DigitalOcean API droplet response into a Target."""
|
|
26
|
+
droplet_id = str(droplet.get("id", ""))
|
|
27
|
+
droplet_name = droplet.get("name", "")
|
|
28
|
+
status_raw = droplet.get("status", "").lower()
|
|
29
|
+
|
|
30
|
+
# Map DO statuses to our values
|
|
31
|
+
# DO: new, active, off, archive
|
|
32
|
+
status_map = {
|
|
33
|
+
"new": "pending",
|
|
34
|
+
"active": "running",
|
|
35
|
+
"off": "stopped",
|
|
36
|
+
"archive": "terminated",
|
|
37
|
+
}
|
|
38
|
+
status = status_map.get(status_raw, status_raw)
|
|
39
|
+
|
|
40
|
+
# Extract public IP
|
|
41
|
+
public_ip = None
|
|
42
|
+
networks = droplet.get("networks", {})
|
|
43
|
+
for net in networks.get("v4", []):
|
|
44
|
+
if net.get("type") == "public":
|
|
45
|
+
public_ip = net.get("ip_address")
|
|
46
|
+
break
|
|
47
|
+
|
|
48
|
+
# Infer spec_name from naming convention: wafer-{spec_name}-{timestamp}
|
|
49
|
+
spec_name = None
|
|
50
|
+
if droplet_name.startswith("wafer-"):
|
|
51
|
+
parts = droplet_name.split("-")
|
|
52
|
+
if len(parts) >= 3:
|
|
53
|
+
spec_name = "-".join(parts[1:-1])
|
|
54
|
+
|
|
55
|
+
created_at = droplet.get("created_at")
|
|
56
|
+
|
|
57
|
+
# Extract GPU type from size slug
|
|
58
|
+
size = droplet.get("size", {})
|
|
59
|
+
size_slug = (
|
|
60
|
+
size.get("slug", "") if isinstance(size, dict) else str(droplet.get("size_slug", ""))
|
|
61
|
+
)
|
|
62
|
+
gpu_type = "MI300X" if "mi300x" in size_slug.lower() else "unknown"
|
|
63
|
+
|
|
64
|
+
return Target(
|
|
65
|
+
resource_id=droplet_id,
|
|
66
|
+
provider="digitalocean",
|
|
67
|
+
status=status,
|
|
68
|
+
public_ip=public_ip,
|
|
69
|
+
ssh_port=22,
|
|
70
|
+
ssh_username="root",
|
|
71
|
+
gpu_type=gpu_type,
|
|
72
|
+
name=droplet_name or None,
|
|
73
|
+
created_at=created_at,
|
|
74
|
+
spec_name=spec_name,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class DigitalOceanProvider:
|
|
79
|
+
"""DigitalOcean implementation of TargetProvider.
|
|
80
|
+
|
|
81
|
+
Wraps existing REST API calls for droplet management.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
async def list_targets(self) -> list[Target]:
|
|
85
|
+
"""List all droplets on the DigitalOcean account."""
|
|
86
|
+
try:
|
|
87
|
+
response = await _api_request_async("GET", "/droplets", params={"per_page": "200"})
|
|
88
|
+
except DigitalOceanError:
|
|
89
|
+
raise
|
|
90
|
+
except Exception as e:
|
|
91
|
+
logger.warning(f"Failed to list DigitalOcean droplets: {e}")
|
|
92
|
+
return []
|
|
93
|
+
|
|
94
|
+
droplets = response.get("droplets", [])
|
|
95
|
+
return [_parse_droplet_to_target(d) for d in droplets]
|
|
96
|
+
|
|
97
|
+
async def get_target(self, resource_id: str) -> Target | None:
|
|
98
|
+
"""Get a specific droplet by ID."""
|
|
99
|
+
try:
|
|
100
|
+
response = await _api_request_async("GET", f"/droplets/{resource_id}")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.warning(f"Failed to get DigitalOcean droplet {resource_id}: {e}")
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
droplet = response.get("droplet")
|
|
106
|
+
if not droplet:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
return _parse_droplet_to_target(droplet)
|
|
110
|
+
|
|
111
|
+
async def provision(self, spec: TargetSpec) -> Target:
|
|
112
|
+
"""Provision a new DigitalOcean droplet from a spec.
|
|
113
|
+
|
|
114
|
+
Blocks until SSH is ready.
|
|
115
|
+
"""
|
|
116
|
+
assert isinstance(spec, DigitalOceanTarget), (
|
|
117
|
+
f"DigitalOceanProvider.provision requires DigitalOceanTarget, got {type(spec).__name__}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
droplet_name = f"wafer-{spec.name}-{int(time.time())}"
|
|
121
|
+
|
|
122
|
+
ssh_key_ids = await _get_ssh_key_ids()
|
|
123
|
+
if not ssh_key_ids:
|
|
124
|
+
logger.warning("No SSH keys found - droplet may not be accessible")
|
|
125
|
+
|
|
126
|
+
create_data = {
|
|
127
|
+
"name": droplet_name,
|
|
128
|
+
"region": spec.region,
|
|
129
|
+
"size": spec.size_slug,
|
|
130
|
+
"image": spec.image,
|
|
131
|
+
"ssh_keys": ssh_key_ids,
|
|
132
|
+
"backups": False,
|
|
133
|
+
"ipv6": True,
|
|
134
|
+
"monitoring": True,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
logger.info(f"Provisioning DigitalOcean droplet: {droplet_name}")
|
|
138
|
+
response = await _api_request_async("POST", "/droplets", data=create_data)
|
|
139
|
+
|
|
140
|
+
if not response or "droplet" not in response:
|
|
141
|
+
raise DigitalOceanError(f"Failed to create droplet: {response}")
|
|
142
|
+
|
|
143
|
+
droplet = response["droplet"]
|
|
144
|
+
droplet_id = str(droplet["id"])
|
|
145
|
+
logger.info(f"Droplet created: {droplet_id}")
|
|
146
|
+
|
|
147
|
+
public_ip = await _wait_for_ssh(droplet_id, spec.provision_timeout)
|
|
148
|
+
|
|
149
|
+
return Target(
|
|
150
|
+
resource_id=droplet_id,
|
|
151
|
+
provider="digitalocean",
|
|
152
|
+
status="running",
|
|
153
|
+
public_ip=public_ip,
|
|
154
|
+
ssh_port=22,
|
|
155
|
+
ssh_username="root",
|
|
156
|
+
gpu_type=spec.gpu_type,
|
|
157
|
+
name=droplet_name,
|
|
158
|
+
created_at=datetime.now(timezone.utc).isoformat(),
|
|
159
|
+
spec_name=spec.name,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def terminate(self, resource_id: str) -> bool:
|
|
163
|
+
"""Terminate a DigitalOcean droplet."""
|
|
164
|
+
return await _terminate_droplet(resource_id)
|