wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/targets.py ADDED
@@ -0,0 +1,842 @@
1
+ """Target management for Wafer CLI.
2
+
3
+ CRUD operations for GPU targets stored in ~/.wafer/targets/.
4
+ """
5
+
6
+ import tomllib
7
+ from dataclasses import asdict
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from wafer_core.utils.kernel_utils.targets.config import (
12
+ BaremetalTarget,
13
+ DigitalOceanTarget,
14
+ ModalTarget,
15
+ RunPodTarget,
16
+ TargetConfig,
17
+ VMTarget,
18
+ WorkspaceTarget,
19
+ )
20
+
21
+ # Default paths
22
+ WAFER_DIR = Path.home() / ".wafer"
23
+ TARGETS_DIR = WAFER_DIR / "targets"
24
+ CONFIG_FILE = WAFER_DIR / "config.toml"
25
+
26
+
27
+ def _ensure_dirs() -> None:
28
+ """Ensure ~/.wafer/targets/ exists."""
29
+ TARGETS_DIR.mkdir(parents=True, exist_ok=True)
30
+
31
+
32
+ def _target_path(name: str) -> Path:
33
+ """Get path to target config file."""
34
+ return TARGETS_DIR / f"{name}.toml"
35
+
36
+
37
+ def _parse_target(data: dict[str, Any]) -> TargetConfig:
38
+ """Parse TOML dict into target dataclass.
39
+
40
+ Args:
41
+ data: Parsed TOML data
42
+
43
+ Returns:
44
+ TargetConfig (BaremetalTarget, VMTarget, ModalTarget, or WorkspaceTarget)
45
+
46
+ Raises:
47
+ ValueError: If target type is unknown or required fields missing
48
+ """
49
+ target_type = data.get("type")
50
+ if not target_type:
51
+ raise ValueError(
52
+ "Target must have 'type' field (baremetal, vm, modal, workspace, runpod, or digitalocean)"
53
+ )
54
+
55
+ # Remove type field before passing to dataclass
56
+ data_copy = {k: v for k, v in data.items() if k != "type"}
57
+
58
+ # Convert pip_packages list to tuple (TOML parses as list, dataclass expects tuple)
59
+ if "pip_packages" in data_copy and isinstance(data_copy["pip_packages"], list):
60
+ data_copy["pip_packages"] = tuple(data_copy["pip_packages"])
61
+
62
+ # Convert gpu_ids list to tuple for RunPodTarget
63
+ if "gpu_ids" in data_copy and isinstance(data_copy["gpu_ids"], list):
64
+ data_copy["gpu_ids"] = tuple(data_copy["gpu_ids"])
65
+
66
+ if target_type == "baremetal":
67
+ return BaremetalTarget(**data_copy)
68
+ elif target_type == "vm":
69
+ return VMTarget(**data_copy)
70
+ elif target_type == "modal":
71
+ return ModalTarget(**data_copy)
72
+ elif target_type == "workspace":
73
+ return WorkspaceTarget(**data_copy)
74
+ elif target_type == "runpod":
75
+ return RunPodTarget(**data_copy)
76
+ elif target_type == "digitalocean":
77
+ return DigitalOceanTarget(**data_copy)
78
+ else:
79
+ raise ValueError(
80
+ f"Unknown target type: {target_type}. Must be baremetal, vm, modal, workspace, runpod, or digitalocean"
81
+ )
82
+
83
+
84
+ def _serialize_target(target: TargetConfig) -> dict[str, Any]:
85
+ """Serialize target dataclass to TOML-compatible dict.
86
+
87
+ Args:
88
+ target: Target config
89
+
90
+ Returns:
91
+ Dict with 'type' field added
92
+ """
93
+ data = asdict(target)
94
+
95
+ # Add type field
96
+ if isinstance(target, BaremetalTarget):
97
+ data["type"] = "baremetal"
98
+ elif isinstance(target, VMTarget):
99
+ data["type"] = "vm"
100
+ elif isinstance(target, ModalTarget):
101
+ data["type"] = "modal"
102
+ elif isinstance(target, WorkspaceTarget):
103
+ data["type"] = "workspace"
104
+ elif isinstance(target, RunPodTarget):
105
+ data["type"] = "runpod"
106
+ elif isinstance(target, DigitalOceanTarget):
107
+ data["type"] = "digitalocean"
108
+
109
+ # Convert pip_packages tuple to list for TOML serialization
110
+ if "pip_packages" in data and isinstance(data["pip_packages"], tuple):
111
+ data["pip_packages"] = list(data["pip_packages"])
112
+
113
+ # Convert gpu_ids tuple to list for TOML serialization
114
+ if "gpu_ids" in data and isinstance(data["gpu_ids"], tuple):
115
+ data["gpu_ids"] = list(data["gpu_ids"])
116
+
117
+ # Remove empty pip_packages to keep config clean
118
+ if "pip_packages" in data and not data["pip_packages"]:
119
+ del data["pip_packages"]
120
+
121
+ return data
122
+
123
+
124
+ def _write_toml(path: Path, data: dict[str, Any]) -> None:
125
+ """Write dict as TOML file.
126
+
127
+ Simple TOML writer - handles flat dicts and lists.
128
+ """
129
+ lines = []
130
+ for key, value in data.items():
131
+ if value is None:
132
+ continue # Skip None values
133
+ if isinstance(value, bool):
134
+ lines.append(f"{key} = {str(value).lower()}")
135
+ elif isinstance(value, int | float):
136
+ lines.append(f"{key} = {value}")
137
+ elif isinstance(value, str):
138
+ lines.append(f'{key} = "{value}"')
139
+ elif isinstance(value, list):
140
+ # Format list
141
+ if all(isinstance(v, int) for v in value):
142
+ lines.append(f"{key} = {value}")
143
+ else:
144
+ formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
145
+ lines.append(f"{key} = [{formatted}]")
146
+
147
+ path.write_text("\n".join(lines) + "\n")
148
+
149
+
150
+ def load_target(name: str) -> TargetConfig:
151
+ """Load target config by name.
152
+
153
+ Args:
154
+ name: Target name (filename without .toml)
155
+
156
+ Returns:
157
+ Target config
158
+
159
+ Raises:
160
+ FileNotFoundError: If target doesn't exist
161
+ ValueError: If target config is invalid
162
+ """
163
+ path = _target_path(name)
164
+ if not path.exists():
165
+ raise FileNotFoundError(f"Target not found: {name} (looked in {path})")
166
+
167
+ with open(path, "rb") as f:
168
+ data = tomllib.load(f)
169
+
170
+ return _parse_target(data)
171
+
172
+
173
+ def save_target(target: TargetConfig | dict[str, Any]) -> TargetConfig:
174
+ """Save target config.
175
+
176
+ Args:
177
+ target: Target config (TargetConfig object or dict with 'type' field)
178
+
179
+ Returns:
180
+ The saved TargetConfig object
181
+
182
+ Creates ~/.wafer/targets/{name}.toml
183
+ """
184
+ _ensure_dirs()
185
+
186
+ # If dict, parse into TargetConfig first
187
+ if isinstance(target, dict):
188
+ target = _parse_target(target)
189
+
190
+ data = _serialize_target(target)
191
+ path = _target_path(target.name)
192
+ _write_toml(path, data)
193
+ return target
194
+
195
+
196
+ def add_target_from_file(file_path: Path) -> TargetConfig:
197
+ """Add target from TOML file.
198
+
199
+ Args:
200
+ file_path: Path to TOML file
201
+
202
+ Returns:
203
+ Parsed and saved target
204
+
205
+ Raises:
206
+ FileNotFoundError: If file doesn't exist
207
+ ValueError: If file is invalid
208
+ """
209
+ if not file_path.exists():
210
+ raise FileNotFoundError(f"File not found: {file_path}")
211
+
212
+ with open(file_path, "rb") as f:
213
+ data = tomllib.load(f)
214
+
215
+ target = _parse_target(data)
216
+ save_target(target)
217
+ return target
218
+
219
+
220
+ def list_targets() -> list[str]:
221
+ """List all configured target names.
222
+
223
+ Returns:
224
+ Sorted list of target names
225
+ """
226
+ _ensure_dirs()
227
+ return sorted(p.stem for p in TARGETS_DIR.glob("*.toml"))
228
+
229
+
230
+ def remove_target(name: str) -> None:
231
+ """Remove target config.
232
+
233
+ Args:
234
+ name: Target name to remove
235
+
236
+ Raises:
237
+ FileNotFoundError: If target doesn't exist
238
+ """
239
+ path = _target_path(name)
240
+ if not path.exists():
241
+ raise FileNotFoundError(f"Target not found: {name}")
242
+ path.unlink()
243
+
244
+
245
+ def get_default_target() -> str | None:
246
+ """Get default target name from config.
247
+
248
+ Returns:
249
+ Default target name, or None if not set
250
+ """
251
+ if not CONFIG_FILE.exists():
252
+ return None
253
+
254
+ with open(CONFIG_FILE, "rb") as f:
255
+ data = tomllib.load(f)
256
+
257
+ return data.get("default_target")
258
+
259
+
260
+ # ── Pool Management ─────────────────────────────────────────────────────────
261
+
262
+
263
+ def get_pool(name: str) -> list[str]:
264
+ """Get list of targets in a named pool.
265
+
266
+ Pools are defined in ~/.wafer/config.toml:
267
+ [pools.my-pool]
268
+ targets = ["target-1", "target-2", "target-3"]
269
+
270
+ Args:
271
+ name: Pool name
272
+
273
+ Returns:
274
+ List of target names in the pool
275
+
276
+ Raises:
277
+ FileNotFoundError: If pool doesn't exist
278
+ """
279
+ if not CONFIG_FILE.exists():
280
+ raise FileNotFoundError(f"Pool not found: {name} (no config file)")
281
+
282
+ with open(CONFIG_FILE, "rb") as f:
283
+ data = tomllib.load(f)
284
+
285
+ pools = data.get("pools", {})
286
+ if name not in pools:
287
+ raise FileNotFoundError(
288
+ f"Pool not found: {name}\n"
289
+ f" Define pools in ~/.wafer/config.toml:\n"
290
+ f" [pools.{name}]\n"
291
+ f' targets = ["target-1", "target-2"]'
292
+ )
293
+
294
+ pool_config = pools[name]
295
+ targets = pool_config.get("targets", [])
296
+
297
+ if not targets:
298
+ raise ValueError(f"Pool '{name}' has no targets defined")
299
+
300
+ return targets
301
+
302
+
303
+ def get_target_type(name: str) -> str | None:
304
+ """Get the type of a target without fully loading it.
305
+
306
+ Args:
307
+ name: Target name
308
+
309
+ Returns:
310
+ Target type string (runpod, digitalocean, baremetal, etc.) or None if not found
311
+ """
312
+ path = _target_path(name)
313
+ if not path.exists():
314
+ return None
315
+
316
+ with open(path, "rb") as f:
317
+ data = tomllib.load(f)
318
+
319
+ return data.get("type")
320
+
321
+
322
+ def filter_pool_by_auth(target_names: list[str]) -> tuple[list[str], list[str]]:
323
+ """Filter pool targets to only those with valid authentication.
324
+
325
+ Args:
326
+ target_names: List of target names to filter
327
+
328
+ Returns:
329
+ Tuple of (usable_targets, skipped_targets)
330
+ """
331
+ from wafer_core.auth import get_api_key
332
+
333
+ usable = []
334
+ skipped = []
335
+
336
+ for name in target_names:
337
+ target_type = get_target_type(name)
338
+ if target_type is None:
339
+ # Target doesn't exist, skip it
340
+ skipped.append(name)
341
+ continue
342
+
343
+ # Check auth requirements by target type
344
+ if target_type == "runpod":
345
+ if not get_api_key("runpod"):
346
+ skipped.append(name)
347
+ continue
348
+ elif target_type == "digitalocean":
349
+ if not get_api_key("digitalocean"):
350
+ skipped.append(name)
351
+ continue
352
+ # Other types (baremetal, vm, workspace, modal) don't need runtime API keys
353
+
354
+ usable.append(name)
355
+
356
+ return usable, skipped
357
+
358
+
359
+ def list_pools() -> list[str]:
360
+ """List all configured pool names.
361
+
362
+ Returns:
363
+ Sorted list of pool names
364
+ """
365
+ if not CONFIG_FILE.exists():
366
+ return []
367
+
368
+ with open(CONFIG_FILE, "rb") as f:
369
+ data = tomllib.load(f)
370
+
371
+ return sorted(data.get("pools", {}).keys())
372
+
373
+
374
+ def save_pool(name: str, targets: list[str]) -> None:
375
+ """Save or update a pool configuration.
376
+
377
+ Args:
378
+ name: Pool name
379
+ targets: List of target names (must all exist)
380
+
381
+ Raises:
382
+ FileNotFoundError: If any target doesn't exist
383
+ """
384
+ # Verify all targets exist
385
+ existing_targets = list_targets()
386
+ missing = [t for t in targets if t not in existing_targets]
387
+ if missing:
388
+ raise FileNotFoundError(f"Targets not found: {', '.join(missing)}")
389
+
390
+ _ensure_dirs()
391
+
392
+ # Load existing config
393
+ if CONFIG_FILE.exists():
394
+ with open(CONFIG_FILE, "rb") as f:
395
+ data = tomllib.load(f)
396
+ else:
397
+ data = {}
398
+
399
+ # Update pools section
400
+ if "pools" not in data:
401
+ data["pools"] = {}
402
+
403
+ data["pools"][name] = {"targets": targets}
404
+
405
+ # Write back - need custom handling for nested structure
406
+ _write_config_with_pools(data)
407
+
408
+
409
+ def _write_config_with_pools(data: dict) -> None:
410
+ """Write config file with pools support.
411
+
412
+ Handles the nested [pools.name] TOML structure and preserves
413
+ existing nested sections like [default], [api], [environments.*].
414
+ """
415
+ lines = []
416
+
417
+ # Collect nested sections to write after top-level keys
418
+ nested_sections: dict[str, dict] = {}
419
+
420
+ # Write top-level keys first (except pools and nested dicts)
421
+ for key, value in data.items():
422
+ if key == "pools":
423
+ continue
424
+ if value is None:
425
+ continue
426
+ if isinstance(value, dict):
427
+ # Save nested sections for later
428
+ nested_sections[key] = value
429
+ elif isinstance(value, str):
430
+ lines.append(f'{key} = "{value}"')
431
+ elif isinstance(value, bool):
432
+ lines.append(f"{key} = {str(value).lower()}")
433
+ elif isinstance(value, int | float):
434
+ lines.append(f"{key} = {value}")
435
+ elif isinstance(value, list):
436
+ if all(isinstance(v, int) for v in value):
437
+ lines.append(f"{key} = {value}")
438
+ else:
439
+ formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
440
+ lines.append(f"{key} = [{formatted}]")
441
+
442
+ # Write nested sections (e.g., [default], [api], [environments.foo])
443
+ for section_name, section_data in nested_sections.items():
444
+ lines.append("")
445
+ lines.append(f"[{section_name}]")
446
+ for key, value in section_data.items():
447
+ if value is None:
448
+ continue
449
+ if isinstance(value, str):
450
+ lines.append(f'{key} = "{value}"')
451
+ elif isinstance(value, bool):
452
+ lines.append(f"{key} = {str(value).lower()}")
453
+ elif isinstance(value, int | float):
454
+ lines.append(f"{key} = {value}")
455
+ elif isinstance(value, list):
456
+ if all(isinstance(v, int) for v in value):
457
+ lines.append(f"{key} = {value}")
458
+ else:
459
+ formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
460
+ lines.append(f"{key} = [{formatted}]")
461
+
462
+ # Write pools
463
+ pools = data.get("pools", {})
464
+ for pool_name, pool_config in pools.items():
465
+ lines.append("")
466
+ lines.append(f"[pools.{pool_name}]")
467
+ targets = pool_config.get("targets", [])
468
+ formatted = ", ".join(f'"{t}"' for t in targets)
469
+ lines.append(f"targets = [{formatted}]")
470
+
471
+ CONFIG_FILE.write_text("\n".join(lines) + "\n")
472
+
473
+
474
+ def set_default_target(name: str) -> None:
475
+ """Set default target.
476
+
477
+ Args:
478
+ name: Target name (must exist)
479
+
480
+ Raises:
481
+ FileNotFoundError: If target doesn't exist
482
+ """
483
+ # Verify target exists
484
+ if name not in list_targets():
485
+ raise FileNotFoundError(f"Target not found: {name}")
486
+
487
+ _ensure_dirs()
488
+
489
+ # Load existing config or create new
490
+ if CONFIG_FILE.exists():
491
+ with open(CONFIG_FILE, "rb") as f:
492
+ data = tomllib.load(f)
493
+ else:
494
+ data = {}
495
+
496
+ data["default_target"] = name
497
+
498
+ # Write back (simple TOML)
499
+ _write_toml(CONFIG_FILE, data)
500
+
501
+
502
+ def get_target_info(target: TargetConfig) -> dict[str, str]:
503
+ """Get human-readable info about target.
504
+
505
+ Args:
506
+ target: Target config
507
+
508
+ Returns:
509
+ Dict of field name -> display value
510
+ """
511
+ info = {}
512
+
513
+ if isinstance(target, BaremetalTarget):
514
+ info["Type"] = "Baremetal"
515
+ info["SSH"] = target.ssh_target
516
+ info["GPUs"] = ", ".join(str(g) for g in target.gpu_ids)
517
+ info["NCU"] = "Yes" if target.ncu_available else "No"
518
+ # Docker info
519
+ if target.docker_image:
520
+ info["Docker"] = target.docker_image
521
+ if target.pip_packages:
522
+ info["Packages"] = ", ".join(target.pip_packages)
523
+ if target.torch_package:
524
+ info["Torch"] = target.torch_package
525
+ elif isinstance(target, VMTarget):
526
+ info["Type"] = "VM"
527
+ info["SSH"] = target.ssh_target
528
+ info["GPUs"] = ", ".join(str(g) for g in target.gpu_ids)
529
+ info["NCU"] = "Yes" if target.ncu_available else "No"
530
+ # Docker info
531
+ if target.docker_image:
532
+ info["Docker"] = target.docker_image
533
+ if target.pip_packages:
534
+ info["Packages"] = ", ".join(target.pip_packages)
535
+ if target.torch_package:
536
+ info["Torch"] = target.torch_package
537
+ elif isinstance(target, ModalTarget):
538
+ info["Type"] = "Modal"
539
+ info["App"] = target.modal_app_name
540
+ info["GPU"] = target.gpu_type
541
+ info["Timeout"] = f"{target.timeout_seconds}s"
542
+ info["NCU"] = "No (Modal)"
543
+ elif isinstance(target, WorkspaceTarget):
544
+ info["Type"] = "Workspace"
545
+ info["Workspace ID"] = target.workspace_id
546
+ info["GPU"] = target.gpu_type
547
+ info["Timeout"] = f"{target.timeout_seconds}s"
548
+ info["NCU"] = "No (Workspace)"
549
+ elif isinstance(target, RunPodTarget):
550
+ info["Type"] = "RunPod"
551
+ info["GPU Type"] = target.gpu_type_id
552
+ info["GPU Count"] = str(target.gpu_count)
553
+ info["Image"] = target.image
554
+ info["Keep Alive"] = "Yes" if target.keep_alive else "No"
555
+ info["NCU"] = "No (RunPod)"
556
+ elif isinstance(target, DigitalOceanTarget):
557
+ info["Type"] = "DigitalOcean"
558
+ info["Region"] = target.region
559
+ info["Size"] = target.size_slug
560
+ info["Image"] = target.image
561
+ info["Keep Alive"] = "Yes" if target.keep_alive else "No"
562
+ info["NCU"] = "No (DigitalOcean)"
563
+
564
+ info["Compute"] = target.compute_capability
565
+
566
+ return info
567
+
568
+
569
+ # Probe script to run on target - checks available backends
570
+ _PROBE_SCRIPT = """
571
+ import json
572
+ import shutil
573
+ import sys
574
+
575
+ def probe():
576
+ result = {
577
+ "python_version": sys.version.split()[0],
578
+ "backends": {},
579
+ "packages": {},
580
+ }
581
+
582
+ # Check Triton
583
+ try:
584
+ import triton
585
+ result["backends"]["triton"] = triton.__version__
586
+ except ImportError:
587
+ result["backends"]["triton"] = None
588
+
589
+ # Check torch
590
+ try:
591
+ import torch
592
+ result["packages"]["torch"] = torch.__version__
593
+ result["backends"]["torch"] = torch.__version__
594
+ result["cuda_available"] = torch.cuda.is_available()
595
+ if torch.cuda.is_available():
596
+ result["gpu_name"] = torch.cuda.get_device_name(0)
597
+ props = torch.cuda.get_device_properties(0)
598
+ result["compute_capability"] = f"{props.major}.{props.minor}"
599
+ except ImportError:
600
+ result["packages"]["torch"] = None
601
+
602
+ # Check hipcc (AMD)
603
+ hipcc = shutil.which("hipcc")
604
+ result["backends"]["hipcc"] = hipcc
605
+
606
+ # Check nvcc (NVIDIA)
607
+ nvcc = shutil.which("nvcc")
608
+ result["backends"]["nvcc"] = nvcc
609
+
610
+ # Check ROCm version
611
+ try:
612
+ with open("/opt/rocm/.info/version", "r") as f:
613
+ result["rocm_version"] = f.read().strip()
614
+ except Exception:
615
+ result["rocm_version"] = None
616
+
617
+ # Check CUDA version from nvcc
618
+ if nvcc:
619
+ import subprocess
620
+ try:
621
+ out = subprocess.check_output([nvcc, "--version"], text=True)
622
+ for line in out.split("\\n"):
623
+ if "release" in line.lower():
624
+ # Parse "Cuda compilation tools, release 12.1, V12.1.105"
625
+ parts = line.split("release")
626
+ if len(parts) > 1:
627
+ result["cuda_version"] = parts[1].split(",")[0].strip()
628
+ break
629
+ except Exception:
630
+ pass
631
+
632
+ print(json.dumps(result))
633
+
634
+ if __name__ == "__main__":
635
+ probe()
636
+ """
637
+
638
+
639
+ class ProbeError(Exception):
640
+ """Error during target probing with actionable context."""
641
+
642
+ pass
643
+
644
+
645
+ async def probe_target_capabilities(target: TargetConfig) -> dict[str, Any]:
646
+ """Probe a target to discover available compilation backends.
647
+
648
+ Connects to the target and runs a probe script to check:
649
+ - Triton availability
650
+ - torch availability
651
+ - HIP/CUDA compiler
652
+ - ROCm/CUDA version
653
+ - GPU info
654
+
655
+ Args:
656
+ target: Target config
657
+
658
+ Returns:
659
+ Dict with capabilities info
660
+
661
+ Raises:
662
+ ProbeError: With actionable error message on failure
663
+ """
664
+ import json
665
+ import subprocess
666
+
667
+ if isinstance(target, RunPodTarget):
668
+ import trio_asyncio
669
+ from wafer_core.targets.runpod import RunPodError, get_pod_state, runpod_ssh_context
670
+
671
+ # Check if pod exists before trying to connect
672
+ pod_state = get_pod_state(target.name)
673
+
674
+ try:
675
+ # Need trio_asyncio.open_loop() for asyncssh bridge used by runpod_ssh_context
676
+ async with trio_asyncio.open_loop():
677
+ async with runpod_ssh_context(target) as ssh_info:
678
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}"
679
+ port = ssh_info.port
680
+ key_path = target.ssh_key
681
+
682
+ # Find Python and run probe using subprocess (simpler than async ssh)
683
+ def run_ssh_cmd(cmd: str) -> tuple[int, str, str]:
684
+ try:
685
+ result = subprocess.run(
686
+ [
687
+ "ssh",
688
+ "-o",
689
+ "StrictHostKeyChecking=no",
690
+ "-o",
691
+ "UserKnownHostsFile=/dev/null",
692
+ "-o",
693
+ "ConnectTimeout=30",
694
+ "-i",
695
+ str(key_path),
696
+ "-p",
697
+ str(port),
698
+ ssh_target,
699
+ cmd,
700
+ ],
701
+ capture_output=True,
702
+ text=True,
703
+ timeout=60,
704
+ )
705
+ return result.returncode, result.stdout, result.stderr
706
+ except subprocess.TimeoutExpired:
707
+ raise ProbeError(
708
+ f"SSH connection timed out\n"
709
+ f" Host: {ssh_target}:{port}\n"
710
+ f" Hint: The pod may be starting up. Try again in 30 seconds."
711
+ ) from None
712
+
713
+ # Find Python
714
+ python_exe = "python3"
715
+ for candidate in [
716
+ "/opt/conda/envs/py_3.10/bin/python3",
717
+ "/opt/conda/bin/python3",
718
+ ]:
719
+ code, out, _ = run_ssh_cmd(f"{candidate} --version 2>/dev/null && echo OK")
720
+ if code == 0 and "OK" in out:
721
+ python_exe = candidate
722
+ break
723
+
724
+ # Run probe script
725
+ escaped_script = _PROBE_SCRIPT.replace("'", "'\"'\"'")
726
+ code, out, err = run_ssh_cmd(f"{python_exe} -c '{escaped_script}'")
727
+ if code != 0:
728
+ raise ProbeError(
729
+ f"Probe script failed on target\n"
730
+ f" Exit code: {code}\n"
731
+ f" Error: {err.strip() if err else 'unknown'}"
732
+ )
733
+
734
+ try:
735
+ return json.loads(out)
736
+ except json.JSONDecodeError as e:
737
+ raise ProbeError(
738
+ f"Failed to parse probe output\n Error: {e}\n Output: {out[:200]}..."
739
+ ) from None
740
+
741
+ except RunPodError as e:
742
+ # RunPod API errors (provisioning, pod not found, etc.)
743
+ raise ProbeError(f"RunPod error for target '{target.name}'\n {e}") from None
744
+ except OSError as e:
745
+ # SSH connection errors
746
+ if pod_state:
747
+ raise ProbeError(
748
+ f"SSH connection failed to target '{target.name}'\n"
749
+ f" Host: {pod_state.ssh_username}@{pod_state.public_ip}:{pod_state.ssh_port}\n"
750
+ f" Error: {e}\n"
751
+ f" Hint: Check if the pod is still running with 'wafer config targets pods'"
752
+ ) from None
753
+ raise ProbeError(
754
+ f"SSH connection failed to target '{target.name}'\n"
755
+ f" Error: {e}\n"
756
+ f" Hint: No pod found. One will be provisioned on next probe attempt."
757
+ ) from None
758
+
759
+ elif isinstance(target, (BaremetalTarget, VMTarget)):
760
+ import subprocess
761
+
762
+ # Parse ssh_target (user@host:port or user@host)
763
+ ssh_target = target.ssh_target
764
+ if ":" in ssh_target.split("@")[-1]:
765
+ host_port = ssh_target.split("@")[-1]
766
+ host = host_port.rsplit(":", 1)[0]
767
+ port = host_port.rsplit(":", 1)[1]
768
+ user = ssh_target.split("@")[0]
769
+ ssh_target = f"{user}@{host}"
770
+ else:
771
+ host = ssh_target.split("@")[-1]
772
+ port = "22"
773
+ user = ssh_target.split("@")[0]
774
+
775
+ key_path = target.ssh_key
776
+
777
+ def run_ssh_cmd(cmd: str) -> tuple[int, str, str]:
778
+ try:
779
+ result = subprocess.run(
780
+ [
781
+ "ssh",
782
+ "-o",
783
+ "StrictHostKeyChecking=no",
784
+ "-o",
785
+ "UserKnownHostsFile=/dev/null",
786
+ "-o",
787
+ "ConnectTimeout=30",
788
+ "-i",
789
+ str(key_path),
790
+ "-p",
791
+ port,
792
+ ssh_target,
793
+ cmd,
794
+ ],
795
+ capture_output=True,
796
+ text=True,
797
+ timeout=60,
798
+ )
799
+ return result.returncode, result.stdout, result.stderr
800
+ except subprocess.TimeoutExpired:
801
+ raise ProbeError(
802
+ f"SSH connection timed out\n"
803
+ f" Host: {ssh_target}:{port}\n"
804
+ f" Hint: Check if the host is reachable and SSH is running."
805
+ ) from None
806
+
807
+ # Test SSH connection first
808
+ code, out, err = run_ssh_cmd("echo OK")
809
+ if code != 0:
810
+ raise ProbeError(
811
+ f"SSH connection failed to target '{target.name}'\n"
812
+ f" Host: {user}@{host}:{port}\n"
813
+ f" Key: {key_path}\n"
814
+ f" Error: {err.strip() if err else 'connection refused or timeout'}\n"
815
+ f" Hint: Verify the host is reachable and the SSH key is authorized."
816
+ )
817
+
818
+ # Run probe script
819
+ escaped_script = _PROBE_SCRIPT.replace("'", "'\"'\"'")
820
+ code, out, err = run_ssh_cmd(f"python3 -c '{escaped_script}'")
821
+ if code != 0:
822
+ raise ProbeError(
823
+ f"Probe script failed on target '{target.name}'\n"
824
+ f" Exit code: {code}\n"
825
+ f" Error: {err.strip() if err else 'unknown'}\n"
826
+ f" Hint: Ensure python3 is installed on the target."
827
+ )
828
+
829
+ try:
830
+ return json.loads(out)
831
+ except json.JSONDecodeError as e:
832
+ raise ProbeError(
833
+ f"Failed to parse probe output from '{target.name}'\n"
834
+ f" Error: {e}\n"
835
+ f" Output: {out[:200]}..."
836
+ ) from None
837
+
838
+ else:
839
+ raise ProbeError(
840
+ f"Probing not supported for target type: {type(target).__name__}\n"
841
+ f" Supported types: RunPod, Baremetal, VM"
842
+ )