wafer-cli 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +18 -7
- wafer/api_client.py +4 -0
- wafer/cli.py +1177 -278
- wafer/corpus.py +158 -32
- wafer/evaluate.py +75 -6
- wafer/kernel_scope.py +132 -31
- wafer/nsys_analyze.py +903 -73
- wafer/nsys_profile.py +511 -0
- wafer/output.py +241 -0
- wafer/skills/wafer-guide/SKILL.md +13 -0
- wafer/ssh_keys.py +261 -0
- wafer/targets_ops.py +718 -0
- wafer/wevin_cli.py +127 -18
- wafer/workspaces.py +232 -184
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.10.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.10.dist-info}/RECORD +19 -15
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.10.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.10.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.10.dist-info}/top_level.txt +0 -0
wafer/targets_ops.py
ADDED
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
"""Target operations for exec/ssh/sync commands.
|
|
2
|
+
|
|
3
|
+
This module provides the business logic for running commands on targets,
|
|
4
|
+
getting SSH credentials, and syncing files. It handles:
|
|
5
|
+
- RunPod: Auto-provision pod, get SSH credentials
|
|
6
|
+
- DigitalOcean: Auto-provision droplet, get SSH credentials
|
|
7
|
+
- Baremetal/VM: Direct SSH with configured credentials
|
|
8
|
+
- Workspace: Delegate to workspace API
|
|
9
|
+
- Modal/Local: Not supported (no SSH access)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
import shlex
|
|
16
|
+
import subprocess
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from dataclasses import dataclass, replace
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
24
|
+
BaremetalTarget,
|
|
25
|
+
DigitalOceanTarget,
|
|
26
|
+
RunPodTarget,
|
|
27
|
+
TargetConfig,
|
|
28
|
+
VMTarget,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class TargetSSHInfo:
|
|
36
|
+
"""SSH connection info for a target."""
|
|
37
|
+
|
|
38
|
+
host: str
|
|
39
|
+
port: int
|
|
40
|
+
user: str
|
|
41
|
+
key_path: Path
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TargetExecError(Exception):
|
|
45
|
+
"""Error during target operation (exec/ssh/sync)."""
|
|
46
|
+
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _expand_key_path(ssh_key: str) -> Path:
|
|
51
|
+
"""Expand SSH key path (synchronous, fast operation)."""
|
|
52
|
+
return Path(ssh_key).expanduser()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _parse_ssh_target(ssh_target: str) -> tuple[str, str, int]:
|
|
56
|
+
"""Parse ssh_target string into (user, host, port).
|
|
57
|
+
|
|
58
|
+
Format: user@host:port
|
|
59
|
+
"""
|
|
60
|
+
# Split user@host:port
|
|
61
|
+
if "@" not in ssh_target:
|
|
62
|
+
raise ValueError(f"Invalid ssh_target format: {ssh_target} (expected user@host:port)")
|
|
63
|
+
|
|
64
|
+
user, rest = ssh_target.split("@", 1)
|
|
65
|
+
|
|
66
|
+
if ":" not in rest:
|
|
67
|
+
raise ValueError(f"Invalid ssh_target format: {ssh_target} (expected user@host:port)")
|
|
68
|
+
|
|
69
|
+
host, port_str = rest.rsplit(":", 1)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
port = int(port_str)
|
|
73
|
+
except ValueError as e:
|
|
74
|
+
raise ValueError(f"Invalid port in ssh_target: {port_str}") from e
|
|
75
|
+
|
|
76
|
+
return user, host, port
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
async def get_target_ssh_info(target: TargetConfig) -> TargetSSHInfo:
|
|
80
|
+
"""Get SSH connection info for a target.
|
|
81
|
+
|
|
82
|
+
For RunPod/DigitalOcean: Provisions if needed, returns SSH info.
|
|
83
|
+
For Baremetal/VM: Returns configured SSH info directly.
|
|
84
|
+
For Modal/Local/Workspace: Raises (no SSH access).
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
target: Target configuration
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
TargetSSHInfo with host, port, user, key_path
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
TargetExecError: If target type doesn't support SSH
|
|
94
|
+
"""
|
|
95
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
96
|
+
BaremetalTarget,
|
|
97
|
+
DigitalOceanTarget,
|
|
98
|
+
LocalTarget,
|
|
99
|
+
ModalTarget,
|
|
100
|
+
RunPodTarget,
|
|
101
|
+
VMTarget,
|
|
102
|
+
WorkspaceTarget,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if isinstance(target, RunPodTarget):
|
|
106
|
+
return await _get_runpod_ssh_info(target)
|
|
107
|
+
elif isinstance(target, DigitalOceanTarget):
|
|
108
|
+
return await _get_digitalocean_ssh_info(target)
|
|
109
|
+
elif isinstance(target, (BaremetalTarget, VMTarget)):
|
|
110
|
+
return _get_direct_ssh_info(target)
|
|
111
|
+
elif isinstance(target, WorkspaceTarget):
|
|
112
|
+
raise TargetExecError(
|
|
113
|
+
f"WorkspaceTarget '{target.name}' uses API-based access.\n"
|
|
114
|
+
"Use 'wafer workspaces exec/ssh/sync' instead."
|
|
115
|
+
)
|
|
116
|
+
elif isinstance(target, ModalTarget):
|
|
117
|
+
raise TargetExecError(
|
|
118
|
+
f"ModalTarget '{target.name}' is serverless and has no SSH access.\n"
|
|
119
|
+
"Use 'wafer evaluate' to run code on Modal targets."
|
|
120
|
+
)
|
|
121
|
+
elif isinstance(target, LocalTarget):
|
|
122
|
+
raise TargetExecError(
|
|
123
|
+
f"LocalTarget '{target.name}' runs locally and has no SSH.\n"
|
|
124
|
+
"Run commands directly on this machine."
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
raise TargetExecError(f"Unknown target type: {type(target).__name__}")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
async def _get_runpod_ssh_info(target: RunPodTarget) -> TargetSSHInfo:
|
|
131
|
+
"""Get SSH info for RunPod target, provisioning if needed."""
|
|
132
|
+
from wafer_core.targets.runpod import check_pod_running, get_pod_state, runpod_ssh_context
|
|
133
|
+
|
|
134
|
+
key_path = _expand_key_path(target.ssh_key)
|
|
135
|
+
|
|
136
|
+
# Check if pod already exists and is running
|
|
137
|
+
existing = get_pod_state(target.name)
|
|
138
|
+
if existing and await check_pod_running(existing.pod_id):
|
|
139
|
+
# Reuse existing pod
|
|
140
|
+
return TargetSSHInfo(
|
|
141
|
+
host=existing.public_ip,
|
|
142
|
+
port=existing.ssh_port,
|
|
143
|
+
user=existing.ssh_username,
|
|
144
|
+
key_path=key_path,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Need to provision - use the context manager but don't terminate
|
|
148
|
+
# We'll provision and keep the pod running for the exec/ssh/sync operation
|
|
149
|
+
# The user can run `wafer config targets cleanup` to terminate later
|
|
150
|
+
|
|
151
|
+
# Temporarily override keep_alive to True so we don't terminate after getting info
|
|
152
|
+
target_keep_alive = replace(target, keep_alive=True)
|
|
153
|
+
|
|
154
|
+
async with runpod_ssh_context(target_keep_alive) as ssh_info:
|
|
155
|
+
return TargetSSHInfo(
|
|
156
|
+
host=ssh_info.host,
|
|
157
|
+
port=ssh_info.port,
|
|
158
|
+
user=ssh_info.user,
|
|
159
|
+
key_path=key_path,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def _get_digitalocean_ssh_info(target: DigitalOceanTarget) -> TargetSSHInfo:
|
|
164
|
+
"""Get SSH info for DigitalOcean target, provisioning if needed."""
|
|
165
|
+
from wafer_core.targets.digitalocean import (
|
|
166
|
+
check_droplet_running,
|
|
167
|
+
digitalocean_ssh_context,
|
|
168
|
+
get_droplet_state,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
key_path = _expand_key_path(target.ssh_key)
|
|
172
|
+
|
|
173
|
+
# Check if droplet already exists and is running
|
|
174
|
+
existing = get_droplet_state(target.name)
|
|
175
|
+
if existing and await check_droplet_running(existing.droplet_id):
|
|
176
|
+
# Reuse existing droplet
|
|
177
|
+
return TargetSSHInfo(
|
|
178
|
+
host=existing.public_ip,
|
|
179
|
+
port=22, # DigitalOcean uses standard SSH port
|
|
180
|
+
user=existing.ssh_username,
|
|
181
|
+
key_path=key_path,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Need to provision - use the context manager but don't terminate
|
|
185
|
+
target_keep_alive = replace(target, keep_alive=True)
|
|
186
|
+
|
|
187
|
+
async with digitalocean_ssh_context(target_keep_alive) as ssh_info:
|
|
188
|
+
return TargetSSHInfo(
|
|
189
|
+
host=ssh_info.host,
|
|
190
|
+
port=ssh_info.port,
|
|
191
|
+
user=ssh_info.user,
|
|
192
|
+
key_path=key_path,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _get_direct_ssh_info(target: BaremetalTarget | VMTarget) -> TargetSSHInfo:
|
|
197
|
+
"""Get SSH info for Baremetal/VM target (no provisioning needed)."""
|
|
198
|
+
user, host, port = _parse_ssh_target(target.ssh_target)
|
|
199
|
+
key_path = _expand_key_path(target.ssh_key)
|
|
200
|
+
|
|
201
|
+
if not key_path.exists():
|
|
202
|
+
raise TargetExecError(f"SSH key not found: {key_path}")
|
|
203
|
+
|
|
204
|
+
return TargetSSHInfo(
|
|
205
|
+
host=host,
|
|
206
|
+
port=port,
|
|
207
|
+
user=user,
|
|
208
|
+
key_path=key_path,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def exec_on_target_sync(
|
|
213
|
+
ssh_info: TargetSSHInfo,
|
|
214
|
+
command: str,
|
|
215
|
+
timeout_seconds: int | None = None,
|
|
216
|
+
) -> int:
|
|
217
|
+
"""Execute a command on target via SSH (synchronous).
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
ssh_info: SSH connection info
|
|
221
|
+
command: Command to execute
|
|
222
|
+
timeout_seconds: Optional timeout
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Exit code from the remote command
|
|
226
|
+
"""
|
|
227
|
+
ssh_args = [
|
|
228
|
+
"ssh",
|
|
229
|
+
"-i",
|
|
230
|
+
str(ssh_info.key_path),
|
|
231
|
+
"-p",
|
|
232
|
+
str(ssh_info.port),
|
|
233
|
+
"-o",
|
|
234
|
+
"StrictHostKeyChecking=no",
|
|
235
|
+
"-o",
|
|
236
|
+
"UserKnownHostsFile=/dev/null",
|
|
237
|
+
"-o",
|
|
238
|
+
"LogLevel=ERROR",
|
|
239
|
+
f"{ssh_info.user}@{ssh_info.host}",
|
|
240
|
+
command,
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
result = subprocess.run(
|
|
245
|
+
ssh_args,
|
|
246
|
+
timeout=timeout_seconds,
|
|
247
|
+
)
|
|
248
|
+
return result.returncode
|
|
249
|
+
except subprocess.TimeoutExpired as e:
|
|
250
|
+
raise TargetExecError(f"Command timed out after {timeout_seconds}s") from e
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def sync_to_target(
|
|
254
|
+
ssh_info: TargetSSHInfo,
|
|
255
|
+
local_path: Path,
|
|
256
|
+
remote_path: str | None = None,
|
|
257
|
+
on_progress: Callable[[str], None] | None = None,
|
|
258
|
+
) -> int:
|
|
259
|
+
"""Sync files to target via rsync over SSH.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
ssh_info: SSH connection info
|
|
263
|
+
local_path: Local file or directory to sync
|
|
264
|
+
remote_path: Remote destination (default: /tmp/{basename})
|
|
265
|
+
on_progress: Optional callback for progress messages
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Number of files synced
|
|
269
|
+
"""
|
|
270
|
+
if remote_path is None:
|
|
271
|
+
remote_path = f"/tmp/{local_path.name}"
|
|
272
|
+
|
|
273
|
+
# Build rsync command
|
|
274
|
+
ssh_cmd = (
|
|
275
|
+
f"ssh -i {ssh_info.key_path} -p {ssh_info.port} "
|
|
276
|
+
f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Add trailing slash to sync directory contents
|
|
280
|
+
source = str(local_path.resolve())
|
|
281
|
+
if local_path.is_dir():
|
|
282
|
+
source = source.rstrip("/") + "/"
|
|
283
|
+
|
|
284
|
+
rsync_args = [
|
|
285
|
+
"rsync",
|
|
286
|
+
"-avz",
|
|
287
|
+
"--progress",
|
|
288
|
+
"-e",
|
|
289
|
+
ssh_cmd,
|
|
290
|
+
source,
|
|
291
|
+
f"{ssh_info.user}@{ssh_info.host}:{remote_path}",
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
if on_progress:
|
|
295
|
+
on_progress(f"Syncing {local_path} to {ssh_info.host}:{remote_path}")
|
|
296
|
+
|
|
297
|
+
result = subprocess.run(
|
|
298
|
+
rsync_args,
|
|
299
|
+
capture_output=True,
|
|
300
|
+
text=True,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if result.returncode != 0:
|
|
304
|
+
raise TargetExecError(f"rsync failed: {result.stderr}")
|
|
305
|
+
|
|
306
|
+
# Count files from rsync output (lines that don't start with special chars)
|
|
307
|
+
file_count = 0
|
|
308
|
+
for line in result.stdout.splitlines():
|
|
309
|
+
# rsync shows transferred files without leading special chars
|
|
310
|
+
if line and not line.startswith((" ", ".", "sent", "total", "building")):
|
|
311
|
+
file_count += 1
|
|
312
|
+
|
|
313
|
+
if on_progress:
|
|
314
|
+
on_progress(f"Synced {file_count} files")
|
|
315
|
+
|
|
316
|
+
return file_count
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def parse_scp_path(path: str) -> tuple[str | None, str]:
|
|
320
|
+
"""Parse scp-style path into (target_name, path).
|
|
321
|
+
|
|
322
|
+
Returns (None, path) for local paths, (target_name, remote_path) for remote.
|
|
323
|
+
|
|
324
|
+
Examples:
|
|
325
|
+
"./local/file" -> (None, "./local/file")
|
|
326
|
+
"target:/remote/path" -> ("target", "/remote/path")
|
|
327
|
+
"my-target:/tmp/foo" -> ("my-target", "/tmp/foo")
|
|
328
|
+
"""
|
|
329
|
+
if ":" in path:
|
|
330
|
+
# Check if it looks like a Windows path (e.g., C:\...)
|
|
331
|
+
if len(path) >= 2 and path[1] == ":" and path[0].isalpha():
|
|
332
|
+
return (None, path)
|
|
333
|
+
target, remote_path = path.split(":", 1)
|
|
334
|
+
return (target, remote_path)
|
|
335
|
+
return (None, path)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _has_glob_chars(path: str) -> bool:
|
|
339
|
+
"""Check if path contains glob characters."""
|
|
340
|
+
return any(c in path for c in "*?[]")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _sanitize_glob_pattern(pattern: str) -> str:
|
|
344
|
+
"""Sanitize a glob pattern for safe shell execution.
|
|
345
|
+
|
|
346
|
+
Escapes dangerous shell metacharacters while preserving glob characters (* ? [ ]).
|
|
347
|
+
This prevents command injection while allowing glob expansion.
|
|
348
|
+
"""
|
|
349
|
+
# Characters that could enable command injection
|
|
350
|
+
dangerous_chars = {
|
|
351
|
+
";": r"\;",
|
|
352
|
+
"$": r"\$",
|
|
353
|
+
"`": r"\`",
|
|
354
|
+
"|": r"\|",
|
|
355
|
+
"&": r"\&",
|
|
356
|
+
"(": r"\(",
|
|
357
|
+
")": r"\)",
|
|
358
|
+
"{": r"\{",
|
|
359
|
+
"}": r"\}",
|
|
360
|
+
"<": r"\<",
|
|
361
|
+
">": r"\>",
|
|
362
|
+
"\n": "", # Remove newlines entirely
|
|
363
|
+
"\r": "",
|
|
364
|
+
}
|
|
365
|
+
result = pattern
|
|
366
|
+
for char, escaped in dangerous_chars.items():
|
|
367
|
+
result = result.replace(char, escaped)
|
|
368
|
+
return result
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _expand_remote_glob(ssh_info: TargetSSHInfo, pattern: str) -> list[str]:
|
|
372
|
+
"""Expand a glob pattern on the remote host.
|
|
373
|
+
|
|
374
|
+
Returns list of matching file paths, empty if no matches.
|
|
375
|
+
"""
|
|
376
|
+
# Sanitize pattern to prevent command injection while preserving glob chars
|
|
377
|
+
safe_pattern = _sanitize_glob_pattern(pattern)
|
|
378
|
+
|
|
379
|
+
# Use ls -1d to expand glob (handles files and dirs, one per line)
|
|
380
|
+
# The -d flag prevents listing directory contents
|
|
381
|
+
ssh_args = [
|
|
382
|
+
"ssh",
|
|
383
|
+
"-i",
|
|
384
|
+
str(ssh_info.key_path),
|
|
385
|
+
"-p",
|
|
386
|
+
str(ssh_info.port),
|
|
387
|
+
"-o",
|
|
388
|
+
"StrictHostKeyChecking=no",
|
|
389
|
+
"-o",
|
|
390
|
+
"UserKnownHostsFile=/dev/null",
|
|
391
|
+
"-o",
|
|
392
|
+
"LogLevel=ERROR",
|
|
393
|
+
f"{ssh_info.user}@{ssh_info.host}",
|
|
394
|
+
f"ls -1d {safe_pattern} 2>/dev/null",
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
result = subprocess.run(ssh_args, capture_output=True, text=True)
|
|
398
|
+
|
|
399
|
+
if result.returncode != 0 or not result.stdout.strip():
|
|
400
|
+
return []
|
|
401
|
+
|
|
402
|
+
return result.stdout.strip().split("\n")
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def _scp_single_file(
|
|
406
|
+
ssh_info: TargetSSHInfo,
|
|
407
|
+
remote_path: str,
|
|
408
|
+
local_dest: str,
|
|
409
|
+
recursive: bool,
|
|
410
|
+
) -> None:
|
|
411
|
+
"""Download a single file/dir from remote."""
|
|
412
|
+
scp_args = [
|
|
413
|
+
"scp",
|
|
414
|
+
"-i",
|
|
415
|
+
str(ssh_info.key_path),
|
|
416
|
+
"-P",
|
|
417
|
+
str(ssh_info.port),
|
|
418
|
+
"-o",
|
|
419
|
+
"StrictHostKeyChecking=no",
|
|
420
|
+
"-o",
|
|
421
|
+
"UserKnownHostsFile=/dev/null",
|
|
422
|
+
"-o",
|
|
423
|
+
"LogLevel=ERROR",
|
|
424
|
+
]
|
|
425
|
+
|
|
426
|
+
if recursive:
|
|
427
|
+
scp_args.append("-r")
|
|
428
|
+
|
|
429
|
+
scp_args.extend([
|
|
430
|
+
f"{ssh_info.user}@{ssh_info.host}:{remote_path}",
|
|
431
|
+
local_dest,
|
|
432
|
+
])
|
|
433
|
+
|
|
434
|
+
result = subprocess.run(scp_args, capture_output=True, text=True)
|
|
435
|
+
if result.returncode != 0:
|
|
436
|
+
raise TargetExecError(f"scp failed for {remote_path}: {result.stderr}")
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _scp_glob_download(
|
|
440
|
+
ssh_info: TargetSSHInfo,
|
|
441
|
+
remote_pattern: str,
|
|
442
|
+
local_dest: str,
|
|
443
|
+
recursive: bool,
|
|
444
|
+
) -> None:
|
|
445
|
+
"""Download files matching a glob pattern from remote.
|
|
446
|
+
|
|
447
|
+
Expands the glob on the remote host, then downloads each file.
|
|
448
|
+
"""
|
|
449
|
+
files = _expand_remote_glob(ssh_info, remote_pattern)
|
|
450
|
+
|
|
451
|
+
if not files:
|
|
452
|
+
logger.warning(f"No files matched pattern: {remote_pattern}")
|
|
453
|
+
return
|
|
454
|
+
|
|
455
|
+
for remote_file in files:
|
|
456
|
+
_scp_single_file(ssh_info, remote_file, local_dest, recursive)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def scp_transfer(
|
|
460
|
+
ssh_info: TargetSSHInfo,
|
|
461
|
+
source: str,
|
|
462
|
+
dest: str,
|
|
463
|
+
is_download: bool,
|
|
464
|
+
recursive: bool = False,
|
|
465
|
+
) -> None:
|
|
466
|
+
"""Transfer files via scp. Supports glob patterns for downloads.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
ssh_info: SSH connection info
|
|
470
|
+
source: Source path (local for upload, remote for download)
|
|
471
|
+
dest: Destination path (remote for upload, local for download)
|
|
472
|
+
is_download: True if downloading from remote, False if uploading
|
|
473
|
+
recursive: Whether to copy directories recursively
|
|
474
|
+
|
|
475
|
+
Raises:
|
|
476
|
+
TargetExecError: If scp fails
|
|
477
|
+
"""
|
|
478
|
+
# Handle glob patterns for downloads
|
|
479
|
+
if is_download and _has_glob_chars(source):
|
|
480
|
+
return _scp_glob_download(ssh_info, source, dest, recursive)
|
|
481
|
+
|
|
482
|
+
scp_args = [
|
|
483
|
+
"scp",
|
|
484
|
+
"-i",
|
|
485
|
+
str(ssh_info.key_path),
|
|
486
|
+
"-P",
|
|
487
|
+
str(ssh_info.port),
|
|
488
|
+
"-o",
|
|
489
|
+
"StrictHostKeyChecking=no",
|
|
490
|
+
"-o",
|
|
491
|
+
"UserKnownHostsFile=/dev/null",
|
|
492
|
+
"-o",
|
|
493
|
+
"LogLevel=ERROR",
|
|
494
|
+
]
|
|
495
|
+
|
|
496
|
+
if recursive:
|
|
497
|
+
scp_args.append("-r")
|
|
498
|
+
|
|
499
|
+
if is_download:
|
|
500
|
+
# remote -> local
|
|
501
|
+
scp_args.extend([
|
|
502
|
+
f"{ssh_info.user}@{ssh_info.host}:{source}",
|
|
503
|
+
dest,
|
|
504
|
+
])
|
|
505
|
+
else:
|
|
506
|
+
# local -> remote
|
|
507
|
+
scp_args.extend([
|
|
508
|
+
source,
|
|
509
|
+
f"{ssh_info.user}@{ssh_info.host}:{dest}",
|
|
510
|
+
])
|
|
511
|
+
|
|
512
|
+
result = subprocess.run(scp_args, capture_output=True, text=True)
|
|
513
|
+
if result.returncode != 0:
|
|
514
|
+
raise TargetExecError(f"scp failed: {result.stderr}")
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
# =============================================================================
|
|
518
|
+
# Tool Registry for `wafer targets ensure`
|
|
519
|
+
# =============================================================================
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
@dataclass(frozen=True)
|
|
523
|
+
class ToolSpec:
|
|
524
|
+
"""Specification for a tool that can be installed on a target."""
|
|
525
|
+
|
|
526
|
+
name: str
|
|
527
|
+
check_cmd: str # Command to check if installed (exit 0 = installed)
|
|
528
|
+
install_cmd: str | None # Command to install (None = can't auto-install)
|
|
529
|
+
verify_cmd: str | None = None # Command to verify after install
|
|
530
|
+
platform: str = "any" # "amd", "nvidia", or "any"
|
|
531
|
+
description: str = ""
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
TOOL_REGISTRY: dict[str, ToolSpec] = {
|
|
535
|
+
# AMD Tools
|
|
536
|
+
"rocprof-compute": ToolSpec(
|
|
537
|
+
name="rocprof-compute",
|
|
538
|
+
check_cmd="which rocprof-compute",
|
|
539
|
+
# rocprofiler-compute requires ROCm >= 6.3 and apt install (not pip)
|
|
540
|
+
# For older ROCm, users need to upgrade or install manually
|
|
541
|
+
install_cmd="apt-get update && apt-get install -y rocprofiler-compute && python3 -m pip install -r /opt/rocm/libexec/rocprofiler-compute/requirements.txt",
|
|
542
|
+
verify_cmd="rocprof-compute --version",
|
|
543
|
+
platform="amd",
|
|
544
|
+
description="AMD GPU profiling (roofline, memory, etc.) - requires ROCm >= 6.3",
|
|
545
|
+
),
|
|
546
|
+
"rocprof-systems": ToolSpec(
|
|
547
|
+
name="rocprof-systems",
|
|
548
|
+
check_cmd="which rocprof-systems",
|
|
549
|
+
# rocprofiler-systems also requires apt install on ROCm >= 6.3
|
|
550
|
+
install_cmd="apt-get update && apt-get install -y rocprofiler-systems && python3 -m pip install -r /opt/rocm/libexec/rocprofiler-systems/requirements.txt",
|
|
551
|
+
verify_cmd="rocprof-systems --version",
|
|
552
|
+
platform="amd",
|
|
553
|
+
description="AMD system-wide tracing - requires ROCm >= 6.3",
|
|
554
|
+
),
|
|
555
|
+
"rocprof": ToolSpec(
|
|
556
|
+
name="rocprof",
|
|
557
|
+
check_cmd="which rocprof",
|
|
558
|
+
install_cmd=None, # Part of ROCm base install
|
|
559
|
+
platform="amd",
|
|
560
|
+
description="AMD kernel profiling (part of ROCm)",
|
|
561
|
+
),
|
|
562
|
+
# NVIDIA Tools
|
|
563
|
+
"ncu": ToolSpec(
|
|
564
|
+
name="ncu",
|
|
565
|
+
check_cmd="which ncu",
|
|
566
|
+
install_cmd=None, # Part of CUDA toolkit
|
|
567
|
+
platform="nvidia",
|
|
568
|
+
description="NVIDIA Nsight Compute (part of CUDA toolkit)",
|
|
569
|
+
),
|
|
570
|
+
"nsys": ToolSpec(
|
|
571
|
+
name="nsys",
|
|
572
|
+
check_cmd="which nsys",
|
|
573
|
+
install_cmd=None, # Part of CUDA toolkit
|
|
574
|
+
platform="nvidia",
|
|
575
|
+
description="NVIDIA Nsight Systems (part of CUDA toolkit)",
|
|
576
|
+
),
|
|
577
|
+
"nvtx": ToolSpec(
|
|
578
|
+
name="nvtx",
|
|
579
|
+
check_cmd='python -c "import nvtx"',
|
|
580
|
+
install_cmd="pip install nvtx",
|
|
581
|
+
verify_cmd='python -c "import nvtx; print(nvtx.__version__)"',
|
|
582
|
+
platform="nvidia",
|
|
583
|
+
description="NVIDIA Tools Extension (Python)",
|
|
584
|
+
),
|
|
585
|
+
# Cross-platform Python packages
|
|
586
|
+
"triton": ToolSpec(
|
|
587
|
+
name="triton",
|
|
588
|
+
check_cmd='python -c "import triton"',
|
|
589
|
+
install_cmd="pip install triton",
|
|
590
|
+
verify_cmd='python -c "import triton; print(triton.__version__)"',
|
|
591
|
+
platform="any",
|
|
592
|
+
description="OpenAI Triton compiler",
|
|
593
|
+
),
|
|
594
|
+
"torch": ToolSpec(
|
|
595
|
+
name="torch",
|
|
596
|
+
check_cmd='python -c "import torch"',
|
|
597
|
+
install_cmd="pip install torch",
|
|
598
|
+
verify_cmd='python -c "import torch; print(torch.__version__)"',
|
|
599
|
+
platform="any",
|
|
600
|
+
description="PyTorch",
|
|
601
|
+
),
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def get_target_platform(target: TargetConfig) -> str:
|
|
606
|
+
"""Determine platform (amd/nvidia) from target config."""
|
|
607
|
+
# Import target types for isinstance checks
|
|
608
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
609
|
+
DigitalOceanTarget,
|
|
610
|
+
LocalTarget,
|
|
611
|
+
RunPodTarget,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# RunPod and DigitalOcean are always AMD MI300X
|
|
615
|
+
if isinstance(target, (RunPodTarget, DigitalOceanTarget)):
|
|
616
|
+
return "amd"
|
|
617
|
+
|
|
618
|
+
# LocalTarget has explicit vendor field
|
|
619
|
+
if isinstance(target, LocalTarget):
|
|
620
|
+
return target.vendor
|
|
621
|
+
|
|
622
|
+
# For Baremetal/VM, check gpu_type or compute_capability
|
|
623
|
+
gpu_type = getattr(target, "gpu_type", "")
|
|
624
|
+
if "MI300" in gpu_type:
|
|
625
|
+
return "amd"
|
|
626
|
+
|
|
627
|
+
compute_cap = getattr(target, "compute_capability", "")
|
|
628
|
+
if compute_cap == "9.4": # gfx942 = MI300X
|
|
629
|
+
return "amd"
|
|
630
|
+
|
|
631
|
+
# Default to nvidia for other compute capabilities
|
|
632
|
+
return "nvidia"
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@dataclass
|
|
636
|
+
class EnsureResult:
|
|
637
|
+
"""Result of ensure_tool operation."""
|
|
638
|
+
|
|
639
|
+
tool: str
|
|
640
|
+
already_installed: bool
|
|
641
|
+
installed: bool
|
|
642
|
+
verified: bool
|
|
643
|
+
error: str | None = None
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def ensure_tool(
|
|
647
|
+
ssh_info: TargetSSHInfo,
|
|
648
|
+
tool: str,
|
|
649
|
+
force: bool = False,
|
|
650
|
+
timeout: int = 300,
|
|
651
|
+
) -> EnsureResult:
|
|
652
|
+
"""Ensure a tool is installed on target.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
ssh_info: SSH connection info
|
|
656
|
+
tool: Tool name from TOOL_REGISTRY
|
|
657
|
+
force: If True, reinstall even if present
|
|
658
|
+
timeout: Timeout for install command
|
|
659
|
+
|
|
660
|
+
Returns:
|
|
661
|
+
EnsureResult with status
|
|
662
|
+
"""
|
|
663
|
+
if tool not in TOOL_REGISTRY:
|
|
664
|
+
return EnsureResult(
|
|
665
|
+
tool=tool,
|
|
666
|
+
already_installed=False,
|
|
667
|
+
installed=False,
|
|
668
|
+
verified=False,
|
|
669
|
+
error=f"Unknown tool: {tool}. Available: {', '.join(sorted(TOOL_REGISTRY.keys()))}",
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
spec = TOOL_REGISTRY[tool]
|
|
673
|
+
|
|
674
|
+
# Check if already installed
|
|
675
|
+
if not force:
|
|
676
|
+
exit_code = exec_on_target_sync(ssh_info, spec.check_cmd, timeout_seconds=30)
|
|
677
|
+
if exit_code == 0:
|
|
678
|
+
return EnsureResult(
|
|
679
|
+
tool=tool,
|
|
680
|
+
already_installed=True,
|
|
681
|
+
installed=False,
|
|
682
|
+
verified=True,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# Can't auto-install
|
|
686
|
+
if spec.install_cmd is None:
|
|
687
|
+
return EnsureResult(
|
|
688
|
+
tool=tool,
|
|
689
|
+
already_installed=False,
|
|
690
|
+
installed=False,
|
|
691
|
+
verified=False,
|
|
692
|
+
error=f"{tool} cannot be auto-installed. It's part of the base platform (ROCm/CUDA).",
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# Install
|
|
696
|
+
exit_code = exec_on_target_sync(ssh_info, spec.install_cmd, timeout_seconds=timeout)
|
|
697
|
+
if exit_code != 0:
|
|
698
|
+
return EnsureResult(
|
|
699
|
+
tool=tool,
|
|
700
|
+
already_installed=False,
|
|
701
|
+
installed=False,
|
|
702
|
+
verified=False,
|
|
703
|
+
error=f"Installation failed (exit code {exit_code})",
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
# Verify
|
|
707
|
+
verified = True
|
|
708
|
+
if spec.verify_cmd:
|
|
709
|
+
exit_code = exec_on_target_sync(ssh_info, spec.verify_cmd, timeout_seconds=30)
|
|
710
|
+
verified = exit_code == 0
|
|
711
|
+
|
|
712
|
+
return EnsureResult(
|
|
713
|
+
tool=tool,
|
|
714
|
+
already_installed=False,
|
|
715
|
+
installed=True,
|
|
716
|
+
verified=verified,
|
|
717
|
+
error=None if verified else "Installation succeeded but verification failed",
|
|
718
|
+
)
|