wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/targets_ops.py ADDED
@@ -0,0 +1,717 @@
1
+ """Target operations for exec/ssh/sync commands.
2
+
3
+ This module provides the business logic for running commands on targets,
4
+ getting SSH credentials, and syncing files. It handles:
5
+ - RunPod: Auto-provision pod, get SSH credentials
6
+ - DigitalOcean: Auto-provision droplet, get SSH credentials
7
+ - Baremetal/VM: Direct SSH with configured credentials
8
+ - Workspace: Delegate to workspace API
9
+ - Modal/Local: Not supported (no SSH access)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import subprocess
16
+ from collections.abc import Callable
17
+ from dataclasses import dataclass, replace
18
+ from pathlib import Path
19
+ from typing import TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ from wafer_core.utils.kernel_utils.targets.config import (
23
+ BaremetalTarget,
24
+ DigitalOceanTarget,
25
+ RunPodTarget,
26
+ TargetConfig,
27
+ VMTarget,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class TargetSSHInfo:
35
+ """SSH connection info for a target."""
36
+
37
+ host: str
38
+ port: int
39
+ user: str
40
+ key_path: Path
41
+
42
+
43
+ class TargetExecError(Exception):
44
+ """Error during target operation (exec/ssh/sync)."""
45
+
46
+ pass
47
+
48
+
49
+ def _expand_key_path(ssh_key: str) -> Path:
50
+ """Expand SSH key path (synchronous, fast operation)."""
51
+ return Path(ssh_key).expanduser()
52
+
53
+
54
+ def _parse_ssh_target(ssh_target: str) -> tuple[str, str, int]:
55
+ """Parse ssh_target string into (user, host, port).
56
+
57
+ Format: user@host:port
58
+ """
59
+ # Split user@host:port
60
+ if "@" not in ssh_target:
61
+ raise ValueError(f"Invalid ssh_target format: {ssh_target} (expected user@host:port)")
62
+
63
+ user, rest = ssh_target.split("@", 1)
64
+
65
+ if ":" not in rest:
66
+ raise ValueError(f"Invalid ssh_target format: {ssh_target} (expected user@host:port)")
67
+
68
+ host, port_str = rest.rsplit(":", 1)
69
+
70
+ try:
71
+ port = int(port_str)
72
+ except ValueError as e:
73
+ raise ValueError(f"Invalid port in ssh_target: {port_str}") from e
74
+
75
+ return user, host, port
76
+
77
+
78
+ async def get_target_ssh_info(target: TargetConfig) -> TargetSSHInfo:
79
+ """Get SSH connection info for a target.
80
+
81
+ For RunPod/DigitalOcean: Provisions if needed, returns SSH info.
82
+ For Baremetal/VM: Returns configured SSH info directly.
83
+ For Modal/Local/Workspace: Raises (no SSH access).
84
+
85
+ Args:
86
+ target: Target configuration
87
+
88
+ Returns:
89
+ TargetSSHInfo with host, port, user, key_path
90
+
91
+ Raises:
92
+ TargetExecError: If target type doesn't support SSH
93
+ """
94
+ from wafer_core.utils.kernel_utils.targets.config import (
95
+ BaremetalTarget,
96
+ DigitalOceanTarget,
97
+ LocalTarget,
98
+ ModalTarget,
99
+ RunPodTarget,
100
+ VMTarget,
101
+ WorkspaceTarget,
102
+ )
103
+
104
+ if isinstance(target, RunPodTarget):
105
+ return await _get_runpod_ssh_info(target)
106
+ elif isinstance(target, DigitalOceanTarget):
107
+ return await _get_digitalocean_ssh_info(target)
108
+ elif isinstance(target, (BaremetalTarget, VMTarget)):
109
+ return _get_direct_ssh_info(target)
110
+ elif isinstance(target, WorkspaceTarget):
111
+ raise TargetExecError(
112
+ f"WorkspaceTarget '{target.name}' uses API-based access.\n"
113
+ "Use 'wafer workspaces exec/ssh/sync' instead."
114
+ )
115
+ elif isinstance(target, ModalTarget):
116
+ raise TargetExecError(
117
+ f"ModalTarget '{target.name}' is serverless and has no SSH access.\n"
118
+ "Use 'wafer evaluate' to run code on Modal targets."
119
+ )
120
+ elif isinstance(target, LocalTarget):
121
+ raise TargetExecError(
122
+ f"LocalTarget '{target.name}' runs locally and has no SSH.\n"
123
+ "Run commands directly on this machine."
124
+ )
125
+ else:
126
+ raise TargetExecError(f"Unknown target type: {type(target).__name__}")
127
+
128
+
129
+ async def _get_runpod_ssh_info(target: RunPodTarget) -> TargetSSHInfo:
130
+ """Get SSH info for RunPod target, provisioning if needed."""
131
+ from wafer_core.targets.runpod import check_pod_running, get_pod_state, runpod_ssh_context
132
+
133
+ key_path = _expand_key_path(target.ssh_key)
134
+
135
+ # Check if pod already exists and is running
136
+ existing = get_pod_state(target.name)
137
+ if existing and await check_pod_running(existing.pod_id):
138
+ # Reuse existing pod
139
+ return TargetSSHInfo(
140
+ host=existing.public_ip,
141
+ port=existing.ssh_port,
142
+ user=existing.ssh_username,
143
+ key_path=key_path,
144
+ )
145
+
146
+ # Need to provision - use the context manager but don't terminate
147
+ # We'll provision and keep the pod running for the exec/ssh/sync operation
148
+ # The user can run `wafer config targets cleanup` to terminate later
149
+
150
+ # Temporarily override keep_alive to True so we don't terminate after getting info
151
+ target_keep_alive = replace(target, keep_alive=True)
152
+
153
+ async with runpod_ssh_context(target_keep_alive) as ssh_info:
154
+ return TargetSSHInfo(
155
+ host=ssh_info.host,
156
+ port=ssh_info.port,
157
+ user=ssh_info.user,
158
+ key_path=key_path,
159
+ )
160
+
161
+
162
+ async def _get_digitalocean_ssh_info(target: DigitalOceanTarget) -> TargetSSHInfo:
163
+ """Get SSH info for DigitalOcean target, provisioning if needed."""
164
+ from wafer_core.targets.digitalocean import (
165
+ check_droplet_running,
166
+ digitalocean_ssh_context,
167
+ get_droplet_state,
168
+ )
169
+
170
+ key_path = _expand_key_path(target.ssh_key)
171
+
172
+ # Check if droplet already exists and is running
173
+ existing = get_droplet_state(target.name)
174
+ if existing and await check_droplet_running(existing.droplet_id):
175
+ # Reuse existing droplet
176
+ return TargetSSHInfo(
177
+ host=existing.public_ip,
178
+ port=22, # DigitalOcean uses standard SSH port
179
+ user=existing.ssh_username,
180
+ key_path=key_path,
181
+ )
182
+
183
+ # Need to provision - use the context manager but don't terminate
184
+ target_keep_alive = replace(target, keep_alive=True)
185
+
186
+ async with digitalocean_ssh_context(target_keep_alive) as ssh_info:
187
+ return TargetSSHInfo(
188
+ host=ssh_info.host,
189
+ port=ssh_info.port,
190
+ user=ssh_info.user,
191
+ key_path=key_path,
192
+ )
193
+
194
+
195
+ def _get_direct_ssh_info(target: BaremetalTarget | VMTarget) -> TargetSSHInfo:
196
+ """Get SSH info for Baremetal/VM target (no provisioning needed)."""
197
+ user, host, port = _parse_ssh_target(target.ssh_target)
198
+ key_path = _expand_key_path(target.ssh_key)
199
+
200
+ if not key_path.exists():
201
+ raise TargetExecError(f"SSH key not found: {key_path}")
202
+
203
+ return TargetSSHInfo(
204
+ host=host,
205
+ port=port,
206
+ user=user,
207
+ key_path=key_path,
208
+ )
209
+
210
+
211
+ def exec_on_target_sync(
212
+ ssh_info: TargetSSHInfo,
213
+ command: str,
214
+ timeout_seconds: int | None = None,
215
+ ) -> int:
216
+ """Execute a command on target via SSH (synchronous).
217
+
218
+ Args:
219
+ ssh_info: SSH connection info
220
+ command: Command to execute
221
+ timeout_seconds: Optional timeout
222
+
223
+ Returns:
224
+ Exit code from the remote command
225
+ """
226
+ ssh_args = [
227
+ "ssh",
228
+ "-i",
229
+ str(ssh_info.key_path),
230
+ "-p",
231
+ str(ssh_info.port),
232
+ "-o",
233
+ "StrictHostKeyChecking=no",
234
+ "-o",
235
+ "UserKnownHostsFile=/dev/null",
236
+ "-o",
237
+ "LogLevel=ERROR",
238
+ f"{ssh_info.user}@{ssh_info.host}",
239
+ command,
240
+ ]
241
+
242
+ try:
243
+ result = subprocess.run(
244
+ ssh_args,
245
+ timeout=timeout_seconds,
246
+ )
247
+ return result.returncode
248
+ except subprocess.TimeoutExpired as e:
249
+ raise TargetExecError(f"Command timed out after {timeout_seconds}s") from e
250
+
251
+
252
+ def sync_to_target(
253
+ ssh_info: TargetSSHInfo,
254
+ local_path: Path,
255
+ remote_path: str | None = None,
256
+ on_progress: Callable[[str], None] | None = None,
257
+ ) -> int:
258
+ """Sync files to target via rsync over SSH.
259
+
260
+ Args:
261
+ ssh_info: SSH connection info
262
+ local_path: Local file or directory to sync
263
+ remote_path: Remote destination (default: /tmp/{basename})
264
+ on_progress: Optional callback for progress messages
265
+
266
+ Returns:
267
+ Number of files synced
268
+ """
269
+ if remote_path is None:
270
+ remote_path = f"/tmp/{local_path.name}"
271
+
272
+ # Build rsync command
273
+ ssh_cmd = (
274
+ f"ssh -i {ssh_info.key_path} -p {ssh_info.port} "
275
+ f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR"
276
+ )
277
+
278
+ # Add trailing slash to sync directory contents
279
+ source = str(local_path.resolve())
280
+ if local_path.is_dir():
281
+ source = source.rstrip("/") + "/"
282
+
283
+ rsync_args = [
284
+ "rsync",
285
+ "-avz",
286
+ "--progress",
287
+ "-e",
288
+ ssh_cmd,
289
+ source,
290
+ f"{ssh_info.user}@{ssh_info.host}:{remote_path}",
291
+ ]
292
+
293
+ if on_progress:
294
+ on_progress(f"Syncing {local_path} to {ssh_info.host}:{remote_path}")
295
+
296
+ result = subprocess.run(
297
+ rsync_args,
298
+ capture_output=True,
299
+ text=True,
300
+ )
301
+
302
+ if result.returncode != 0:
303
+ raise TargetExecError(f"rsync failed: {result.stderr}")
304
+
305
+ # Count files from rsync output (lines that don't start with special chars)
306
+ file_count = 0
307
+ for line in result.stdout.splitlines():
308
+ # rsync shows transferred files without leading special chars
309
+ if line and not line.startswith((" ", ".", "sent", "total", "building")):
310
+ file_count += 1
311
+
312
+ if on_progress:
313
+ on_progress(f"Synced {file_count} files")
314
+
315
+ return file_count
316
+
317
+
318
+ def parse_scp_path(path: str) -> tuple[str | None, str]:
319
+ """Parse scp-style path into (target_name, path).
320
+
321
+ Returns (None, path) for local paths, (target_name, remote_path) for remote.
322
+
323
+ Examples:
324
+ "./local/file" -> (None, "./local/file")
325
+ "target:/remote/path" -> ("target", "/remote/path")
326
+ "my-target:/tmp/foo" -> ("my-target", "/tmp/foo")
327
+ """
328
+ if ":" in path:
329
+ # Check if it looks like a Windows path (e.g., C:\...)
330
+ if len(path) >= 2 and path[1] == ":" and path[0].isalpha():
331
+ return (None, path)
332
+ target, remote_path = path.split(":", 1)
333
+ return (target, remote_path)
334
+ return (None, path)
335
+
336
+
337
+ def _has_glob_chars(path: str) -> bool:
338
+ """Check if path contains glob characters."""
339
+ return any(c in path for c in "*?[]")
340
+
341
+
342
+ def _sanitize_glob_pattern(pattern: str) -> str:
343
+ """Sanitize a glob pattern for safe shell execution.
344
+
345
+ Escapes dangerous shell metacharacters while preserving glob characters (* ? [ ]).
346
+ This prevents command injection while allowing glob expansion.
347
+ """
348
+ # Characters that could enable command injection
349
+ dangerous_chars = {
350
+ ";": r"\;",
351
+ "$": r"\$",
352
+ "`": r"\`",
353
+ "|": r"\|",
354
+ "&": r"\&",
355
+ "(": r"\(",
356
+ ")": r"\)",
357
+ "{": r"\{",
358
+ "}": r"\}",
359
+ "<": r"\<",
360
+ ">": r"\>",
361
+ "\n": "", # Remove newlines entirely
362
+ "\r": "",
363
+ }
364
+ result = pattern
365
+ for char, escaped in dangerous_chars.items():
366
+ result = result.replace(char, escaped)
367
+ return result
368
+
369
+
370
+ def _expand_remote_glob(ssh_info: TargetSSHInfo, pattern: str) -> list[str]:
371
+ """Expand a glob pattern on the remote host.
372
+
373
+ Returns list of matching file paths, empty if no matches.
374
+ """
375
+ # Sanitize pattern to prevent command injection while preserving glob chars
376
+ safe_pattern = _sanitize_glob_pattern(pattern)
377
+
378
+ # Use ls -1d to expand glob (handles files and dirs, one per line)
379
+ # The -d flag prevents listing directory contents
380
+ ssh_args = [
381
+ "ssh",
382
+ "-i",
383
+ str(ssh_info.key_path),
384
+ "-p",
385
+ str(ssh_info.port),
386
+ "-o",
387
+ "StrictHostKeyChecking=no",
388
+ "-o",
389
+ "UserKnownHostsFile=/dev/null",
390
+ "-o",
391
+ "LogLevel=ERROR",
392
+ f"{ssh_info.user}@{ssh_info.host}",
393
+ f"ls -1d {safe_pattern} 2>/dev/null",
394
+ ]
395
+
396
+ result = subprocess.run(ssh_args, capture_output=True, text=True)
397
+
398
+ if result.returncode != 0 or not result.stdout.strip():
399
+ return []
400
+
401
+ return result.stdout.strip().split("\n")
402
+
403
+
404
+ def _scp_single_file(
405
+ ssh_info: TargetSSHInfo,
406
+ remote_path: str,
407
+ local_dest: str,
408
+ recursive: bool,
409
+ ) -> None:
410
+ """Download a single file/dir from remote."""
411
+ scp_args = [
412
+ "scp",
413
+ "-i",
414
+ str(ssh_info.key_path),
415
+ "-P",
416
+ str(ssh_info.port),
417
+ "-o",
418
+ "StrictHostKeyChecking=no",
419
+ "-o",
420
+ "UserKnownHostsFile=/dev/null",
421
+ "-o",
422
+ "LogLevel=ERROR",
423
+ ]
424
+
425
+ if recursive:
426
+ scp_args.append("-r")
427
+
428
+ scp_args.extend([
429
+ f"{ssh_info.user}@{ssh_info.host}:{remote_path}",
430
+ local_dest,
431
+ ])
432
+
433
+ result = subprocess.run(scp_args, capture_output=True, text=True)
434
+ if result.returncode != 0:
435
+ raise TargetExecError(f"scp failed for {remote_path}: {result.stderr}")
436
+
437
+
438
+ def _scp_glob_download(
439
+ ssh_info: TargetSSHInfo,
440
+ remote_pattern: str,
441
+ local_dest: str,
442
+ recursive: bool,
443
+ ) -> None:
444
+ """Download files matching a glob pattern from remote.
445
+
446
+ Expands the glob on the remote host, then downloads each file.
447
+ """
448
+ files = _expand_remote_glob(ssh_info, remote_pattern)
449
+
450
+ if not files:
451
+ logger.warning(f"No files matched pattern: {remote_pattern}")
452
+ return
453
+
454
+ for remote_file in files:
455
+ _scp_single_file(ssh_info, remote_file, local_dest, recursive)
456
+
457
+
458
+ def scp_transfer(
459
+ ssh_info: TargetSSHInfo,
460
+ source: str,
461
+ dest: str,
462
+ is_download: bool,
463
+ recursive: bool = False,
464
+ ) -> None:
465
+ """Transfer files via scp. Supports glob patterns for downloads.
466
+
467
+ Args:
468
+ ssh_info: SSH connection info
469
+ source: Source path (local for upload, remote for download)
470
+ dest: Destination path (remote for upload, local for download)
471
+ is_download: True if downloading from remote, False if uploading
472
+ recursive: Whether to copy directories recursively
473
+
474
+ Raises:
475
+ TargetExecError: If scp fails
476
+ """
477
+ # Handle glob patterns for downloads
478
+ if is_download and _has_glob_chars(source):
479
+ return _scp_glob_download(ssh_info, source, dest, recursive)
480
+
481
+ scp_args = [
482
+ "scp",
483
+ "-i",
484
+ str(ssh_info.key_path),
485
+ "-P",
486
+ str(ssh_info.port),
487
+ "-o",
488
+ "StrictHostKeyChecking=no",
489
+ "-o",
490
+ "UserKnownHostsFile=/dev/null",
491
+ "-o",
492
+ "LogLevel=ERROR",
493
+ ]
494
+
495
+ if recursive:
496
+ scp_args.append("-r")
497
+
498
+ if is_download:
499
+ # remote -> local
500
+ scp_args.extend([
501
+ f"{ssh_info.user}@{ssh_info.host}:{source}",
502
+ dest,
503
+ ])
504
+ else:
505
+ # local -> remote
506
+ scp_args.extend([
507
+ source,
508
+ f"{ssh_info.user}@{ssh_info.host}:{dest}",
509
+ ])
510
+
511
+ result = subprocess.run(scp_args, capture_output=True, text=True)
512
+ if result.returncode != 0:
513
+ raise TargetExecError(f"scp failed: {result.stderr}")
514
+
515
+
516
+ # =============================================================================
517
+ # Tool Registry for `wafer targets ensure`
518
+ # =============================================================================
519
+
520
+
521
+ @dataclass(frozen=True)
522
+ class ToolSpec:
523
+ """Specification for a tool that can be installed on a target."""
524
+
525
+ name: str
526
+ check_cmd: str # Command to check if installed (exit 0 = installed)
527
+ install_cmd: str | None # Command to install (None = can't auto-install)
528
+ verify_cmd: str | None = None # Command to verify after install
529
+ platform: str = "any" # "amd", "nvidia", or "any"
530
+ description: str = ""
531
+
532
+
533
+ TOOL_REGISTRY: dict[str, ToolSpec] = {
534
+ # AMD Tools
535
+ "rocprof-compute": ToolSpec(
536
+ name="rocprof-compute",
537
+ check_cmd="which rocprof-compute",
538
+ # rocprofiler-compute requires ROCm >= 6.3 and apt install (not pip)
539
+ # For older ROCm, users need to upgrade or install manually
540
+ install_cmd="apt-get update && apt-get install -y rocprofiler-compute && python3 -m pip install -r /opt/rocm/libexec/rocprofiler-compute/requirements.txt",
541
+ verify_cmd="rocprof-compute --version",
542
+ platform="amd",
543
+ description="AMD GPU profiling (roofline, memory, etc.) - requires ROCm >= 6.3",
544
+ ),
545
+ "rocprof-systems": ToolSpec(
546
+ name="rocprof-systems",
547
+ check_cmd="which rocprof-systems",
548
+ # rocprofiler-systems also requires apt install on ROCm >= 6.3
549
+ install_cmd="apt-get update && apt-get install -y rocprofiler-systems && python3 -m pip install -r /opt/rocm/libexec/rocprofiler-systems/requirements.txt",
550
+ verify_cmd="rocprof-systems --version",
551
+ platform="amd",
552
+ description="AMD system-wide tracing - requires ROCm >= 6.3",
553
+ ),
554
+ "rocprof": ToolSpec(
555
+ name="rocprof",
556
+ check_cmd="which rocprof",
557
+ install_cmd=None, # Part of ROCm base install
558
+ platform="amd",
559
+ description="AMD kernel profiling (part of ROCm)",
560
+ ),
561
+ # NVIDIA Tools
562
+ "ncu": ToolSpec(
563
+ name="ncu",
564
+ check_cmd="which ncu",
565
+ install_cmd=None, # Part of CUDA toolkit
566
+ platform="nvidia",
567
+ description="NVIDIA Nsight Compute (part of CUDA toolkit)",
568
+ ),
569
+ "nsys": ToolSpec(
570
+ name="nsys",
571
+ check_cmd="which nsys",
572
+ install_cmd=None, # Part of CUDA toolkit
573
+ platform="nvidia",
574
+ description="NVIDIA Nsight Systems (part of CUDA toolkit)",
575
+ ),
576
+ "nvtx": ToolSpec(
577
+ name="nvtx",
578
+ check_cmd='python -c "import nvtx"',
579
+ install_cmd="pip install nvtx",
580
+ verify_cmd='python -c "import nvtx; print(nvtx.__version__)"',
581
+ platform="nvidia",
582
+ description="NVIDIA Tools Extension (Python)",
583
+ ),
584
+ # Cross-platform Python packages
585
+ "triton": ToolSpec(
586
+ name="triton",
587
+ check_cmd='python -c "import triton"',
588
+ install_cmd="pip install triton",
589
+ verify_cmd='python -c "import triton; print(triton.__version__)"',
590
+ platform="any",
591
+ description="OpenAI Triton compiler",
592
+ ),
593
+ "torch": ToolSpec(
594
+ name="torch",
595
+ check_cmd='python -c "import torch"',
596
+ install_cmd="pip install torch",
597
+ verify_cmd='python -c "import torch; print(torch.__version__)"',
598
+ platform="any",
599
+ description="PyTorch",
600
+ ),
601
+ }
602
+
603
+
604
+ def get_target_platform(target: TargetConfig) -> str:
605
+ """Determine platform (amd/nvidia) from target config."""
606
+ # Import target types for isinstance checks
607
+ from wafer_core.utils.kernel_utils.targets.config import (
608
+ DigitalOceanTarget,
609
+ LocalTarget,
610
+ RunPodTarget,
611
+ )
612
+
613
+ # RunPod and DigitalOcean are always AMD MI300X
614
+ if isinstance(target, (RunPodTarget, DigitalOceanTarget)):
615
+ return "amd"
616
+
617
+ # LocalTarget has explicit vendor field
618
+ if isinstance(target, LocalTarget):
619
+ return target.vendor
620
+
621
+ # For Baremetal/VM, check gpu_type or compute_capability
622
+ gpu_type = getattr(target, "gpu_type", "")
623
+ if "MI300" in gpu_type:
624
+ return "amd"
625
+
626
+ compute_cap = getattr(target, "compute_capability", "")
627
+ if compute_cap == "9.4": # gfx942 = MI300X
628
+ return "amd"
629
+
630
+ # Default to nvidia for other compute capabilities
631
+ return "nvidia"
632
+
633
+
634
+ @dataclass
635
+ class EnsureResult:
636
+ """Result of ensure_tool operation."""
637
+
638
+ tool: str
639
+ already_installed: bool
640
+ installed: bool
641
+ verified: bool
642
+ error: str | None = None
643
+
644
+
645
+ def ensure_tool(
646
+ ssh_info: TargetSSHInfo,
647
+ tool: str,
648
+ force: bool = False,
649
+ timeout: int = 300,
650
+ ) -> EnsureResult:
651
+ """Ensure a tool is installed on target.
652
+
653
+ Args:
654
+ ssh_info: SSH connection info
655
+ tool: Tool name from TOOL_REGISTRY
656
+ force: If True, reinstall even if present
657
+ timeout: Timeout for install command
658
+
659
+ Returns:
660
+ EnsureResult with status
661
+ """
662
+ if tool not in TOOL_REGISTRY:
663
+ return EnsureResult(
664
+ tool=tool,
665
+ already_installed=False,
666
+ installed=False,
667
+ verified=False,
668
+ error=f"Unknown tool: {tool}. Available: {', '.join(sorted(TOOL_REGISTRY.keys()))}",
669
+ )
670
+
671
+ spec = TOOL_REGISTRY[tool]
672
+
673
+ # Check if already installed
674
+ if not force:
675
+ exit_code = exec_on_target_sync(ssh_info, spec.check_cmd, timeout_seconds=30)
676
+ if exit_code == 0:
677
+ return EnsureResult(
678
+ tool=tool,
679
+ already_installed=True,
680
+ installed=False,
681
+ verified=True,
682
+ )
683
+
684
+ # Can't auto-install
685
+ if spec.install_cmd is None:
686
+ return EnsureResult(
687
+ tool=tool,
688
+ already_installed=False,
689
+ installed=False,
690
+ verified=False,
691
+ error=f"{tool} cannot be auto-installed. It's part of the base platform (ROCm/CUDA).",
692
+ )
693
+
694
+ # Install
695
+ exit_code = exec_on_target_sync(ssh_info, spec.install_cmd, timeout_seconds=timeout)
696
+ if exit_code != 0:
697
+ return EnsureResult(
698
+ tool=tool,
699
+ already_installed=False,
700
+ installed=False,
701
+ verified=False,
702
+ error=f"Installation failed (exit code {exit_code})",
703
+ )
704
+
705
+ # Verify
706
+ verified = True
707
+ if spec.verify_cmd:
708
+ exit_code = exec_on_target_sync(ssh_info, spec.verify_cmd, timeout_seconds=30)
709
+ verified = exit_code == 0
710
+
711
+ return EnsureResult(
712
+ tool=tool,
713
+ already_installed=False,
714
+ installed=True,
715
+ verified=verified,
716
+ error=None if verified else "Installation succeeded but verification failed",
717
+ )
File without changes