wafer-cli 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/targets_ops.py ADDED
@@ -0,0 +1,718 @@
1
+ """Target operations for exec/ssh/sync commands.
2
+
3
+ This module provides the business logic for running commands on targets,
4
+ getting SSH credentials, and syncing files. It handles:
5
+ - RunPod: Auto-provision pod, get SSH credentials
6
+ - DigitalOcean: Auto-provision droplet, get SSH credentials
7
+ - Baremetal/VM: Direct SSH with configured credentials
8
+ - Workspace: Delegate to workspace API
9
+ - Modal/Local: Not supported (no SSH access)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import shlex
16
+ import subprocess
17
+ from collections.abc import Callable
18
+ from dataclasses import dataclass, replace
19
+ from pathlib import Path
20
+ from typing import TYPE_CHECKING
21
+
22
+ if TYPE_CHECKING:
23
+ from wafer_core.utils.kernel_utils.targets.config import (
24
+ BaremetalTarget,
25
+ DigitalOceanTarget,
26
+ RunPodTarget,
27
+ TargetConfig,
28
+ VMTarget,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class TargetSSHInfo:
36
+ """SSH connection info for a target."""
37
+
38
+ host: str
39
+ port: int
40
+ user: str
41
+ key_path: Path
42
+
43
+
44
+ class TargetExecError(Exception):
45
+ """Error during target operation (exec/ssh/sync)."""
46
+
47
+ pass
48
+
49
+
50
+ def _expand_key_path(ssh_key: str) -> Path:
51
+ """Expand SSH key path (synchronous, fast operation)."""
52
+ return Path(ssh_key).expanduser()
53
+
54
+
55
+ def _parse_ssh_target(ssh_target: str) -> tuple[str, str, int]:
56
+ """Parse ssh_target string into (user, host, port).
57
+
58
+ Format: user@host:port
59
+ """
60
+ # Split user@host:port
61
+ if "@" not in ssh_target:
62
+ raise ValueError(f"Invalid ssh_target format: {ssh_target} (expected user@host:port)")
63
+
64
+ user, rest = ssh_target.split("@", 1)
65
+
66
+ if ":" not in rest:
67
+ raise ValueError(f"Invalid ssh_target format: {ssh_target} (expected user@host:port)")
68
+
69
+ host, port_str = rest.rsplit(":", 1)
70
+
71
+ try:
72
+ port = int(port_str)
73
+ except ValueError as e:
74
+ raise ValueError(f"Invalid port in ssh_target: {port_str}") from e
75
+
76
+ return user, host, port
77
+
78
+
79
+ async def get_target_ssh_info(target: TargetConfig) -> TargetSSHInfo:
80
+ """Get SSH connection info for a target.
81
+
82
+ For RunPod/DigitalOcean: Provisions if needed, returns SSH info.
83
+ For Baremetal/VM: Returns configured SSH info directly.
84
+ For Modal/Local/Workspace: Raises (no SSH access).
85
+
86
+ Args:
87
+ target: Target configuration
88
+
89
+ Returns:
90
+ TargetSSHInfo with host, port, user, key_path
91
+
92
+ Raises:
93
+ TargetExecError: If target type doesn't support SSH
94
+ """
95
+ from wafer_core.utils.kernel_utils.targets.config import (
96
+ BaremetalTarget,
97
+ DigitalOceanTarget,
98
+ LocalTarget,
99
+ ModalTarget,
100
+ RunPodTarget,
101
+ VMTarget,
102
+ WorkspaceTarget,
103
+ )
104
+
105
+ if isinstance(target, RunPodTarget):
106
+ return await _get_runpod_ssh_info(target)
107
+ elif isinstance(target, DigitalOceanTarget):
108
+ return await _get_digitalocean_ssh_info(target)
109
+ elif isinstance(target, (BaremetalTarget, VMTarget)):
110
+ return _get_direct_ssh_info(target)
111
+ elif isinstance(target, WorkspaceTarget):
112
+ raise TargetExecError(
113
+ f"WorkspaceTarget '{target.name}' uses API-based access.\n"
114
+ "Use 'wafer workspaces exec/ssh/sync' instead."
115
+ )
116
+ elif isinstance(target, ModalTarget):
117
+ raise TargetExecError(
118
+ f"ModalTarget '{target.name}' is serverless and has no SSH access.\n"
119
+ "Use 'wafer evaluate' to run code on Modal targets."
120
+ )
121
+ elif isinstance(target, LocalTarget):
122
+ raise TargetExecError(
123
+ f"LocalTarget '{target.name}' runs locally and has no SSH.\n"
124
+ "Run commands directly on this machine."
125
+ )
126
+ else:
127
+ raise TargetExecError(f"Unknown target type: {type(target).__name__}")
128
+
129
+
130
+ async def _get_runpod_ssh_info(target: RunPodTarget) -> TargetSSHInfo:
131
+ """Get SSH info for RunPod target, provisioning if needed."""
132
+ from wafer_core.targets.runpod import check_pod_running, get_pod_state, runpod_ssh_context
133
+
134
+ key_path = _expand_key_path(target.ssh_key)
135
+
136
+ # Check if pod already exists and is running
137
+ existing = get_pod_state(target.name)
138
+ if existing and await check_pod_running(existing.pod_id):
139
+ # Reuse existing pod
140
+ return TargetSSHInfo(
141
+ host=existing.public_ip,
142
+ port=existing.ssh_port,
143
+ user=existing.ssh_username,
144
+ key_path=key_path,
145
+ )
146
+
147
+ # Need to provision - use the context manager but don't terminate
148
+ # We'll provision and keep the pod running for the exec/ssh/sync operation
149
+ # The user can run `wafer config targets cleanup` to terminate later
150
+
151
+ # Temporarily override keep_alive to True so we don't terminate after getting info
152
+ target_keep_alive = replace(target, keep_alive=True)
153
+
154
+ async with runpod_ssh_context(target_keep_alive) as ssh_info:
155
+ return TargetSSHInfo(
156
+ host=ssh_info.host,
157
+ port=ssh_info.port,
158
+ user=ssh_info.user,
159
+ key_path=key_path,
160
+ )
161
+
162
+
163
+ async def _get_digitalocean_ssh_info(target: DigitalOceanTarget) -> TargetSSHInfo:
164
+ """Get SSH info for DigitalOcean target, provisioning if needed."""
165
+ from wafer_core.targets.digitalocean import (
166
+ check_droplet_running,
167
+ digitalocean_ssh_context,
168
+ get_droplet_state,
169
+ )
170
+
171
+ key_path = _expand_key_path(target.ssh_key)
172
+
173
+ # Check if droplet already exists and is running
174
+ existing = get_droplet_state(target.name)
175
+ if existing and await check_droplet_running(existing.droplet_id):
176
+ # Reuse existing droplet
177
+ return TargetSSHInfo(
178
+ host=existing.public_ip,
179
+ port=22, # DigitalOcean uses standard SSH port
180
+ user=existing.ssh_username,
181
+ key_path=key_path,
182
+ )
183
+
184
+ # Need to provision - use the context manager but don't terminate
185
+ target_keep_alive = replace(target, keep_alive=True)
186
+
187
+ async with digitalocean_ssh_context(target_keep_alive) as ssh_info:
188
+ return TargetSSHInfo(
189
+ host=ssh_info.host,
190
+ port=ssh_info.port,
191
+ user=ssh_info.user,
192
+ key_path=key_path,
193
+ )
194
+
195
+
196
+ def _get_direct_ssh_info(target: BaremetalTarget | VMTarget) -> TargetSSHInfo:
197
+ """Get SSH info for Baremetal/VM target (no provisioning needed)."""
198
+ user, host, port = _parse_ssh_target(target.ssh_target)
199
+ key_path = _expand_key_path(target.ssh_key)
200
+
201
+ if not key_path.exists():
202
+ raise TargetExecError(f"SSH key not found: {key_path}")
203
+
204
+ return TargetSSHInfo(
205
+ host=host,
206
+ port=port,
207
+ user=user,
208
+ key_path=key_path,
209
+ )
210
+
211
+
212
+ def exec_on_target_sync(
213
+ ssh_info: TargetSSHInfo,
214
+ command: str,
215
+ timeout_seconds: int | None = None,
216
+ ) -> int:
217
+ """Execute a command on target via SSH (synchronous).
218
+
219
+ Args:
220
+ ssh_info: SSH connection info
221
+ command: Command to execute
222
+ timeout_seconds: Optional timeout
223
+
224
+ Returns:
225
+ Exit code from the remote command
226
+ """
227
+ ssh_args = [
228
+ "ssh",
229
+ "-i",
230
+ str(ssh_info.key_path),
231
+ "-p",
232
+ str(ssh_info.port),
233
+ "-o",
234
+ "StrictHostKeyChecking=no",
235
+ "-o",
236
+ "UserKnownHostsFile=/dev/null",
237
+ "-o",
238
+ "LogLevel=ERROR",
239
+ f"{ssh_info.user}@{ssh_info.host}",
240
+ command,
241
+ ]
242
+
243
+ try:
244
+ result = subprocess.run(
245
+ ssh_args,
246
+ timeout=timeout_seconds,
247
+ )
248
+ return result.returncode
249
+ except subprocess.TimeoutExpired as e:
250
+ raise TargetExecError(f"Command timed out after {timeout_seconds}s") from e
251
+
252
+
253
+ def sync_to_target(
254
+ ssh_info: TargetSSHInfo,
255
+ local_path: Path,
256
+ remote_path: str | None = None,
257
+ on_progress: Callable[[str], None] | None = None,
258
+ ) -> int:
259
+ """Sync files to target via rsync over SSH.
260
+
261
+ Args:
262
+ ssh_info: SSH connection info
263
+ local_path: Local file or directory to sync
264
+ remote_path: Remote destination (default: /tmp/{basename})
265
+ on_progress: Optional callback for progress messages
266
+
267
+ Returns:
268
+ Number of files synced
269
+ """
270
+ if remote_path is None:
271
+ remote_path = f"/tmp/{local_path.name}"
272
+
273
+ # Build rsync command
274
+ ssh_cmd = (
275
+ f"ssh -i {ssh_info.key_path} -p {ssh_info.port} "
276
+ f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR"
277
+ )
278
+
279
+ # Add trailing slash to sync directory contents
280
+ source = str(local_path.resolve())
281
+ if local_path.is_dir():
282
+ source = source.rstrip("/") + "/"
283
+
284
+ rsync_args = [
285
+ "rsync",
286
+ "-avz",
287
+ "--progress",
288
+ "-e",
289
+ ssh_cmd,
290
+ source,
291
+ f"{ssh_info.user}@{ssh_info.host}:{remote_path}",
292
+ ]
293
+
294
+ if on_progress:
295
+ on_progress(f"Syncing {local_path} to {ssh_info.host}:{remote_path}")
296
+
297
+ result = subprocess.run(
298
+ rsync_args,
299
+ capture_output=True,
300
+ text=True,
301
+ )
302
+
303
+ if result.returncode != 0:
304
+ raise TargetExecError(f"rsync failed: {result.stderr}")
305
+
306
+ # Count files from rsync output (lines that don't start with special chars)
307
+ file_count = 0
308
+ for line in result.stdout.splitlines():
309
+ # rsync shows transferred files without leading special chars
310
+ if line and not line.startswith((" ", ".", "sent", "total", "building")):
311
+ file_count += 1
312
+
313
+ if on_progress:
314
+ on_progress(f"Synced {file_count} files")
315
+
316
+ return file_count
317
+
318
+
319
+ def parse_scp_path(path: str) -> tuple[str | None, str]:
320
+ """Parse scp-style path into (target_name, path).
321
+
322
+ Returns (None, path) for local paths, (target_name, remote_path) for remote.
323
+
324
+ Examples:
325
+ "./local/file" -> (None, "./local/file")
326
+ "target:/remote/path" -> ("target", "/remote/path")
327
+ "my-target:/tmp/foo" -> ("my-target", "/tmp/foo")
328
+ """
329
+ if ":" in path:
330
+ # Check if it looks like a Windows path (e.g., C:\...)
331
+ if len(path) >= 2 and path[1] == ":" and path[0].isalpha():
332
+ return (None, path)
333
+ target, remote_path = path.split(":", 1)
334
+ return (target, remote_path)
335
+ return (None, path)
336
+
337
+
338
+ def _has_glob_chars(path: str) -> bool:
339
+ """Check if path contains glob characters."""
340
+ return any(c in path for c in "*?[]")
341
+
342
+
343
+ def _sanitize_glob_pattern(pattern: str) -> str:
344
+ """Sanitize a glob pattern for safe shell execution.
345
+
346
+ Escapes dangerous shell metacharacters while preserving glob characters (* ? [ ]).
347
+ This prevents command injection while allowing glob expansion.
348
+ """
349
+ # Characters that could enable command injection
350
+ dangerous_chars = {
351
+ ";": r"\;",
352
+ "$": r"\$",
353
+ "`": r"\`",
354
+ "|": r"\|",
355
+ "&": r"\&",
356
+ "(": r"\(",
357
+ ")": r"\)",
358
+ "{": r"\{",
359
+ "}": r"\}",
360
+ "<": r"\<",
361
+ ">": r"\>",
362
+ "\n": "", # Remove newlines entirely
363
+ "\r": "",
364
+ }
365
+ result = pattern
366
+ for char, escaped in dangerous_chars.items():
367
+ result = result.replace(char, escaped)
368
+ return result
369
+
370
+
371
+ def _expand_remote_glob(ssh_info: TargetSSHInfo, pattern: str) -> list[str]:
372
+ """Expand a glob pattern on the remote host.
373
+
374
+ Returns list of matching file paths, empty if no matches.
375
+ """
376
+ # Sanitize pattern to prevent command injection while preserving glob chars
377
+ safe_pattern = _sanitize_glob_pattern(pattern)
378
+
379
+ # Use ls -1d to expand glob (handles files and dirs, one per line)
380
+ # The -d flag prevents listing directory contents
381
+ ssh_args = [
382
+ "ssh",
383
+ "-i",
384
+ str(ssh_info.key_path),
385
+ "-p",
386
+ str(ssh_info.port),
387
+ "-o",
388
+ "StrictHostKeyChecking=no",
389
+ "-o",
390
+ "UserKnownHostsFile=/dev/null",
391
+ "-o",
392
+ "LogLevel=ERROR",
393
+ f"{ssh_info.user}@{ssh_info.host}",
394
+ f"ls -1d {safe_pattern} 2>/dev/null",
395
+ ]
396
+
397
+ result = subprocess.run(ssh_args, capture_output=True, text=True)
398
+
399
+ if result.returncode != 0 or not result.stdout.strip():
400
+ return []
401
+
402
+ return result.stdout.strip().split("\n")
403
+
404
+
405
+ def _scp_single_file(
406
+ ssh_info: TargetSSHInfo,
407
+ remote_path: str,
408
+ local_dest: str,
409
+ recursive: bool,
410
+ ) -> None:
411
+ """Download a single file/dir from remote."""
412
+ scp_args = [
413
+ "scp",
414
+ "-i",
415
+ str(ssh_info.key_path),
416
+ "-P",
417
+ str(ssh_info.port),
418
+ "-o",
419
+ "StrictHostKeyChecking=no",
420
+ "-o",
421
+ "UserKnownHostsFile=/dev/null",
422
+ "-o",
423
+ "LogLevel=ERROR",
424
+ ]
425
+
426
+ if recursive:
427
+ scp_args.append("-r")
428
+
429
+ scp_args.extend([
430
+ f"{ssh_info.user}@{ssh_info.host}:{remote_path}",
431
+ local_dest,
432
+ ])
433
+
434
+ result = subprocess.run(scp_args, capture_output=True, text=True)
435
+ if result.returncode != 0:
436
+ raise TargetExecError(f"scp failed for {remote_path}: {result.stderr}")
437
+
438
+
439
+ def _scp_glob_download(
440
+ ssh_info: TargetSSHInfo,
441
+ remote_pattern: str,
442
+ local_dest: str,
443
+ recursive: bool,
444
+ ) -> None:
445
+ """Download files matching a glob pattern from remote.
446
+
447
+ Expands the glob on the remote host, then downloads each file.
448
+ """
449
+ files = _expand_remote_glob(ssh_info, remote_pattern)
450
+
451
+ if not files:
452
+ logger.warning(f"No files matched pattern: {remote_pattern}")
453
+ return
454
+
455
+ for remote_file in files:
456
+ _scp_single_file(ssh_info, remote_file, local_dest, recursive)
457
+
458
+
459
+ def scp_transfer(
460
+ ssh_info: TargetSSHInfo,
461
+ source: str,
462
+ dest: str,
463
+ is_download: bool,
464
+ recursive: bool = False,
465
+ ) -> None:
466
+ """Transfer files via scp. Supports glob patterns for downloads.
467
+
468
+ Args:
469
+ ssh_info: SSH connection info
470
+ source: Source path (local for upload, remote for download)
471
+ dest: Destination path (remote for upload, local for download)
472
+ is_download: True if downloading from remote, False if uploading
473
+ recursive: Whether to copy directories recursively
474
+
475
+ Raises:
476
+ TargetExecError: If scp fails
477
+ """
478
+ # Handle glob patterns for downloads
479
+ if is_download and _has_glob_chars(source):
480
+ return _scp_glob_download(ssh_info, source, dest, recursive)
481
+
482
+ scp_args = [
483
+ "scp",
484
+ "-i",
485
+ str(ssh_info.key_path),
486
+ "-P",
487
+ str(ssh_info.port),
488
+ "-o",
489
+ "StrictHostKeyChecking=no",
490
+ "-o",
491
+ "UserKnownHostsFile=/dev/null",
492
+ "-o",
493
+ "LogLevel=ERROR",
494
+ ]
495
+
496
+ if recursive:
497
+ scp_args.append("-r")
498
+
499
+ if is_download:
500
+ # remote -> local
501
+ scp_args.extend([
502
+ f"{ssh_info.user}@{ssh_info.host}:{source}",
503
+ dest,
504
+ ])
505
+ else:
506
+ # local -> remote
507
+ scp_args.extend([
508
+ source,
509
+ f"{ssh_info.user}@{ssh_info.host}:{dest}",
510
+ ])
511
+
512
+ result = subprocess.run(scp_args, capture_output=True, text=True)
513
+ if result.returncode != 0:
514
+ raise TargetExecError(f"scp failed: {result.stderr}")
515
+
516
+
517
+ # =============================================================================
518
+ # Tool Registry for `wafer targets ensure`
519
+ # =============================================================================
520
+
521
+
522
+ @dataclass(frozen=True)
523
+ class ToolSpec:
524
+ """Specification for a tool that can be installed on a target."""
525
+
526
+ name: str
527
+ check_cmd: str # Command to check if installed (exit 0 = installed)
528
+ install_cmd: str | None # Command to install (None = can't auto-install)
529
+ verify_cmd: str | None = None # Command to verify after install
530
+ platform: str = "any" # "amd", "nvidia", or "any"
531
+ description: str = ""
532
+
533
+
534
+ TOOL_REGISTRY: dict[str, ToolSpec] = {
535
+ # AMD Tools
536
+ "rocprof-compute": ToolSpec(
537
+ name="rocprof-compute",
538
+ check_cmd="which rocprof-compute",
539
+ # rocprofiler-compute requires ROCm >= 6.3 and apt install (not pip)
540
+ # For older ROCm, users need to upgrade or install manually
541
+ install_cmd="apt-get update && apt-get install -y rocprofiler-compute && python3 -m pip install -r /opt/rocm/libexec/rocprofiler-compute/requirements.txt",
542
+ verify_cmd="rocprof-compute --version",
543
+ platform="amd",
544
+ description="AMD GPU profiling (roofline, memory, etc.) - requires ROCm >= 6.3",
545
+ ),
546
+ "rocprof-systems": ToolSpec(
547
+ name="rocprof-systems",
548
+ check_cmd="which rocprof-systems",
549
+ # rocprofiler-systems also requires apt install on ROCm >= 6.3
550
+ install_cmd="apt-get update && apt-get install -y rocprofiler-systems && python3 -m pip install -r /opt/rocm/libexec/rocprofiler-systems/requirements.txt",
551
+ verify_cmd="rocprof-systems --version",
552
+ platform="amd",
553
+ description="AMD system-wide tracing - requires ROCm >= 6.3",
554
+ ),
555
+ "rocprof": ToolSpec(
556
+ name="rocprof",
557
+ check_cmd="which rocprof",
558
+ install_cmd=None, # Part of ROCm base install
559
+ platform="amd",
560
+ description="AMD kernel profiling (part of ROCm)",
561
+ ),
562
+ # NVIDIA Tools
563
+ "ncu": ToolSpec(
564
+ name="ncu",
565
+ check_cmd="which ncu",
566
+ install_cmd=None, # Part of CUDA toolkit
567
+ platform="nvidia",
568
+ description="NVIDIA Nsight Compute (part of CUDA toolkit)",
569
+ ),
570
+ "nsys": ToolSpec(
571
+ name="nsys",
572
+ check_cmd="which nsys",
573
+ install_cmd=None, # Part of CUDA toolkit
574
+ platform="nvidia",
575
+ description="NVIDIA Nsight Systems (part of CUDA toolkit)",
576
+ ),
577
+ "nvtx": ToolSpec(
578
+ name="nvtx",
579
+ check_cmd='python -c "import nvtx"',
580
+ install_cmd="pip install nvtx",
581
+ verify_cmd='python -c "import nvtx; print(nvtx.__version__)"',
582
+ platform="nvidia",
583
+ description="NVIDIA Tools Extension (Python)",
584
+ ),
585
+ # Cross-platform Python packages
586
+ "triton": ToolSpec(
587
+ name="triton",
588
+ check_cmd='python -c "import triton"',
589
+ install_cmd="pip install triton",
590
+ verify_cmd='python -c "import triton; print(triton.__version__)"',
591
+ platform="any",
592
+ description="OpenAI Triton compiler",
593
+ ),
594
+ "torch": ToolSpec(
595
+ name="torch",
596
+ check_cmd='python -c "import torch"',
597
+ install_cmd="pip install torch",
598
+ verify_cmd='python -c "import torch; print(torch.__version__)"',
599
+ platform="any",
600
+ description="PyTorch",
601
+ ),
602
+ }
603
+
604
+
605
+ def get_target_platform(target: TargetConfig) -> str:
606
+ """Determine platform (amd/nvidia) from target config."""
607
+ # Import target types for isinstance checks
608
+ from wafer_core.utils.kernel_utils.targets.config import (
609
+ DigitalOceanTarget,
610
+ LocalTarget,
611
+ RunPodTarget,
612
+ )
613
+
614
+ # RunPod and DigitalOcean are always AMD MI300X
615
+ if isinstance(target, (RunPodTarget, DigitalOceanTarget)):
616
+ return "amd"
617
+
618
+ # LocalTarget has explicit vendor field
619
+ if isinstance(target, LocalTarget):
620
+ return target.vendor
621
+
622
+ # For Baremetal/VM, check gpu_type or compute_capability
623
+ gpu_type = getattr(target, "gpu_type", "")
624
+ if "MI300" in gpu_type:
625
+ return "amd"
626
+
627
+ compute_cap = getattr(target, "compute_capability", "")
628
+ if compute_cap == "9.4": # gfx942 = MI300X
629
+ return "amd"
630
+
631
+ # Default to nvidia for other compute capabilities
632
+ return "nvidia"
633
+
634
+
635
+ @dataclass
636
+ class EnsureResult:
637
+ """Result of ensure_tool operation."""
638
+
639
+ tool: str
640
+ already_installed: bool
641
+ installed: bool
642
+ verified: bool
643
+ error: str | None = None
644
+
645
+
646
+ def ensure_tool(
647
+ ssh_info: TargetSSHInfo,
648
+ tool: str,
649
+ force: bool = False,
650
+ timeout: int = 300,
651
+ ) -> EnsureResult:
652
+ """Ensure a tool is installed on target.
653
+
654
+ Args:
655
+ ssh_info: SSH connection info
656
+ tool: Tool name from TOOL_REGISTRY
657
+ force: If True, reinstall even if present
658
+ timeout: Timeout for install command
659
+
660
+ Returns:
661
+ EnsureResult with status
662
+ """
663
+ if tool not in TOOL_REGISTRY:
664
+ return EnsureResult(
665
+ tool=tool,
666
+ already_installed=False,
667
+ installed=False,
668
+ verified=False,
669
+ error=f"Unknown tool: {tool}. Available: {', '.join(sorted(TOOL_REGISTRY.keys()))}",
670
+ )
671
+
672
+ spec = TOOL_REGISTRY[tool]
673
+
674
+ # Check if already installed
675
+ if not force:
676
+ exit_code = exec_on_target_sync(ssh_info, spec.check_cmd, timeout_seconds=30)
677
+ if exit_code == 0:
678
+ return EnsureResult(
679
+ tool=tool,
680
+ already_installed=True,
681
+ installed=False,
682
+ verified=True,
683
+ )
684
+
685
+ # Can't auto-install
686
+ if spec.install_cmd is None:
687
+ return EnsureResult(
688
+ tool=tool,
689
+ already_installed=False,
690
+ installed=False,
691
+ verified=False,
692
+ error=f"{tool} cannot be auto-installed. It's part of the base platform (ROCm/CUDA).",
693
+ )
694
+
695
+ # Install
696
+ exit_code = exec_on_target_sync(ssh_info, spec.install_cmd, timeout_seconds=timeout)
697
+ if exit_code != 0:
698
+ return EnsureResult(
699
+ tool=tool,
700
+ already_installed=False,
701
+ installed=False,
702
+ verified=False,
703
+ error=f"Installation failed (exit code {exit_code})",
704
+ )
705
+
706
+ # Verify
707
+ verified = True
708
+ if spec.verify_cmd:
709
+ exit_code = exec_on_target_sync(ssh_info, spec.verify_cmd, timeout_seconds=30)
710
+ verified = exit_code == 0
711
+
712
+ return EnsureResult(
713
+ tool=tool,
714
+ already_installed=False,
715
+ installed=True,
716
+ verified=verified,
717
+ error=None if verified else "Installation succeeded but verification failed",
718
+ )