wafer-cli 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +118 -0
- wafer/__init__.py +3 -0
- wafer/analytics.py +306 -0
- wafer/api_client.py +195 -0
- wafer/auth.py +432 -0
- wafer/autotuner.py +1080 -0
- wafer/billing.py +233 -0
- wafer/cli.py +7289 -0
- wafer/config.py +105 -0
- wafer/corpus.py +366 -0
- wafer/evaluate.py +4593 -0
- wafer/global_config.py +350 -0
- wafer/gpu_run.py +307 -0
- wafer/inference.py +148 -0
- wafer/kernel_scope.py +552 -0
- wafer/ncu_analyze.py +651 -0
- wafer/nsys_analyze.py +1042 -0
- wafer/nsys_profile.py +510 -0
- wafer/output.py +248 -0
- wafer/problems.py +357 -0
- wafer/rocprof_compute.py +490 -0
- wafer/rocprof_sdk.py +274 -0
- wafer/rocprof_systems.py +520 -0
- wafer/skills/wafer-guide/SKILL.md +129 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +842 -0
- wafer/targets_ops.py +717 -0
- wafer/templates/__init__.py +0 -0
- wafer/templates/ask_docs.py +61 -0
- wafer/templates/optimize_kernel.py +71 -0
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/templates/trace_analyze.py +74 -0
- wafer/tracelens.py +218 -0
- wafer/wevin_cli.py +577 -0
- wafer/workspaces.py +852 -0
- wafer_cli-0.2.14.dist-info/METADATA +16 -0
- wafer_cli-0.2.14.dist-info/RECORD +41 -0
- wafer_cli-0.2.14.dist-info/WHEEL +5 -0
- wafer_cli-0.2.14.dist-info/entry_points.txt +2 -0
- wafer_cli-0.2.14.dist-info/top_level.txt +1 -0
wafer/targets.py
ADDED
|
@@ -0,0 +1,842 @@
|
|
|
1
|
+
"""Target management for Wafer CLI.
|
|
2
|
+
|
|
3
|
+
CRUD operations for GPU targets stored in ~/.wafer/targets/.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import tomllib
|
|
7
|
+
from dataclasses import asdict
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
12
|
+
BaremetalTarget,
|
|
13
|
+
DigitalOceanTarget,
|
|
14
|
+
ModalTarget,
|
|
15
|
+
RunPodTarget,
|
|
16
|
+
TargetConfig,
|
|
17
|
+
VMTarget,
|
|
18
|
+
WorkspaceTarget,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Default paths
|
|
22
|
+
WAFER_DIR = Path.home() / ".wafer"
|
|
23
|
+
TARGETS_DIR = WAFER_DIR / "targets"
|
|
24
|
+
CONFIG_FILE = WAFER_DIR / "config.toml"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _ensure_dirs() -> None:
|
|
28
|
+
"""Ensure ~/.wafer/targets/ exists."""
|
|
29
|
+
TARGETS_DIR.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _target_path(name: str) -> Path:
|
|
33
|
+
"""Get path to target config file."""
|
|
34
|
+
return TARGETS_DIR / f"{name}.toml"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _parse_target(data: dict[str, Any]) -> TargetConfig:
|
|
38
|
+
"""Parse TOML dict into target dataclass.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
data: Parsed TOML data
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
TargetConfig (BaremetalTarget, VMTarget, ModalTarget, or WorkspaceTarget)
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If target type is unknown or required fields missing
|
|
48
|
+
"""
|
|
49
|
+
target_type = data.get("type")
|
|
50
|
+
if not target_type:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Target must have 'type' field (baremetal, vm, modal, workspace, runpod, or digitalocean)"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Remove type field before passing to dataclass
|
|
56
|
+
data_copy = {k: v for k, v in data.items() if k != "type"}
|
|
57
|
+
|
|
58
|
+
# Convert pip_packages list to tuple (TOML parses as list, dataclass expects tuple)
|
|
59
|
+
if "pip_packages" in data_copy and isinstance(data_copy["pip_packages"], list):
|
|
60
|
+
data_copy["pip_packages"] = tuple(data_copy["pip_packages"])
|
|
61
|
+
|
|
62
|
+
# Convert gpu_ids list to tuple for RunPodTarget
|
|
63
|
+
if "gpu_ids" in data_copy and isinstance(data_copy["gpu_ids"], list):
|
|
64
|
+
data_copy["gpu_ids"] = tuple(data_copy["gpu_ids"])
|
|
65
|
+
|
|
66
|
+
if target_type == "baremetal":
|
|
67
|
+
return BaremetalTarget(**data_copy)
|
|
68
|
+
elif target_type == "vm":
|
|
69
|
+
return VMTarget(**data_copy)
|
|
70
|
+
elif target_type == "modal":
|
|
71
|
+
return ModalTarget(**data_copy)
|
|
72
|
+
elif target_type == "workspace":
|
|
73
|
+
return WorkspaceTarget(**data_copy)
|
|
74
|
+
elif target_type == "runpod":
|
|
75
|
+
return RunPodTarget(**data_copy)
|
|
76
|
+
elif target_type == "digitalocean":
|
|
77
|
+
return DigitalOceanTarget(**data_copy)
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Unknown target type: {target_type}. Must be baremetal, vm, modal, workspace, runpod, or digitalocean"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _serialize_target(target: TargetConfig) -> dict[str, Any]:
|
|
85
|
+
"""Serialize target dataclass to TOML-compatible dict.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
target: Target config
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Dict with 'type' field added
|
|
92
|
+
"""
|
|
93
|
+
data = asdict(target)
|
|
94
|
+
|
|
95
|
+
# Add type field
|
|
96
|
+
if isinstance(target, BaremetalTarget):
|
|
97
|
+
data["type"] = "baremetal"
|
|
98
|
+
elif isinstance(target, VMTarget):
|
|
99
|
+
data["type"] = "vm"
|
|
100
|
+
elif isinstance(target, ModalTarget):
|
|
101
|
+
data["type"] = "modal"
|
|
102
|
+
elif isinstance(target, WorkspaceTarget):
|
|
103
|
+
data["type"] = "workspace"
|
|
104
|
+
elif isinstance(target, RunPodTarget):
|
|
105
|
+
data["type"] = "runpod"
|
|
106
|
+
elif isinstance(target, DigitalOceanTarget):
|
|
107
|
+
data["type"] = "digitalocean"
|
|
108
|
+
|
|
109
|
+
# Convert pip_packages tuple to list for TOML serialization
|
|
110
|
+
if "pip_packages" in data and isinstance(data["pip_packages"], tuple):
|
|
111
|
+
data["pip_packages"] = list(data["pip_packages"])
|
|
112
|
+
|
|
113
|
+
# Convert gpu_ids tuple to list for TOML serialization
|
|
114
|
+
if "gpu_ids" in data and isinstance(data["gpu_ids"], tuple):
|
|
115
|
+
data["gpu_ids"] = list(data["gpu_ids"])
|
|
116
|
+
|
|
117
|
+
# Remove empty pip_packages to keep config clean
|
|
118
|
+
if "pip_packages" in data and not data["pip_packages"]:
|
|
119
|
+
del data["pip_packages"]
|
|
120
|
+
|
|
121
|
+
return data
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _write_toml(path: Path, data: dict[str, Any]) -> None:
|
|
125
|
+
"""Write dict as TOML file.
|
|
126
|
+
|
|
127
|
+
Simple TOML writer - handles flat dicts and lists.
|
|
128
|
+
"""
|
|
129
|
+
lines = []
|
|
130
|
+
for key, value in data.items():
|
|
131
|
+
if value is None:
|
|
132
|
+
continue # Skip None values
|
|
133
|
+
if isinstance(value, bool):
|
|
134
|
+
lines.append(f"{key} = {str(value).lower()}")
|
|
135
|
+
elif isinstance(value, int | float):
|
|
136
|
+
lines.append(f"{key} = {value}")
|
|
137
|
+
elif isinstance(value, str):
|
|
138
|
+
lines.append(f'{key} = "{value}"')
|
|
139
|
+
elif isinstance(value, list):
|
|
140
|
+
# Format list
|
|
141
|
+
if all(isinstance(v, int) for v in value):
|
|
142
|
+
lines.append(f"{key} = {value}")
|
|
143
|
+
else:
|
|
144
|
+
formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
|
|
145
|
+
lines.append(f"{key} = [{formatted}]")
|
|
146
|
+
|
|
147
|
+
path.write_text("\n".join(lines) + "\n")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def load_target(name: str) -> TargetConfig:
|
|
151
|
+
"""Load target config by name.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
name: Target name (filename without .toml)
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Target config
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
FileNotFoundError: If target doesn't exist
|
|
161
|
+
ValueError: If target config is invalid
|
|
162
|
+
"""
|
|
163
|
+
path = _target_path(name)
|
|
164
|
+
if not path.exists():
|
|
165
|
+
raise FileNotFoundError(f"Target not found: {name} (looked in {path})")
|
|
166
|
+
|
|
167
|
+
with open(path, "rb") as f:
|
|
168
|
+
data = tomllib.load(f)
|
|
169
|
+
|
|
170
|
+
return _parse_target(data)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def save_target(target: TargetConfig | dict[str, Any]) -> TargetConfig:
|
|
174
|
+
"""Save target config.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
target: Target config (TargetConfig object or dict with 'type' field)
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
The saved TargetConfig object
|
|
181
|
+
|
|
182
|
+
Creates ~/.wafer/targets/{name}.toml
|
|
183
|
+
"""
|
|
184
|
+
_ensure_dirs()
|
|
185
|
+
|
|
186
|
+
# If dict, parse into TargetConfig first
|
|
187
|
+
if isinstance(target, dict):
|
|
188
|
+
target = _parse_target(target)
|
|
189
|
+
|
|
190
|
+
data = _serialize_target(target)
|
|
191
|
+
path = _target_path(target.name)
|
|
192
|
+
_write_toml(path, data)
|
|
193
|
+
return target
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def add_target_from_file(file_path: Path) -> TargetConfig:
|
|
197
|
+
"""Add target from TOML file.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
file_path: Path to TOML file
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Parsed and saved target
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
FileNotFoundError: If file doesn't exist
|
|
207
|
+
ValueError: If file is invalid
|
|
208
|
+
"""
|
|
209
|
+
if not file_path.exists():
|
|
210
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
211
|
+
|
|
212
|
+
with open(file_path, "rb") as f:
|
|
213
|
+
data = tomllib.load(f)
|
|
214
|
+
|
|
215
|
+
target = _parse_target(data)
|
|
216
|
+
save_target(target)
|
|
217
|
+
return target
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def list_targets() -> list[str]:
|
|
221
|
+
"""List all configured target names.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Sorted list of target names
|
|
225
|
+
"""
|
|
226
|
+
_ensure_dirs()
|
|
227
|
+
return sorted(p.stem for p in TARGETS_DIR.glob("*.toml"))
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def remove_target(name: str) -> None:
|
|
231
|
+
"""Remove target config.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
name: Target name to remove
|
|
235
|
+
|
|
236
|
+
Raises:
|
|
237
|
+
FileNotFoundError: If target doesn't exist
|
|
238
|
+
"""
|
|
239
|
+
path = _target_path(name)
|
|
240
|
+
if not path.exists():
|
|
241
|
+
raise FileNotFoundError(f"Target not found: {name}")
|
|
242
|
+
path.unlink()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def get_default_target() -> str | None:
|
|
246
|
+
"""Get default target name from config.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Default target name, or None if not set
|
|
250
|
+
"""
|
|
251
|
+
if not CONFIG_FILE.exists():
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
with open(CONFIG_FILE, "rb") as f:
|
|
255
|
+
data = tomllib.load(f)
|
|
256
|
+
|
|
257
|
+
return data.get("default_target")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# ── Pool Management ─────────────────────────────────────────────────────────
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def get_pool(name: str) -> list[str]:
|
|
264
|
+
"""Get list of targets in a named pool.
|
|
265
|
+
|
|
266
|
+
Pools are defined in ~/.wafer/config.toml:
|
|
267
|
+
[pools.my-pool]
|
|
268
|
+
targets = ["target-1", "target-2", "target-3"]
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
name: Pool name
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
List of target names in the pool
|
|
275
|
+
|
|
276
|
+
Raises:
|
|
277
|
+
FileNotFoundError: If pool doesn't exist
|
|
278
|
+
"""
|
|
279
|
+
if not CONFIG_FILE.exists():
|
|
280
|
+
raise FileNotFoundError(f"Pool not found: {name} (no config file)")
|
|
281
|
+
|
|
282
|
+
with open(CONFIG_FILE, "rb") as f:
|
|
283
|
+
data = tomllib.load(f)
|
|
284
|
+
|
|
285
|
+
pools = data.get("pools", {})
|
|
286
|
+
if name not in pools:
|
|
287
|
+
raise FileNotFoundError(
|
|
288
|
+
f"Pool not found: {name}\n"
|
|
289
|
+
f" Define pools in ~/.wafer/config.toml:\n"
|
|
290
|
+
f" [pools.{name}]\n"
|
|
291
|
+
f' targets = ["target-1", "target-2"]'
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
pool_config = pools[name]
|
|
295
|
+
targets = pool_config.get("targets", [])
|
|
296
|
+
|
|
297
|
+
if not targets:
|
|
298
|
+
raise ValueError(f"Pool '{name}' has no targets defined")
|
|
299
|
+
|
|
300
|
+
return targets
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def get_target_type(name: str) -> str | None:
|
|
304
|
+
"""Get the type of a target without fully loading it.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
name: Target name
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Target type string (runpod, digitalocean, baremetal, etc.) or None if not found
|
|
311
|
+
"""
|
|
312
|
+
path = _target_path(name)
|
|
313
|
+
if not path.exists():
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
with open(path, "rb") as f:
|
|
317
|
+
data = tomllib.load(f)
|
|
318
|
+
|
|
319
|
+
return data.get("type")
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def filter_pool_by_auth(target_names: list[str]) -> tuple[list[str], list[str]]:
|
|
323
|
+
"""Filter pool targets to only those with valid authentication.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
target_names: List of target names to filter
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Tuple of (usable_targets, skipped_targets)
|
|
330
|
+
"""
|
|
331
|
+
from wafer_core.auth import get_api_key
|
|
332
|
+
|
|
333
|
+
usable = []
|
|
334
|
+
skipped = []
|
|
335
|
+
|
|
336
|
+
for name in target_names:
|
|
337
|
+
target_type = get_target_type(name)
|
|
338
|
+
if target_type is None:
|
|
339
|
+
# Target doesn't exist, skip it
|
|
340
|
+
skipped.append(name)
|
|
341
|
+
continue
|
|
342
|
+
|
|
343
|
+
# Check auth requirements by target type
|
|
344
|
+
if target_type == "runpod":
|
|
345
|
+
if not get_api_key("runpod"):
|
|
346
|
+
skipped.append(name)
|
|
347
|
+
continue
|
|
348
|
+
elif target_type == "digitalocean":
|
|
349
|
+
if not get_api_key("digitalocean"):
|
|
350
|
+
skipped.append(name)
|
|
351
|
+
continue
|
|
352
|
+
# Other types (baremetal, vm, workspace, modal) don't need runtime API keys
|
|
353
|
+
|
|
354
|
+
usable.append(name)
|
|
355
|
+
|
|
356
|
+
return usable, skipped
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def list_pools() -> list[str]:
|
|
360
|
+
"""List all configured pool names.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Sorted list of pool names
|
|
364
|
+
"""
|
|
365
|
+
if not CONFIG_FILE.exists():
|
|
366
|
+
return []
|
|
367
|
+
|
|
368
|
+
with open(CONFIG_FILE, "rb") as f:
|
|
369
|
+
data = tomllib.load(f)
|
|
370
|
+
|
|
371
|
+
return sorted(data.get("pools", {}).keys())
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def save_pool(name: str, targets: list[str]) -> None:
|
|
375
|
+
"""Save or update a pool configuration.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
name: Pool name
|
|
379
|
+
targets: List of target names (must all exist)
|
|
380
|
+
|
|
381
|
+
Raises:
|
|
382
|
+
FileNotFoundError: If any target doesn't exist
|
|
383
|
+
"""
|
|
384
|
+
# Verify all targets exist
|
|
385
|
+
existing_targets = list_targets()
|
|
386
|
+
missing = [t for t in targets if t not in existing_targets]
|
|
387
|
+
if missing:
|
|
388
|
+
raise FileNotFoundError(f"Targets not found: {', '.join(missing)}")
|
|
389
|
+
|
|
390
|
+
_ensure_dirs()
|
|
391
|
+
|
|
392
|
+
# Load existing config
|
|
393
|
+
if CONFIG_FILE.exists():
|
|
394
|
+
with open(CONFIG_FILE, "rb") as f:
|
|
395
|
+
data = tomllib.load(f)
|
|
396
|
+
else:
|
|
397
|
+
data = {}
|
|
398
|
+
|
|
399
|
+
# Update pools section
|
|
400
|
+
if "pools" not in data:
|
|
401
|
+
data["pools"] = {}
|
|
402
|
+
|
|
403
|
+
data["pools"][name] = {"targets": targets}
|
|
404
|
+
|
|
405
|
+
# Write back - need custom handling for nested structure
|
|
406
|
+
_write_config_with_pools(data)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _write_config_with_pools(data: dict) -> None:
|
|
410
|
+
"""Write config file with pools support.
|
|
411
|
+
|
|
412
|
+
Handles the nested [pools.name] TOML structure and preserves
|
|
413
|
+
existing nested sections like [default], [api], [environments.*].
|
|
414
|
+
"""
|
|
415
|
+
lines = []
|
|
416
|
+
|
|
417
|
+
# Collect nested sections to write after top-level keys
|
|
418
|
+
nested_sections: dict[str, dict] = {}
|
|
419
|
+
|
|
420
|
+
# Write top-level keys first (except pools and nested dicts)
|
|
421
|
+
for key, value in data.items():
|
|
422
|
+
if key == "pools":
|
|
423
|
+
continue
|
|
424
|
+
if value is None:
|
|
425
|
+
continue
|
|
426
|
+
if isinstance(value, dict):
|
|
427
|
+
# Save nested sections for later
|
|
428
|
+
nested_sections[key] = value
|
|
429
|
+
elif isinstance(value, str):
|
|
430
|
+
lines.append(f'{key} = "{value}"')
|
|
431
|
+
elif isinstance(value, bool):
|
|
432
|
+
lines.append(f"{key} = {str(value).lower()}")
|
|
433
|
+
elif isinstance(value, int | float):
|
|
434
|
+
lines.append(f"{key} = {value}")
|
|
435
|
+
elif isinstance(value, list):
|
|
436
|
+
if all(isinstance(v, int) for v in value):
|
|
437
|
+
lines.append(f"{key} = {value}")
|
|
438
|
+
else:
|
|
439
|
+
formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
|
|
440
|
+
lines.append(f"{key} = [{formatted}]")
|
|
441
|
+
|
|
442
|
+
# Write nested sections (e.g., [default], [api], [environments.foo])
|
|
443
|
+
for section_name, section_data in nested_sections.items():
|
|
444
|
+
lines.append("")
|
|
445
|
+
lines.append(f"[{section_name}]")
|
|
446
|
+
for key, value in section_data.items():
|
|
447
|
+
if value is None:
|
|
448
|
+
continue
|
|
449
|
+
if isinstance(value, str):
|
|
450
|
+
lines.append(f'{key} = "{value}"')
|
|
451
|
+
elif isinstance(value, bool):
|
|
452
|
+
lines.append(f"{key} = {str(value).lower()}")
|
|
453
|
+
elif isinstance(value, int | float):
|
|
454
|
+
lines.append(f"{key} = {value}")
|
|
455
|
+
elif isinstance(value, list):
|
|
456
|
+
if all(isinstance(v, int) for v in value):
|
|
457
|
+
lines.append(f"{key} = {value}")
|
|
458
|
+
else:
|
|
459
|
+
formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
|
|
460
|
+
lines.append(f"{key} = [{formatted}]")
|
|
461
|
+
|
|
462
|
+
# Write pools
|
|
463
|
+
pools = data.get("pools", {})
|
|
464
|
+
for pool_name, pool_config in pools.items():
|
|
465
|
+
lines.append("")
|
|
466
|
+
lines.append(f"[pools.{pool_name}]")
|
|
467
|
+
targets = pool_config.get("targets", [])
|
|
468
|
+
formatted = ", ".join(f'"{t}"' for t in targets)
|
|
469
|
+
lines.append(f"targets = [{formatted}]")
|
|
470
|
+
|
|
471
|
+
CONFIG_FILE.write_text("\n".join(lines) + "\n")
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def set_default_target(name: str) -> None:
|
|
475
|
+
"""Set default target.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
name: Target name (must exist)
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
FileNotFoundError: If target doesn't exist
|
|
482
|
+
"""
|
|
483
|
+
# Verify target exists
|
|
484
|
+
if name not in list_targets():
|
|
485
|
+
raise FileNotFoundError(f"Target not found: {name}")
|
|
486
|
+
|
|
487
|
+
_ensure_dirs()
|
|
488
|
+
|
|
489
|
+
# Load existing config or create new
|
|
490
|
+
if CONFIG_FILE.exists():
|
|
491
|
+
with open(CONFIG_FILE, "rb") as f:
|
|
492
|
+
data = tomllib.load(f)
|
|
493
|
+
else:
|
|
494
|
+
data = {}
|
|
495
|
+
|
|
496
|
+
data["default_target"] = name
|
|
497
|
+
|
|
498
|
+
# Write back (simple TOML)
|
|
499
|
+
_write_toml(CONFIG_FILE, data)
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def get_target_info(target: TargetConfig) -> dict[str, str]:
|
|
503
|
+
"""Get human-readable info about target.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
target: Target config
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
Dict of field name -> display value
|
|
510
|
+
"""
|
|
511
|
+
info = {}
|
|
512
|
+
|
|
513
|
+
if isinstance(target, BaremetalTarget):
|
|
514
|
+
info["Type"] = "Baremetal"
|
|
515
|
+
info["SSH"] = target.ssh_target
|
|
516
|
+
info["GPUs"] = ", ".join(str(g) for g in target.gpu_ids)
|
|
517
|
+
info["NCU"] = "Yes" if target.ncu_available else "No"
|
|
518
|
+
# Docker info
|
|
519
|
+
if target.docker_image:
|
|
520
|
+
info["Docker"] = target.docker_image
|
|
521
|
+
if target.pip_packages:
|
|
522
|
+
info["Packages"] = ", ".join(target.pip_packages)
|
|
523
|
+
if target.torch_package:
|
|
524
|
+
info["Torch"] = target.torch_package
|
|
525
|
+
elif isinstance(target, VMTarget):
|
|
526
|
+
info["Type"] = "VM"
|
|
527
|
+
info["SSH"] = target.ssh_target
|
|
528
|
+
info["GPUs"] = ", ".join(str(g) for g in target.gpu_ids)
|
|
529
|
+
info["NCU"] = "Yes" if target.ncu_available else "No"
|
|
530
|
+
# Docker info
|
|
531
|
+
if target.docker_image:
|
|
532
|
+
info["Docker"] = target.docker_image
|
|
533
|
+
if target.pip_packages:
|
|
534
|
+
info["Packages"] = ", ".join(target.pip_packages)
|
|
535
|
+
if target.torch_package:
|
|
536
|
+
info["Torch"] = target.torch_package
|
|
537
|
+
elif isinstance(target, ModalTarget):
|
|
538
|
+
info["Type"] = "Modal"
|
|
539
|
+
info["App"] = target.modal_app_name
|
|
540
|
+
info["GPU"] = target.gpu_type
|
|
541
|
+
info["Timeout"] = f"{target.timeout_seconds}s"
|
|
542
|
+
info["NCU"] = "No (Modal)"
|
|
543
|
+
elif isinstance(target, WorkspaceTarget):
|
|
544
|
+
info["Type"] = "Workspace"
|
|
545
|
+
info["Workspace ID"] = target.workspace_id
|
|
546
|
+
info["GPU"] = target.gpu_type
|
|
547
|
+
info["Timeout"] = f"{target.timeout_seconds}s"
|
|
548
|
+
info["NCU"] = "No (Workspace)"
|
|
549
|
+
elif isinstance(target, RunPodTarget):
|
|
550
|
+
info["Type"] = "RunPod"
|
|
551
|
+
info["GPU Type"] = target.gpu_type_id
|
|
552
|
+
info["GPU Count"] = str(target.gpu_count)
|
|
553
|
+
info["Image"] = target.image
|
|
554
|
+
info["Keep Alive"] = "Yes" if target.keep_alive else "No"
|
|
555
|
+
info["NCU"] = "No (RunPod)"
|
|
556
|
+
elif isinstance(target, DigitalOceanTarget):
|
|
557
|
+
info["Type"] = "DigitalOcean"
|
|
558
|
+
info["Region"] = target.region
|
|
559
|
+
info["Size"] = target.size_slug
|
|
560
|
+
info["Image"] = target.image
|
|
561
|
+
info["Keep Alive"] = "Yes" if target.keep_alive else "No"
|
|
562
|
+
info["NCU"] = "No (DigitalOcean)"
|
|
563
|
+
|
|
564
|
+
info["Compute"] = target.compute_capability
|
|
565
|
+
|
|
566
|
+
return info
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
# Probe script to run on target - checks available backends
|
|
570
|
+
_PROBE_SCRIPT = """
|
|
571
|
+
import json
|
|
572
|
+
import shutil
|
|
573
|
+
import sys
|
|
574
|
+
|
|
575
|
+
def probe():
|
|
576
|
+
result = {
|
|
577
|
+
"python_version": sys.version.split()[0],
|
|
578
|
+
"backends": {},
|
|
579
|
+
"packages": {},
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
# Check Triton
|
|
583
|
+
try:
|
|
584
|
+
import triton
|
|
585
|
+
result["backends"]["triton"] = triton.__version__
|
|
586
|
+
except ImportError:
|
|
587
|
+
result["backends"]["triton"] = None
|
|
588
|
+
|
|
589
|
+
# Check torch
|
|
590
|
+
try:
|
|
591
|
+
import torch
|
|
592
|
+
result["packages"]["torch"] = torch.__version__
|
|
593
|
+
result["backends"]["torch"] = torch.__version__
|
|
594
|
+
result["cuda_available"] = torch.cuda.is_available()
|
|
595
|
+
if torch.cuda.is_available():
|
|
596
|
+
result["gpu_name"] = torch.cuda.get_device_name(0)
|
|
597
|
+
props = torch.cuda.get_device_properties(0)
|
|
598
|
+
result["compute_capability"] = f"{props.major}.{props.minor}"
|
|
599
|
+
except ImportError:
|
|
600
|
+
result["packages"]["torch"] = None
|
|
601
|
+
|
|
602
|
+
# Check hipcc (AMD)
|
|
603
|
+
hipcc = shutil.which("hipcc")
|
|
604
|
+
result["backends"]["hipcc"] = hipcc
|
|
605
|
+
|
|
606
|
+
# Check nvcc (NVIDIA)
|
|
607
|
+
nvcc = shutil.which("nvcc")
|
|
608
|
+
result["backends"]["nvcc"] = nvcc
|
|
609
|
+
|
|
610
|
+
# Check ROCm version
|
|
611
|
+
try:
|
|
612
|
+
with open("/opt/rocm/.info/version", "r") as f:
|
|
613
|
+
result["rocm_version"] = f.read().strip()
|
|
614
|
+
except Exception:
|
|
615
|
+
result["rocm_version"] = None
|
|
616
|
+
|
|
617
|
+
# Check CUDA version from nvcc
|
|
618
|
+
if nvcc:
|
|
619
|
+
import subprocess
|
|
620
|
+
try:
|
|
621
|
+
out = subprocess.check_output([nvcc, "--version"], text=True)
|
|
622
|
+
for line in out.split("\\n"):
|
|
623
|
+
if "release" in line.lower():
|
|
624
|
+
# Parse "Cuda compilation tools, release 12.1, V12.1.105"
|
|
625
|
+
parts = line.split("release")
|
|
626
|
+
if len(parts) > 1:
|
|
627
|
+
result["cuda_version"] = parts[1].split(",")[0].strip()
|
|
628
|
+
break
|
|
629
|
+
except Exception:
|
|
630
|
+
pass
|
|
631
|
+
|
|
632
|
+
print(json.dumps(result))
|
|
633
|
+
|
|
634
|
+
if __name__ == "__main__":
|
|
635
|
+
probe()
|
|
636
|
+
"""
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
class ProbeError(Exception):
|
|
640
|
+
"""Error during target probing with actionable context."""
|
|
641
|
+
|
|
642
|
+
pass
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
async def probe_target_capabilities(target: TargetConfig) -> dict[str, Any]:
|
|
646
|
+
"""Probe a target to discover available compilation backends.
|
|
647
|
+
|
|
648
|
+
Connects to the target and runs a probe script to check:
|
|
649
|
+
- Triton availability
|
|
650
|
+
- torch availability
|
|
651
|
+
- HIP/CUDA compiler
|
|
652
|
+
- ROCm/CUDA version
|
|
653
|
+
- GPU info
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
target: Target config
|
|
657
|
+
|
|
658
|
+
Returns:
|
|
659
|
+
Dict with capabilities info
|
|
660
|
+
|
|
661
|
+
Raises:
|
|
662
|
+
ProbeError: With actionable error message on failure
|
|
663
|
+
"""
|
|
664
|
+
import json
|
|
665
|
+
import subprocess
|
|
666
|
+
|
|
667
|
+
if isinstance(target, RunPodTarget):
|
|
668
|
+
import trio_asyncio
|
|
669
|
+
from wafer_core.targets.runpod import RunPodError, get_pod_state, runpod_ssh_context
|
|
670
|
+
|
|
671
|
+
# Check if pod exists before trying to connect
|
|
672
|
+
pod_state = get_pod_state(target.name)
|
|
673
|
+
|
|
674
|
+
try:
|
|
675
|
+
# Need trio_asyncio.open_loop() for asyncssh bridge used by runpod_ssh_context
|
|
676
|
+
async with trio_asyncio.open_loop():
|
|
677
|
+
async with runpod_ssh_context(target) as ssh_info:
|
|
678
|
+
ssh_target = f"{ssh_info.user}@{ssh_info.host}"
|
|
679
|
+
port = ssh_info.port
|
|
680
|
+
key_path = target.ssh_key
|
|
681
|
+
|
|
682
|
+
# Find Python and run probe using subprocess (simpler than async ssh)
|
|
683
|
+
def run_ssh_cmd(cmd: str) -> tuple[int, str, str]:
|
|
684
|
+
try:
|
|
685
|
+
result = subprocess.run(
|
|
686
|
+
[
|
|
687
|
+
"ssh",
|
|
688
|
+
"-o",
|
|
689
|
+
"StrictHostKeyChecking=no",
|
|
690
|
+
"-o",
|
|
691
|
+
"UserKnownHostsFile=/dev/null",
|
|
692
|
+
"-o",
|
|
693
|
+
"ConnectTimeout=30",
|
|
694
|
+
"-i",
|
|
695
|
+
str(key_path),
|
|
696
|
+
"-p",
|
|
697
|
+
str(port),
|
|
698
|
+
ssh_target,
|
|
699
|
+
cmd,
|
|
700
|
+
],
|
|
701
|
+
capture_output=True,
|
|
702
|
+
text=True,
|
|
703
|
+
timeout=60,
|
|
704
|
+
)
|
|
705
|
+
return result.returncode, result.stdout, result.stderr
|
|
706
|
+
except subprocess.TimeoutExpired:
|
|
707
|
+
raise ProbeError(
|
|
708
|
+
f"SSH connection timed out\n"
|
|
709
|
+
f" Host: {ssh_target}:{port}\n"
|
|
710
|
+
f" Hint: The pod may be starting up. Try again in 30 seconds."
|
|
711
|
+
) from None
|
|
712
|
+
|
|
713
|
+
# Find Python
|
|
714
|
+
python_exe = "python3"
|
|
715
|
+
for candidate in [
|
|
716
|
+
"/opt/conda/envs/py_3.10/bin/python3",
|
|
717
|
+
"/opt/conda/bin/python3",
|
|
718
|
+
]:
|
|
719
|
+
code, out, _ = run_ssh_cmd(f"{candidate} --version 2>/dev/null && echo OK")
|
|
720
|
+
if code == 0 and "OK" in out:
|
|
721
|
+
python_exe = candidate
|
|
722
|
+
break
|
|
723
|
+
|
|
724
|
+
# Run probe script
|
|
725
|
+
escaped_script = _PROBE_SCRIPT.replace("'", "'\"'\"'")
|
|
726
|
+
code, out, err = run_ssh_cmd(f"{python_exe} -c '{escaped_script}'")
|
|
727
|
+
if code != 0:
|
|
728
|
+
raise ProbeError(
|
|
729
|
+
f"Probe script failed on target\n"
|
|
730
|
+
f" Exit code: {code}\n"
|
|
731
|
+
f" Error: {err.strip() if err else 'unknown'}"
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
try:
|
|
735
|
+
return json.loads(out)
|
|
736
|
+
except json.JSONDecodeError as e:
|
|
737
|
+
raise ProbeError(
|
|
738
|
+
f"Failed to parse probe output\n Error: {e}\n Output: {out[:200]}..."
|
|
739
|
+
) from None
|
|
740
|
+
|
|
741
|
+
except RunPodError as e:
|
|
742
|
+
# RunPod API errors (provisioning, pod not found, etc.)
|
|
743
|
+
raise ProbeError(f"RunPod error for target '{target.name}'\n {e}") from None
|
|
744
|
+
except OSError as e:
|
|
745
|
+
# SSH connection errors
|
|
746
|
+
if pod_state:
|
|
747
|
+
raise ProbeError(
|
|
748
|
+
f"SSH connection failed to target '{target.name}'\n"
|
|
749
|
+
f" Host: {pod_state.ssh_username}@{pod_state.public_ip}:{pod_state.ssh_port}\n"
|
|
750
|
+
f" Error: {e}\n"
|
|
751
|
+
f" Hint: Check if the pod is still running with 'wafer config targets pods'"
|
|
752
|
+
) from None
|
|
753
|
+
raise ProbeError(
|
|
754
|
+
f"SSH connection failed to target '{target.name}'\n"
|
|
755
|
+
f" Error: {e}\n"
|
|
756
|
+
f" Hint: No pod found. One will be provisioned on next probe attempt."
|
|
757
|
+
) from None
|
|
758
|
+
|
|
759
|
+
elif isinstance(target, (BaremetalTarget, VMTarget)):
|
|
760
|
+
import subprocess
|
|
761
|
+
|
|
762
|
+
# Parse ssh_target (user@host:port or user@host)
|
|
763
|
+
ssh_target = target.ssh_target
|
|
764
|
+
if ":" in ssh_target.split("@")[-1]:
|
|
765
|
+
host_port = ssh_target.split("@")[-1]
|
|
766
|
+
host = host_port.rsplit(":", 1)[0]
|
|
767
|
+
port = host_port.rsplit(":", 1)[1]
|
|
768
|
+
user = ssh_target.split("@")[0]
|
|
769
|
+
ssh_target = f"{user}@{host}"
|
|
770
|
+
else:
|
|
771
|
+
host = ssh_target.split("@")[-1]
|
|
772
|
+
port = "22"
|
|
773
|
+
user = ssh_target.split("@")[0]
|
|
774
|
+
|
|
775
|
+
key_path = target.ssh_key
|
|
776
|
+
|
|
777
|
+
def run_ssh_cmd(cmd: str) -> tuple[int, str, str]:
|
|
778
|
+
try:
|
|
779
|
+
result = subprocess.run(
|
|
780
|
+
[
|
|
781
|
+
"ssh",
|
|
782
|
+
"-o",
|
|
783
|
+
"StrictHostKeyChecking=no",
|
|
784
|
+
"-o",
|
|
785
|
+
"UserKnownHostsFile=/dev/null",
|
|
786
|
+
"-o",
|
|
787
|
+
"ConnectTimeout=30",
|
|
788
|
+
"-i",
|
|
789
|
+
str(key_path),
|
|
790
|
+
"-p",
|
|
791
|
+
port,
|
|
792
|
+
ssh_target,
|
|
793
|
+
cmd,
|
|
794
|
+
],
|
|
795
|
+
capture_output=True,
|
|
796
|
+
text=True,
|
|
797
|
+
timeout=60,
|
|
798
|
+
)
|
|
799
|
+
return result.returncode, result.stdout, result.stderr
|
|
800
|
+
except subprocess.TimeoutExpired:
|
|
801
|
+
raise ProbeError(
|
|
802
|
+
f"SSH connection timed out\n"
|
|
803
|
+
f" Host: {ssh_target}:{port}\n"
|
|
804
|
+
f" Hint: Check if the host is reachable and SSH is running."
|
|
805
|
+
) from None
|
|
806
|
+
|
|
807
|
+
# Test SSH connection first
|
|
808
|
+
code, out, err = run_ssh_cmd("echo OK")
|
|
809
|
+
if code != 0:
|
|
810
|
+
raise ProbeError(
|
|
811
|
+
f"SSH connection failed to target '{target.name}'\n"
|
|
812
|
+
f" Host: {user}@{host}:{port}\n"
|
|
813
|
+
f" Key: {key_path}\n"
|
|
814
|
+
f" Error: {err.strip() if err else 'connection refused or timeout'}\n"
|
|
815
|
+
f" Hint: Verify the host is reachable and the SSH key is authorized."
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Run probe script
|
|
819
|
+
escaped_script = _PROBE_SCRIPT.replace("'", "'\"'\"'")
|
|
820
|
+
code, out, err = run_ssh_cmd(f"python3 -c '{escaped_script}'")
|
|
821
|
+
if code != 0:
|
|
822
|
+
raise ProbeError(
|
|
823
|
+
f"Probe script failed on target '{target.name}'\n"
|
|
824
|
+
f" Exit code: {code}\n"
|
|
825
|
+
f" Error: {err.strip() if err else 'unknown'}\n"
|
|
826
|
+
f" Hint: Ensure python3 is installed on the target."
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
try:
|
|
830
|
+
return json.loads(out)
|
|
831
|
+
except json.JSONDecodeError as e:
|
|
832
|
+
raise ProbeError(
|
|
833
|
+
f"Failed to parse probe output from '{target.name}'\n"
|
|
834
|
+
f" Error: {e}\n"
|
|
835
|
+
f" Output: {out[:200]}..."
|
|
836
|
+
) from None
|
|
837
|
+
|
|
838
|
+
else:
|
|
839
|
+
raise ProbeError(
|
|
840
|
+
f"Probing not supported for target type: {type(target).__name__}\n"
|
|
841
|
+
f" Supported types: RunPod, Baremetal, VM"
|
|
842
|
+
)
|