wafer-cli 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +18 -7
- wafer/api_client.py +4 -0
- wafer/auth.py +85 -0
- wafer/cli.py +2339 -404
- wafer/corpus.py +158 -32
- wafer/evaluate.py +1232 -201
- wafer/gpu_run.py +5 -1
- wafer/kernel_scope.py +554 -0
- wafer/nsys_analyze.py +903 -73
- wafer/nsys_profile.py +511 -0
- wafer/output.py +241 -0
- wafer/problems.py +357 -0
- wafer/skills/wafer-guide/SKILL.md +13 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +490 -0
- wafer/targets_ops.py +718 -0
- wafer/wevin_cli.py +129 -18
- wafer/workspaces.py +282 -182
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/METADATA +1 -1
- wafer_cli-0.2.10.dist-info/RECORD +40 -0
- wafer_cli-0.2.8.dist-info/RECORD +0 -33
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -30,6 +30,14 @@ import typer
|
|
|
30
30
|
|
|
31
31
|
from .config import WaferConfig, WaferEnvironment
|
|
32
32
|
from .inference import infer_upload_files, resolve_environment
|
|
33
|
+
from .problems import (
|
|
34
|
+
download_problems,
|
|
35
|
+
get_problem_path,
|
|
36
|
+
get_problems_path,
|
|
37
|
+
)
|
|
38
|
+
from .problems import (
|
|
39
|
+
list_problems as list_problems_fn,
|
|
40
|
+
)
|
|
33
41
|
|
|
34
42
|
app = typer.Typer(
|
|
35
43
|
help="GPU development toolkit for LLM coding agents",
|
|
@@ -91,11 +99,15 @@ def main_callback(ctx: typer.Context) -> None:
|
|
|
91
99
|
# Install exception hook to catch SystemExit and mark failures
|
|
92
100
|
original_excepthook = sys.excepthook
|
|
93
101
|
|
|
94
|
-
def custom_excepthook(
|
|
102
|
+
def custom_excepthook(
|
|
103
|
+
exc_type: type[BaseException],
|
|
104
|
+
exc_value: BaseException,
|
|
105
|
+
exc_traceback: object,
|
|
106
|
+
) -> None:
|
|
95
107
|
global _command_outcome
|
|
96
108
|
# Mark as failure if SystemExit with non-zero code, or any other exception
|
|
97
109
|
if exc_type is SystemExit:
|
|
98
|
-
exit_code = exc_value.code if hasattr(exc_value,
|
|
110
|
+
exit_code = exc_value.code if hasattr(exc_value, "code") else 1
|
|
99
111
|
if exit_code != 0 and exit_code is not None:
|
|
100
112
|
_command_outcome = "failure"
|
|
101
113
|
else:
|
|
@@ -170,7 +182,12 @@ workspaces_app = typer.Typer(
|
|
|
170
182
|
|
|
171
183
|
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
|
|
172
184
|
|
|
173
|
-
|
|
185
|
+
Available GPUs:
|
|
186
|
+
MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
187
|
+
B200 NVIDIA Blackwell B200 (180GB HBM3e, CUDA)
|
|
188
|
+
|
|
189
|
+
Commands:
|
|
190
|
+
wafer workspaces create dev --gpu B200 # Create workspace
|
|
174
191
|
wafer workspaces exec dev -- python x.py # Run commands
|
|
175
192
|
wafer workspaces ssh dev # Interactive SSH
|
|
176
193
|
wafer workspaces sync dev ./project # Sync files
|
|
@@ -178,6 +195,36 @@ Workspaces are on-demand cloud GPU environments. Requires authentication (wafer
|
|
|
178
195
|
)
|
|
179
196
|
app.add_typer(workspaces_app, name="workspaces")
|
|
180
197
|
|
|
198
|
+
# SSH Key management (BYOK - Bring Your Own Key)
|
|
199
|
+
ssh_keys_app = typer.Typer(
|
|
200
|
+
help="""Manage SSH public keys for workspace access.
|
|
201
|
+
|
|
202
|
+
Register your SSH public keys here. These keys are installed in all workspaces
|
|
203
|
+
you provision, enabling SSH access from any machine with your private key.
|
|
204
|
+
|
|
205
|
+
wafer ssh-keys list # List registered keys
|
|
206
|
+
wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
|
|
207
|
+
wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
|
|
208
|
+
wafer ssh-keys remove <key-id> # Remove a key"""
|
|
209
|
+
)
|
|
210
|
+
app.add_typer(ssh_keys_app, name="ssh-keys")
|
|
211
|
+
|
|
212
|
+
# Target operations (exec/ssh/sync on configured targets)
|
|
213
|
+
targets_ops_app = typer.Typer(
|
|
214
|
+
help="""Execute commands on configured GPU targets.
|
|
215
|
+
|
|
216
|
+
Run commands, SSH, or sync files to targets without going through evaluate.
|
|
217
|
+
Useful for exploratory work, debugging, or custom scripts.
|
|
218
|
+
|
|
219
|
+
wafer targets exec my-target -- python test.py # Run command
|
|
220
|
+
wafer targets ssh my-target # Interactive SSH
|
|
221
|
+
wafer targets sync my-target ./local_dir # Sync files
|
|
222
|
+
|
|
223
|
+
Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
|
|
224
|
+
Configure targets with: wafer config targets init ..."""
|
|
225
|
+
)
|
|
226
|
+
app.add_typer(targets_ops_app, name="targets")
|
|
227
|
+
|
|
181
228
|
# Billing management
|
|
182
229
|
billing_app = typer.Typer(help="Manage billing, credits, and subscription")
|
|
183
230
|
app.add_typer(billing_app, name="billing")
|
|
@@ -200,6 +247,13 @@ kernelbench_app = typer.Typer(
|
|
|
200
247
|
)
|
|
201
248
|
evaluate_app.add_typer(kernelbench_app, name="kernelbench")
|
|
202
249
|
|
|
250
|
+
# Nested subcommand for gpumode format
|
|
251
|
+
gpumode_app = typer.Typer(
|
|
252
|
+
help="Evaluate kernels in GPUMode format (custom_kernel/ref_kernel functions)",
|
|
253
|
+
invoke_without_command=True,
|
|
254
|
+
)
|
|
255
|
+
evaluate_app.add_typer(gpumode_app, name="gpumode")
|
|
256
|
+
|
|
203
257
|
# =============================================================================
|
|
204
258
|
# Dev commands (internal, used by web app proxy)
|
|
205
259
|
# =============================================================================
|
|
@@ -238,10 +292,101 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
|
|
|
238
292
|
amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
|
|
239
293
|
app.add_typer(amd_app, name="amd")
|
|
240
294
|
|
|
241
|
-
# ISA
|
|
242
|
-
isa_app = typer.Typer(help="ISA analysis for AMD GPU
|
|
295
|
+
# Unified ISA Analyzer - supports both .co files and Triton artifacts
|
|
296
|
+
isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
|
|
243
297
|
amd_app.add_typer(isa_app, name="isa")
|
|
244
298
|
|
|
299
|
+
# =============================================================================
|
|
300
|
+
# Roofline analysis (wafer roofline)
|
|
301
|
+
# =============================================================================
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@app.command("roofline")
|
|
305
|
+
def roofline_cmd(
|
|
306
|
+
gpu: str | None = typer.Option(
|
|
307
|
+
None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
|
|
308
|
+
),
|
|
309
|
+
bytes_moved: float | None = typer.Option(
|
|
310
|
+
None, "--bytes", "-b", help="Theoretical minimum bytes moved"
|
|
311
|
+
),
|
|
312
|
+
flops: float | None = typer.Option(None, "--flops", "-f", help="Theoretical minimum FLOPs"),
|
|
313
|
+
time_ms: float | None = typer.Option(
|
|
314
|
+
None, "--time-ms", "-t", help="Actual kernel time in milliseconds"
|
|
315
|
+
),
|
|
316
|
+
dtype: str = typer.Option(
|
|
317
|
+
"fp16", "--dtype", "-d", help="Data type for compute ceiling (fp16, fp32, bf16, fp8, int8)"
|
|
318
|
+
),
|
|
319
|
+
list_gpus: bool = typer.Option(False, "--list-gpus", help="List available GPU specs and exit"),
|
|
320
|
+
) -> None:
|
|
321
|
+
"""Analyze kernel performance against roofline model.
|
|
322
|
+
|
|
323
|
+
The roofline model shows the theoretical speed-of-light (SOL) for your kernel
|
|
324
|
+
based on whether it's memory-bound or compute-bound.
|
|
325
|
+
|
|
326
|
+
You need to provide:
|
|
327
|
+
- The GPU you ran on
|
|
328
|
+
- Theoretical minimum bytes moved (not actual - what the algorithm requires)
|
|
329
|
+
- Theoretical minimum FLOPs
|
|
330
|
+
- Actual measured kernel time
|
|
331
|
+
|
|
332
|
+
Example:
|
|
333
|
+
# Analyze a matmul kernel (4096x4096x4096, FP16)
|
|
334
|
+
# Theoretical: 2*M*N*K FLOPs = 137.4 TFLOP
|
|
335
|
+
# Theoretical bytes: (M*K + K*N + M*N) * 2 = 100.7 MB
|
|
336
|
+
wafer roofline --gpu H100 --bytes 100.7e6 --flops 137.4e12 --time-ms 85
|
|
337
|
+
|
|
338
|
+
# Analyze a memory-bound elementwise add (1B elements FP32)
|
|
339
|
+
# Reads 2 tensors, writes 1 = 12 GB total
|
|
340
|
+
# 1B adds = 1 GFLOP
|
|
341
|
+
wafer roofline --gpu H100 --bytes 12e9 --flops 1e9 --time-ms 4 --dtype fp32
|
|
342
|
+
|
|
343
|
+
# List available GPUs
|
|
344
|
+
wafer roofline --list-gpus
|
|
345
|
+
"""
|
|
346
|
+
from wafer_core.roofline import get_gpu_spec, roofline_analysis
|
|
347
|
+
from wafer_core.roofline import list_gpus as get_all_gpus
|
|
348
|
+
|
|
349
|
+
if list_gpus:
|
|
350
|
+
typer.echo("Available GPUs:")
|
|
351
|
+
for name in get_all_gpus():
|
|
352
|
+
spec = get_gpu_spec(name)
|
|
353
|
+
typer.echo(
|
|
354
|
+
f" {name}: {spec.peak_bandwidth_gbps:.0f} GB/s, {spec.peak_tflops_fp16:.0f} TFLOPS FP16"
|
|
355
|
+
)
|
|
356
|
+
return
|
|
357
|
+
|
|
358
|
+
# Validate required args for analysis
|
|
359
|
+
missing = []
|
|
360
|
+
if gpu is None:
|
|
361
|
+
missing.append("--gpu")
|
|
362
|
+
if bytes_moved is None:
|
|
363
|
+
missing.append("--bytes")
|
|
364
|
+
if flops is None:
|
|
365
|
+
missing.append("--flops")
|
|
366
|
+
if time_ms is None:
|
|
367
|
+
missing.append("--time-ms")
|
|
368
|
+
|
|
369
|
+
if missing:
|
|
370
|
+
typer.echo(f"Error: Missing required options: {', '.join(missing)}", err=True)
|
|
371
|
+
typer.echo("", err=True)
|
|
372
|
+
typer.echo("Run 'wafer roofline --help' for usage.", err=True)
|
|
373
|
+
raise typer.Exit(1)
|
|
374
|
+
|
|
375
|
+
try:
|
|
376
|
+
result = roofline_analysis(
|
|
377
|
+
gpu=gpu,
|
|
378
|
+
dtype=dtype,
|
|
379
|
+
bytes_moved=bytes_moved,
|
|
380
|
+
flops=flops,
|
|
381
|
+
time_ms=time_ms,
|
|
382
|
+
)
|
|
383
|
+
except ValueError as e:
|
|
384
|
+
typer.echo(f"Error: {e}", err=True)
|
|
385
|
+
raise typer.Exit(1) from None
|
|
386
|
+
|
|
387
|
+
typer.echo(result.format_report())
|
|
388
|
+
|
|
389
|
+
|
|
245
390
|
# =============================================================================
|
|
246
391
|
# Skill management (wafer skill ...)
|
|
247
392
|
# =============================================================================
|
|
@@ -256,21 +401,22 @@ def skill_install(
|
|
|
256
401
|
"all",
|
|
257
402
|
"--target",
|
|
258
403
|
"-t",
|
|
259
|
-
help="Target tool: claude, codex, or all",
|
|
404
|
+
help="Target tool: claude, codex, cursor, or all",
|
|
260
405
|
),
|
|
261
406
|
force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing skill"),
|
|
262
407
|
) -> None:
|
|
263
408
|
"""Install the wafer-guide skill for AI coding assistants.
|
|
264
409
|
|
|
265
410
|
Installs the bundled skill to make wafer commands discoverable by
|
|
266
|
-
Claude Code
|
|
411
|
+
Claude Code, OpenAI Codex CLI, and/or Cursor.
|
|
267
412
|
|
|
268
413
|
Skills follow the open agent skills specification (agentskills.io).
|
|
269
414
|
|
|
270
415
|
Examples:
|
|
271
|
-
wafer skill install # Install for
|
|
416
|
+
wafer skill install # Install for all tools
|
|
272
417
|
wafer skill install -t claude # Install for Claude Code only
|
|
273
418
|
wafer skill install -t codex # Install for Codex CLI only
|
|
419
|
+
wafer skill install -t cursor # Install for Cursor only
|
|
274
420
|
wafer skill install --force # Overwrite existing installation
|
|
275
421
|
"""
|
|
276
422
|
# Locate bundled skill
|
|
@@ -288,9 +434,13 @@ def skill_install(
|
|
|
288
434
|
))
|
|
289
435
|
if target in ("all", "codex"):
|
|
290
436
|
targets_to_install.append(("Codex CLI", Path.home() / ".codex" / "skills" / "wafer-guide"))
|
|
437
|
+
if target in ("all", "cursor"):
|
|
438
|
+
targets_to_install.append(("Cursor", Path.home() / ".cursor" / "skills" / "wafer-guide"))
|
|
291
439
|
|
|
292
440
|
if not targets_to_install:
|
|
293
|
-
typer.echo(
|
|
441
|
+
typer.echo(
|
|
442
|
+
f"Error: Unknown target '{target}'. Use: claude, codex, cursor, or all", err=True
|
|
443
|
+
)
|
|
294
444
|
raise typer.Exit(1)
|
|
295
445
|
|
|
296
446
|
for tool_name, dest_path in targets_to_install:
|
|
@@ -325,14 +475,15 @@ def skill_uninstall(
|
|
|
325
475
|
"all",
|
|
326
476
|
"--target",
|
|
327
477
|
"-t",
|
|
328
|
-
help="Target tool: claude, codex, or all",
|
|
478
|
+
help="Target tool: claude, codex, cursor, or all",
|
|
329
479
|
),
|
|
330
480
|
) -> None:
|
|
331
481
|
"""Uninstall the wafer-guide skill.
|
|
332
482
|
|
|
333
483
|
Examples:
|
|
334
|
-
wafer skill uninstall # Uninstall from
|
|
484
|
+
wafer skill uninstall # Uninstall from all tools
|
|
335
485
|
wafer skill uninstall -t claude # Uninstall from Claude Code only
|
|
486
|
+
wafer skill uninstall -t cursor # Uninstall from Cursor only
|
|
336
487
|
"""
|
|
337
488
|
targets_to_uninstall: list[tuple[str, Path]] = []
|
|
338
489
|
|
|
@@ -346,9 +497,16 @@ def skill_uninstall(
|
|
|
346
497
|
"Codex CLI",
|
|
347
498
|
Path.home() / ".codex" / "skills" / "wafer-guide",
|
|
348
499
|
))
|
|
500
|
+
if target in ("all", "cursor"):
|
|
501
|
+
targets_to_uninstall.append((
|
|
502
|
+
"Cursor",
|
|
503
|
+
Path.home() / ".cursor" / "skills" / "wafer-guide",
|
|
504
|
+
))
|
|
349
505
|
|
|
350
506
|
if not targets_to_uninstall:
|
|
351
|
-
typer.echo(
|
|
507
|
+
typer.echo(
|
|
508
|
+
f"Error: Unknown target '{target}'. Use: claude, codex, cursor, or all", err=True
|
|
509
|
+
)
|
|
352
510
|
raise typer.Exit(1)
|
|
353
511
|
|
|
354
512
|
for tool_name, dest_path in targets_to_uninstall:
|
|
@@ -383,6 +541,7 @@ def skill_status() -> None:
|
|
|
383
541
|
installations = [
|
|
384
542
|
("Claude Code", Path.home() / ".claude" / "skills" / "wafer-guide"),
|
|
385
543
|
("Codex CLI", Path.home() / ".codex" / "skills" / "wafer-guide"),
|
|
544
|
+
("Cursor", Path.home() / ".cursor" / "skills" / "wafer-guide"),
|
|
386
545
|
]
|
|
387
546
|
|
|
388
547
|
for tool_name, path in installations:
|
|
@@ -396,6 +555,122 @@ def skill_status() -> None:
|
|
|
396
555
|
typer.echo(f"{tool_name}: Not installed")
|
|
397
556
|
|
|
398
557
|
|
|
558
|
+
# =============================================================================
|
|
559
|
+
# Provider auth management (wafer auth ...)
|
|
560
|
+
# =============================================================================
|
|
561
|
+
|
|
562
|
+
provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
|
|
563
|
+
app.add_typer(provider_auth_app, name="auth")
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
@provider_auth_app.command("login")
|
|
567
|
+
def provider_auth_login(
|
|
568
|
+
provider: str = typer.Argument(
|
|
569
|
+
...,
|
|
570
|
+
help="Provider name: runpod, digitalocean, or modal",
|
|
571
|
+
),
|
|
572
|
+
api_key: str | None = typer.Option(
|
|
573
|
+
None,
|
|
574
|
+
"--api-key",
|
|
575
|
+
"-k",
|
|
576
|
+
help="API key (if not provided, reads from stdin)",
|
|
577
|
+
),
|
|
578
|
+
) -> None:
|
|
579
|
+
"""Save API key for a cloud GPU provider.
|
|
580
|
+
|
|
581
|
+
Stores the key in ~/.wafer/auth.json. Environment variables
|
|
582
|
+
(e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
|
|
583
|
+
|
|
584
|
+
Examples:
|
|
585
|
+
wafer auth login runpod --api-key rp_xxx
|
|
586
|
+
wafer auth login digitalocean --api-key dop_v1_xxx
|
|
587
|
+
echo $API_KEY | wafer auth login runpod
|
|
588
|
+
"""
|
|
589
|
+
import sys
|
|
590
|
+
|
|
591
|
+
from wafer_core.auth import PROVIDERS, save_api_key
|
|
592
|
+
|
|
593
|
+
# Validate provider
|
|
594
|
+
if provider not in PROVIDERS:
|
|
595
|
+
typer.echo(f"Error: Unknown provider '{provider}'", err=True)
|
|
596
|
+
typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
|
|
597
|
+
raise typer.Exit(1)
|
|
598
|
+
|
|
599
|
+
# Get API key from option or stdin
|
|
600
|
+
if api_key is None:
|
|
601
|
+
if sys.stdin.isatty():
|
|
602
|
+
typer.echo(f"Enter API key for {PROVIDERS[provider]['display_name']}:")
|
|
603
|
+
api_key = typer.prompt("API key", hide_input=True)
|
|
604
|
+
else:
|
|
605
|
+
api_key = sys.stdin.read().strip()
|
|
606
|
+
|
|
607
|
+
if not api_key:
|
|
608
|
+
typer.echo("Error: No API key provided", err=True)
|
|
609
|
+
raise typer.Exit(1)
|
|
610
|
+
|
|
611
|
+
# Save the key
|
|
612
|
+
save_api_key(provider, api_key)
|
|
613
|
+
typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
|
|
614
|
+
typer.echo("Stored in: ~/.wafer/auth.json")
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
@provider_auth_app.command("logout")
|
|
618
|
+
def provider_auth_logout(
|
|
619
|
+
provider: str = typer.Argument(
|
|
620
|
+
...,
|
|
621
|
+
help="Provider name: runpod, digitalocean, or modal",
|
|
622
|
+
),
|
|
623
|
+
) -> None:
|
|
624
|
+
"""Remove stored API key for a cloud GPU provider.
|
|
625
|
+
|
|
626
|
+
Examples:
|
|
627
|
+
wafer auth logout runpod
|
|
628
|
+
wafer auth logout digitalocean
|
|
629
|
+
"""
|
|
630
|
+
from wafer_core.auth import PROVIDERS, remove_api_key
|
|
631
|
+
|
|
632
|
+
# Validate provider
|
|
633
|
+
if provider not in PROVIDERS:
|
|
634
|
+
typer.echo(f"Error: Unknown provider '{provider}'", err=True)
|
|
635
|
+
typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
|
|
636
|
+
raise typer.Exit(1)
|
|
637
|
+
|
|
638
|
+
if remove_api_key(provider):
|
|
639
|
+
typer.echo(f"API key removed for {PROVIDERS[provider]['display_name']}")
|
|
640
|
+
else:
|
|
641
|
+
typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
@provider_auth_app.command("status")
|
|
645
|
+
def provider_auth_status() -> None:
|
|
646
|
+
"""Show authentication status for all cloud GPU providers.
|
|
647
|
+
|
|
648
|
+
Displays which providers have API keys configured and where
|
|
649
|
+
the keys are coming from (environment variable or auth.json).
|
|
650
|
+
|
|
651
|
+
Example:
|
|
652
|
+
wafer auth status
|
|
653
|
+
"""
|
|
654
|
+
from wafer_core.auth import get_all_auth_status
|
|
655
|
+
|
|
656
|
+
statuses = get_all_auth_status()
|
|
657
|
+
|
|
658
|
+
typer.echo("Cloud GPU Provider Authentication Status")
|
|
659
|
+
typer.echo("=" * 45)
|
|
660
|
+
|
|
661
|
+
for status in statuses:
|
|
662
|
+
if status.is_authenticated:
|
|
663
|
+
source_str = f"({status.source})" if status.source else ""
|
|
664
|
+
typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
|
|
665
|
+
else:
|
|
666
|
+
typer.echo(f" {status.display_name}: ✗ Not configured")
|
|
667
|
+
typer.echo(f" Run: wafer auth login {status.provider}")
|
|
668
|
+
typer.echo(f" Or set: {status.key_url}")
|
|
669
|
+
|
|
670
|
+
typer.echo("")
|
|
671
|
+
typer.echo("Note: Environment variables take precedence over stored keys.")
|
|
672
|
+
|
|
673
|
+
|
|
399
674
|
@app.command(hidden=True)
|
|
400
675
|
def run(
|
|
401
676
|
command: str = typer.Argument(..., help="Command to run in Docker container"),
|
|
@@ -975,6 +1250,11 @@ def agent( # noqa: PLR0913
|
|
|
975
1250
|
"--list-sessions",
|
|
976
1251
|
help="List recent sessions and exit",
|
|
977
1252
|
),
|
|
1253
|
+
get_session: str | None = typer.Option(
|
|
1254
|
+
None,
|
|
1255
|
+
"--get-session",
|
|
1256
|
+
help="Get session by ID and print messages (use with --json)",
|
|
1257
|
+
),
|
|
978
1258
|
tools: str | None = typer.Option(
|
|
979
1259
|
None,
|
|
980
1260
|
"--tools",
|
|
@@ -1021,47 +1301,7 @@ def agent( # noqa: PLR0913
|
|
|
1021
1301
|
None,
|
|
1022
1302
|
"--corpus",
|
|
1023
1303
|
"-c",
|
|
1024
|
-
help="Documentation corpus to use (cuda, cutlass, hip). Must be downloaded first.",
|
|
1025
|
-
),
|
|
1026
|
-
# Legacy kernel optimization options (hidden, for backwards compat)
|
|
1027
|
-
problem: Path | None = typer.Option(
|
|
1028
|
-
None,
|
|
1029
|
-
"--problem",
|
|
1030
|
-
hidden=True,
|
|
1031
|
-
help="[Legacy] Path to problem YAML config file",
|
|
1032
|
-
),
|
|
1033
|
-
reference: Path | None = typer.Option(
|
|
1034
|
-
None,
|
|
1035
|
-
"--reference",
|
|
1036
|
-
"--ref",
|
|
1037
|
-
hidden=True,
|
|
1038
|
-
help="[Legacy] Path to reference kernel file",
|
|
1039
|
-
),
|
|
1040
|
-
description: str | None = typer.Option(
|
|
1041
|
-
None,
|
|
1042
|
-
"--description",
|
|
1043
|
-
"--desc",
|
|
1044
|
-
hidden=True,
|
|
1045
|
-
help="[Legacy] Problem description",
|
|
1046
|
-
),
|
|
1047
|
-
test: list[str] | None = typer.Option(
|
|
1048
|
-
None,
|
|
1049
|
-
"--test",
|
|
1050
|
-
hidden=True,
|
|
1051
|
-
help="[Legacy] Test case",
|
|
1052
|
-
),
|
|
1053
|
-
benchmark: list[str] | None = typer.Option(
|
|
1054
|
-
None,
|
|
1055
|
-
"--benchmark",
|
|
1056
|
-
"-b",
|
|
1057
|
-
hidden=True,
|
|
1058
|
-
help="[Legacy] Benchmark case",
|
|
1059
|
-
),
|
|
1060
|
-
speedup_target: float | None = typer.Option(
|
|
1061
|
-
None,
|
|
1062
|
-
"--speedup",
|
|
1063
|
-
hidden=True,
|
|
1064
|
-
help="[Legacy] Speedup target",
|
|
1304
|
+
help="Documentation corpus to use (cuda, cutlass, hip, amd). Must be downloaded first.",
|
|
1065
1305
|
),
|
|
1066
1306
|
) -> None:
|
|
1067
1307
|
"""AI assistant for GPU kernel development.
|
|
@@ -1148,20 +1388,15 @@ def agent( # noqa: PLR0913
|
|
|
1148
1388
|
prompt=actual_prompt,
|
|
1149
1389
|
interactive=use_tui,
|
|
1150
1390
|
single_turn=single_turn,
|
|
1151
|
-
problem=str(problem) if problem else None,
|
|
1152
|
-
reference=str(reference) if reference else None,
|
|
1153
|
-
description=description,
|
|
1154
|
-
tests=list(test) if test else None,
|
|
1155
|
-
benchmarks=list(benchmark) if benchmark else None,
|
|
1156
1391
|
model=model,
|
|
1157
|
-
max_turns=max_turns,
|
|
1158
|
-
speedup_target=speedup_target,
|
|
1159
1392
|
resume=resume,
|
|
1160
1393
|
from_turn=from_turn,
|
|
1161
1394
|
list_sessions=list_sessions,
|
|
1395
|
+
get_session=get_session,
|
|
1162
1396
|
tools=tools.split(",") if tools else None,
|
|
1163
1397
|
allow_spawn=allow_spawn,
|
|
1164
1398
|
max_tool_fails=max_tool_fails,
|
|
1399
|
+
max_turns=max_turns,
|
|
1165
1400
|
json_output=json_output,
|
|
1166
1401
|
template=template,
|
|
1167
1402
|
template_args=parsed_template_args,
|
|
@@ -1171,7 +1406,7 @@ def agent( # noqa: PLR0913
|
|
|
1171
1406
|
|
|
1172
1407
|
# =============================================================================
|
|
1173
1408
|
# Evaluate command
|
|
1174
|
-
# Hidden aliases for
|
|
1409
|
+
# Hidden aliases for agent command
|
|
1175
1410
|
def _make_agent_alias(name: str, doc: str) -> None:
|
|
1176
1411
|
"""Create a hidden alias that delegates to agent()."""
|
|
1177
1412
|
|
|
@@ -1186,6 +1421,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1186
1421
|
resume: str | None = typer.Option(None, "--resume", "-r"),
|
|
1187
1422
|
from_turn: int | None = typer.Option(None, "--from-turn"),
|
|
1188
1423
|
list_sessions: bool = typer.Option(False, "--list-sessions"),
|
|
1424
|
+
get_session: str | None = typer.Option(None, "--get-session"),
|
|
1189
1425
|
tools: str | None = typer.Option(None, "--tools"),
|
|
1190
1426
|
allow_spawn: bool = typer.Option(False, "--allow-spawn"),
|
|
1191
1427
|
max_tool_fails: int | None = typer.Option(None, "--max-tool-fails"),
|
|
@@ -1195,12 +1431,6 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1195
1431
|
template: str | None = typer.Option(None, "--template", "-t"),
|
|
1196
1432
|
template_args: list[str] | None = typer.Option(None, "--args"),
|
|
1197
1433
|
corpus: str | None = typer.Option(None, "--corpus"),
|
|
1198
|
-
problem: Path | None = typer.Option(None, "--problem", hidden=True),
|
|
1199
|
-
reference: Path | None = typer.Option(None, "--reference", hidden=True),
|
|
1200
|
-
description: str | None = typer.Option(None, "--description", hidden=True),
|
|
1201
|
-
test: list[Path] | None = typer.Option(None, "--test", hidden=True),
|
|
1202
|
-
benchmark: list[Path] | None = typer.Option(None, "--benchmark", hidden=True),
|
|
1203
|
-
speedup_target: float | None = typer.Option(None, "--speedup-target", hidden=True),
|
|
1204
1434
|
) -> None:
|
|
1205
1435
|
agent(
|
|
1206
1436
|
prompt=prompt,
|
|
@@ -1210,6 +1440,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1210
1440
|
resume=resume,
|
|
1211
1441
|
from_turn=from_turn,
|
|
1212
1442
|
list_sessions=list_sessions,
|
|
1443
|
+
get_session=get_session,
|
|
1213
1444
|
tools=tools,
|
|
1214
1445
|
allow_spawn=allow_spawn,
|
|
1215
1446
|
max_tool_fails=max_tool_fails,
|
|
@@ -1219,12 +1450,6 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1219
1450
|
template=template,
|
|
1220
1451
|
template_args=template_args,
|
|
1221
1452
|
corpus=corpus,
|
|
1222
|
-
problem=problem,
|
|
1223
|
-
reference=reference,
|
|
1224
|
-
description=description,
|
|
1225
|
-
test=test,
|
|
1226
|
-
benchmark=benchmark,
|
|
1227
|
-
speedup_target=speedup_target,
|
|
1228
1453
|
)
|
|
1229
1454
|
|
|
1230
1455
|
alias_cmd.__doc__ = doc
|
|
@@ -1289,86 +1514,37 @@ def evaluate( # noqa: PLR0913
|
|
|
1289
1514
|
--benchmark --defensive
|
|
1290
1515
|
|
|
1291
1516
|
Subcommands:
|
|
1292
|
-
|
|
1517
|
+
gpumode Use GPUMode format (functional) - RECOMMENDED
|
|
1293
1518
|
kernelbench Use KernelBench format (ModelNew class)
|
|
1519
|
+
make-template Generate template files for this format (deprecated)
|
|
1294
1520
|
"""
|
|
1295
1521
|
# If a subcommand is being invoked, skip the main evaluation logic
|
|
1296
1522
|
if ctx.invoked_subcommand is not None:
|
|
1297
1523
|
return
|
|
1298
1524
|
|
|
1299
|
-
#
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
if test_cases is None:
|
|
1306
|
-
missing_args.append("--test-cases")
|
|
1307
|
-
|
|
1308
|
-
if missing_args:
|
|
1309
|
-
typer.echo("Error: Missing required arguments", err=True)
|
|
1310
|
-
typer.echo(f" Required: {', '.join(missing_args)}", err=True)
|
|
1311
|
-
typer.echo("", err=True)
|
|
1312
|
-
typer.echo(
|
|
1313
|
-
"Usage: wafer evaluate --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
|
|
1314
|
-
err=True,
|
|
1315
|
-
)
|
|
1316
|
-
typer.echo("", err=True)
|
|
1317
|
-
typer.echo("Run 'wafer evaluate --help' for full options.", err=True)
|
|
1318
|
-
typer.echo("Run 'wafer evaluate make-template DIR' to generate starter files.", err=True)
|
|
1319
|
-
raise typer.Exit(1)
|
|
1320
|
-
|
|
1321
|
-
from .evaluate import EvaluateArgs, run_evaluate
|
|
1322
|
-
|
|
1323
|
-
args = EvaluateArgs(
|
|
1324
|
-
implementation=implementation,
|
|
1325
|
-
reference=reference,
|
|
1326
|
-
test_cases=test_cases,
|
|
1327
|
-
target_name=target or "",
|
|
1328
|
-
benchmark=benchmark,
|
|
1329
|
-
profile=profile,
|
|
1330
|
-
defensive=defensive,
|
|
1331
|
-
sync_artifacts=sync_artifacts,
|
|
1332
|
-
gpu_id=gpu_id,
|
|
1525
|
+
# Bare 'wafer evaluate' is no longer supported - must use subcommand
|
|
1526
|
+
typer.echo("Error: 'wafer evaluate' requires a subcommand.", err=True)
|
|
1527
|
+
typer.echo("", err=True)
|
|
1528
|
+
typer.echo("Available subcommands:", err=True)
|
|
1529
|
+
typer.echo(
|
|
1530
|
+
" gpumode Evaluate GPUMode format (custom_kernel/ref_kernel functions)", err=True
|
|
1333
1531
|
)
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
typer.echo(f"Error: {e}", err=True)
|
|
1351
|
-
raise typer.Exit(1) from None
|
|
1352
|
-
|
|
1353
|
-
# Print results
|
|
1354
|
-
if result.success:
|
|
1355
|
-
typer.echo("")
|
|
1356
|
-
typer.echo("=" * 60)
|
|
1357
|
-
status = "PASS" if result.all_correct else "FAIL"
|
|
1358
|
-
typer.echo(f"Result: {status}")
|
|
1359
|
-
score_pct = f"{result.correctness_score:.1%}"
|
|
1360
|
-
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
1361
|
-
if result.geomean_speedup > 0:
|
|
1362
|
-
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
1363
|
-
if result.artifact_path:
|
|
1364
|
-
typer.echo(f"Artifacts: {result.artifact_path}")
|
|
1365
|
-
typer.echo("=" * 60)
|
|
1366
|
-
|
|
1367
|
-
if not result.all_correct:
|
|
1368
|
-
raise typer.Exit(1)
|
|
1369
|
-
else:
|
|
1370
|
-
typer.echo(f"Error: {result.error_message}", err=True)
|
|
1371
|
-
raise typer.Exit(1)
|
|
1532
|
+
typer.echo(" kernelbench Evaluate KernelBench format (ModelNew class)", err=True)
|
|
1533
|
+
typer.echo("", err=True)
|
|
1534
|
+
typer.echo("Examples:", err=True)
|
|
1535
|
+
typer.echo(
|
|
1536
|
+
" wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json",
|
|
1537
|
+
err=True,
|
|
1538
|
+
)
|
|
1539
|
+
typer.echo(
|
|
1540
|
+
" wafer evaluate kernelbench --impl impl.py --reference ref.py --benchmark", err=True
|
|
1541
|
+
)
|
|
1542
|
+
typer.echo("", err=True)
|
|
1543
|
+
typer.echo(
|
|
1544
|
+
"Run 'wafer evaluate gpumode --help' or 'wafer evaluate kernelbench --help' for options.",
|
|
1545
|
+
err=True,
|
|
1546
|
+
)
|
|
1547
|
+
raise typer.Exit(1)
|
|
1372
1548
|
|
|
1373
1549
|
|
|
1374
1550
|
TEMPLATE_KERNEL = '''\
|
|
@@ -1503,12 +1679,63 @@ def evaluate_make_template(
|
|
|
1503
1679
|
# KernelBench format evaluation
|
|
1504
1680
|
# =============================================================================
|
|
1505
1681
|
|
|
1506
|
-
|
|
1507
|
-
|
|
1682
|
+
|
|
1683
|
+
def _get_kernelbench_root() -> Path | None:
|
|
1684
|
+
"""Get KernelBench problems root, preferring downloaded location."""
|
|
1685
|
+
# First check downloaded location
|
|
1686
|
+
downloaded = get_problems_path("kernelbench")
|
|
1687
|
+
if downloaded is not None:
|
|
1688
|
+
kb_root = downloaded / "KernelBench"
|
|
1689
|
+
if kb_root.exists():
|
|
1690
|
+
return kb_root
|
|
1691
|
+
return downloaded
|
|
1692
|
+
|
|
1693
|
+
# Fall back to legacy location (for development)
|
|
1694
|
+
legacy = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench" / "KernelBench"
|
|
1695
|
+
if legacy.exists():
|
|
1696
|
+
return legacy
|
|
1697
|
+
|
|
1698
|
+
return None
|
|
1699
|
+
|
|
1700
|
+
|
|
1701
|
+
@kernelbench_app.command("download")
|
|
1702
|
+
def kernelbench_download(
|
|
1703
|
+
force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
|
|
1704
|
+
) -> None:
|
|
1705
|
+
"""Download KernelBench problems from GitHub.
|
|
1706
|
+
|
|
1707
|
+
Downloads the problem set to ~/.cache/wafer/problems/kernelbench/
|
|
1708
|
+
|
|
1709
|
+
Examples:
|
|
1710
|
+
wafer evaluate kernelbench download
|
|
1711
|
+
wafer evaluate kernelbench download --force # Re-download
|
|
1712
|
+
"""
|
|
1713
|
+
try:
|
|
1714
|
+
path = download_problems("kernelbench", force=force, verbose=True)
|
|
1715
|
+
typer.echo("")
|
|
1716
|
+
typer.echo(f"Problems available at: {path}")
|
|
1717
|
+
typer.echo("Run 'wafer evaluate kernelbench list-problems' to see available problems.")
|
|
1718
|
+
except Exception as e:
|
|
1719
|
+
typer.echo(f"Error downloading problems: {e}", err=True)
|
|
1720
|
+
raise typer.Exit(1) from None
|
|
1721
|
+
|
|
1722
|
+
|
|
1723
|
+
@kernelbench_app.command("list-problems")
|
|
1724
|
+
def kernelbench_list_problems() -> None:
|
|
1725
|
+
"""List available KernelBench problems.
|
|
1726
|
+
|
|
1727
|
+
Examples:
|
|
1728
|
+
wafer evaluate kernelbench list-problems
|
|
1729
|
+
"""
|
|
1730
|
+
try:
|
|
1731
|
+
list_problems_fn("kernelbench", verbose=True)
|
|
1732
|
+
except ValueError as e:
|
|
1733
|
+
typer.echo(str(e), err=True)
|
|
1734
|
+
raise typer.Exit(1) from None
|
|
1508
1735
|
|
|
1509
1736
|
|
|
1510
1737
|
@kernelbench_app.callback(invoke_without_command=True)
|
|
1511
|
-
def kernelbench_evaluate( # noqa: PLR0913
|
|
1738
|
+
def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
1512
1739
|
ctx: typer.Context,
|
|
1513
1740
|
implementation: Path | None = typer.Option(
|
|
1514
1741
|
None,
|
|
@@ -1528,17 +1755,38 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1528
1755
|
help="GPU target name. See 'wafer config targets list' for available targets.",
|
|
1529
1756
|
autocompletion=complete_target_name,
|
|
1530
1757
|
),
|
|
1758
|
+
pool: str | None = typer.Option(
|
|
1759
|
+
None,
|
|
1760
|
+
"--pool",
|
|
1761
|
+
"-p",
|
|
1762
|
+
help="Target pool name. Acquires first available target from the pool. "
|
|
1763
|
+
"Define pools in ~/.wafer/config.toml under [pools.<name>].",
|
|
1764
|
+
),
|
|
1531
1765
|
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
1532
1766
|
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
1533
|
-
inputs: Path | None = typer.Option(
|
|
1767
|
+
inputs: Path | None = typer.Option(
|
|
1768
|
+
None, "--inputs", help="Custom inputs file to override get_inputs()"
|
|
1769
|
+
),
|
|
1534
1770
|
seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
|
|
1535
1771
|
defensive: bool = typer.Option(
|
|
1536
1772
|
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
1537
1773
|
),
|
|
1774
|
+
backend: str | None = typer.Option(
|
|
1775
|
+
None,
|
|
1776
|
+
"--backend",
|
|
1777
|
+
help="Kernel backend for static validation (hip, cuda, triton, cute, tilelang, thunderkittens). "
|
|
1778
|
+
"When specified, validates that the implementation uses the correct backend primitives.",
|
|
1779
|
+
),
|
|
1538
1780
|
sync_artifacts: bool = typer.Option(
|
|
1539
1781
|
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
1540
1782
|
),
|
|
1541
1783
|
gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
|
|
1784
|
+
json_output: bool = typer.Option(
|
|
1785
|
+
False, "--json", help="Output as single JSON object (machine-readable)"
|
|
1786
|
+
),
|
|
1787
|
+
jsonl_output: bool = typer.Option(
|
|
1788
|
+
False, "--jsonl", help="Output as streaming JSON Lines (one object per event)"
|
|
1789
|
+
),
|
|
1542
1790
|
) -> None:
|
|
1543
1791
|
"""Run kernel evaluation in KernelBench format (ModelNew class).
|
|
1544
1792
|
|
|
@@ -1588,48 +1836,106 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1588
1836
|
)
|
|
1589
1837
|
raise typer.Exit(1)
|
|
1590
1838
|
|
|
1839
|
+
# Validate --target and --pool are mutually exclusive
|
|
1840
|
+
if target and pool:
|
|
1841
|
+
typer.echo("Error: Cannot specify both --target and --pool", err=True)
|
|
1842
|
+
raise typer.Exit(1)
|
|
1843
|
+
|
|
1591
1844
|
from .evaluate import KernelBenchEvaluateArgs, run_evaluate_kernelbench
|
|
1845
|
+
from .output import OutputCollector, format_evaluate_result, get_output_format
|
|
1846
|
+
|
|
1847
|
+
output_format = get_output_format(json_output, jsonl_output)
|
|
1848
|
+
collector = OutputCollector(format=output_format)
|
|
1849
|
+
|
|
1850
|
+
# If pool specified, acquire a target from the pool
|
|
1851
|
+
resolved_target = target or ""
|
|
1852
|
+
pool_lock_context = None
|
|
1853
|
+
|
|
1854
|
+
if pool:
|
|
1855
|
+
from .target_lock import acquire_from_pool
|
|
1856
|
+
from .targets import filter_pool_by_auth, get_pool
|
|
1857
|
+
|
|
1858
|
+
try:
|
|
1859
|
+
pool_targets = get_pool(pool)
|
|
1860
|
+
except FileNotFoundError as e:
|
|
1861
|
+
collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
|
|
1862
|
+
collector.finalize()
|
|
1863
|
+
raise typer.Exit(1) from None
|
|
1864
|
+
|
|
1865
|
+
# Filter to only targets with valid auth
|
|
1866
|
+
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1867
|
+
if skipped:
|
|
1868
|
+
collector.emit("pool_auth_skip", targets=skipped)
|
|
1869
|
+
|
|
1870
|
+
if not usable_targets:
|
|
1871
|
+
collector.set_error("pool", "NoUsableTargets", pool=pool)
|
|
1872
|
+
collector.finalize()
|
|
1873
|
+
raise typer.Exit(1) from None
|
|
1874
|
+
|
|
1875
|
+
collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
|
|
1876
|
+
pool_lock_context = acquire_from_pool(usable_targets)
|
|
1877
|
+
acquired_target = pool_lock_context.__enter__()
|
|
1878
|
+
|
|
1879
|
+
if acquired_target is None:
|
|
1880
|
+
# Exit context manager before raising to avoid resource leak
|
|
1881
|
+
pool_lock_context.__exit__(None, None, None)
|
|
1882
|
+
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
|
|
1883
|
+
collector.finalize()
|
|
1884
|
+
raise typer.Exit(1)
|
|
1885
|
+
|
|
1886
|
+
collector.emit("pool_acquired", target=acquired_target)
|
|
1887
|
+
resolved_target = acquired_target
|
|
1888
|
+
|
|
1889
|
+
collector.target = resolved_target
|
|
1592
1890
|
|
|
1593
1891
|
args = KernelBenchEvaluateArgs(
|
|
1594
1892
|
implementation=implementation,
|
|
1595
1893
|
reference=reference,
|
|
1596
|
-
target_name=
|
|
1894
|
+
target_name=resolved_target,
|
|
1597
1895
|
benchmark=benchmark,
|
|
1598
1896
|
profile=profile,
|
|
1599
1897
|
inputs=inputs,
|
|
1600
1898
|
seed=seed,
|
|
1601
1899
|
defensive=defensive,
|
|
1900
|
+
backend=backend,
|
|
1602
1901
|
sync_artifacts=sync_artifacts,
|
|
1603
1902
|
gpu_id=gpu_id,
|
|
1604
1903
|
)
|
|
1605
1904
|
|
|
1905
|
+
collector.emit("started", target=resolved_target)
|
|
1906
|
+
|
|
1606
1907
|
try:
|
|
1607
1908
|
import trio_asyncio
|
|
1608
1909
|
|
|
1910
|
+
collector.emit("evaluation", status="running")
|
|
1609
1911
|
result = trio_asyncio.run(run_evaluate_kernelbench, args)
|
|
1610
1912
|
except KeyboardInterrupt:
|
|
1611
|
-
|
|
1913
|
+
collector.set_error("evaluation", "Interrupted", message="Interrupted by user")
|
|
1914
|
+
collector.finalize()
|
|
1612
1915
|
raise typer.Exit(130) from None
|
|
1613
1916
|
except Exception as e:
|
|
1614
|
-
|
|
1917
|
+
collector.set_error("evaluation", "Exception", message=str(e))
|
|
1918
|
+
collector.finalize()
|
|
1615
1919
|
raise typer.Exit(1) from None
|
|
1920
|
+
finally:
|
|
1921
|
+
# Release pool lock if we acquired one
|
|
1922
|
+
if pool_lock_context is not None:
|
|
1923
|
+
pool_lock_context.__exit__(None, None, None)
|
|
1616
1924
|
|
|
1617
|
-
#
|
|
1925
|
+
# Build structured output
|
|
1926
|
+
eval_output = format_evaluate_result(result, target=resolved_target)
|
|
1927
|
+
collector._result = eval_output
|
|
1928
|
+
|
|
1929
|
+
# Print results based on output format
|
|
1618
1930
|
if result.success:
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
status = "PASS" if result.all_correct else "FAIL"
|
|
1622
|
-
typer.echo(f"Result: {status}")
|
|
1623
|
-
score_pct = f"{result.correctness_score:.1%}"
|
|
1624
|
-
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
1625
|
-
if result.geomean_speedup > 0:
|
|
1626
|
-
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
1627
|
-
typer.echo("=" * 60)
|
|
1931
|
+
collector.output_text_result(result)
|
|
1932
|
+
collector.finalize()
|
|
1628
1933
|
|
|
1629
1934
|
if not result.all_correct:
|
|
1630
1935
|
raise typer.Exit(1)
|
|
1631
1936
|
else:
|
|
1632
|
-
|
|
1937
|
+
collector.output_text_error(result.error_message or "Unknown error")
|
|
1938
|
+
collector.finalize()
|
|
1633
1939
|
raise typer.Exit(1)
|
|
1634
1940
|
|
|
1635
1941
|
|
|
@@ -1659,7 +1965,14 @@ def kernelbench_make_template(
|
|
|
1659
1965
|
# Overwrite existing
|
|
1660
1966
|
wafer evaluate kernelbench make-template level1/1 --force
|
|
1661
1967
|
"""
|
|
1662
|
-
#
|
|
1968
|
+
# Get problems root (downloaded or legacy)
|
|
1969
|
+
kb_root = _get_kernelbench_root()
|
|
1970
|
+
if kb_root is None:
|
|
1971
|
+
typer.echo("Error: KernelBench problems not found.", err=True)
|
|
1972
|
+
typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
|
|
1973
|
+
raise typer.Exit(1)
|
|
1974
|
+
|
|
1975
|
+
# Parse problem ID
|
|
1663
1976
|
parts = problem.split("/")
|
|
1664
1977
|
if len(parts) != 2:
|
|
1665
1978
|
typer.echo(f"Error: Invalid problem ID '{problem}'. Expected format: level1/1", err=True)
|
|
@@ -1670,10 +1983,10 @@ def kernelbench_make_template(
|
|
|
1670
1983
|
level_str = f"level{level_str}"
|
|
1671
1984
|
|
|
1672
1985
|
# Find the problem file
|
|
1673
|
-
problem_dir =
|
|
1986
|
+
problem_dir = kb_root / level_str
|
|
1674
1987
|
if not problem_dir.exists():
|
|
1675
1988
|
typer.echo(f"Error: KernelBench level directory not found: {problem_dir}", err=True)
|
|
1676
|
-
typer.echo(
|
|
1989
|
+
typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
|
|
1677
1990
|
raise typer.Exit(1)
|
|
1678
1991
|
|
|
1679
1992
|
# Find matching problem file
|
|
@@ -1740,6 +2053,306 @@ def kernelbench_make_template(
|
|
|
1740
2053
|
typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
|
|
1741
2054
|
|
|
1742
2055
|
|
|
2056
|
+
# =============================================================================
|
|
2057
|
+
# GPUMode format evaluation
|
|
2058
|
+
# =============================================================================
|
|
2059
|
+
|
|
2060
|
+
|
|
2061
|
+
@gpumode_app.command("download")
|
|
2062
|
+
def gpumode_download(
|
|
2063
|
+
force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
|
|
2064
|
+
) -> None:
|
|
2065
|
+
"""Download GPUMode reference kernels from GitHub.
|
|
2066
|
+
|
|
2067
|
+
Downloads the problem set to ~/.cache/wafer/problems/gpumode/
|
|
2068
|
+
|
|
2069
|
+
Examples:
|
|
2070
|
+
wafer evaluate gpumode download
|
|
2071
|
+
wafer evaluate gpumode download --force # Re-download
|
|
2072
|
+
"""
|
|
2073
|
+
try:
|
|
2074
|
+
path = download_problems("gpumode", force=force, verbose=True)
|
|
2075
|
+
typer.echo("")
|
|
2076
|
+
typer.echo(f"Problems available at: {path}")
|
|
2077
|
+
typer.echo("Run 'wafer evaluate gpumode list-problems' to see available problems.")
|
|
2078
|
+
except Exception as e:
|
|
2079
|
+
typer.echo(f"Error downloading problems: {e}", err=True)
|
|
2080
|
+
raise typer.Exit(1) from None
|
|
2081
|
+
|
|
2082
|
+
|
|
2083
|
+
@gpumode_app.command("list-problems")
|
|
2084
|
+
def gpumode_list_problems() -> None:
|
|
2085
|
+
"""List available GPUMode problems.
|
|
2086
|
+
|
|
2087
|
+
Examples:
|
|
2088
|
+
wafer evaluate gpumode list-problems
|
|
2089
|
+
"""
|
|
2090
|
+
try:
|
|
2091
|
+
list_problems_fn("gpumode", verbose=True)
|
|
2092
|
+
except ValueError as e:
|
|
2093
|
+
typer.echo(str(e), err=True)
|
|
2094
|
+
raise typer.Exit(1) from None
|
|
2095
|
+
|
|
2096
|
+
|
|
2097
|
+
@gpumode_app.command("make-template")
|
|
2098
|
+
def gpumode_make_template(
|
|
2099
|
+
problem: str = typer.Option(
|
|
2100
|
+
...,
|
|
2101
|
+
"--problem",
|
|
2102
|
+
"-p",
|
|
2103
|
+
help="Problem ID (e.g., 'pmpp/vectoradd_py' or 'amd/fp8-mm')",
|
|
2104
|
+
),
|
|
2105
|
+
output: Path = typer.Option(
|
|
2106
|
+
None, "--output", "-o", help="Output directory (default: ./<problem_name>/)"
|
|
2107
|
+
),
|
|
2108
|
+
force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
|
|
2109
|
+
) -> None:
|
|
2110
|
+
"""Extract a GPUMode problem as template files.
|
|
2111
|
+
|
|
2112
|
+
Creates a directory with reference.py, task.yml, and other problem files.
|
|
2113
|
+
You then create kernel.py with your custom_kernel implementation.
|
|
2114
|
+
|
|
2115
|
+
Examples:
|
|
2116
|
+
# Extract pmpp vectoradd problem
|
|
2117
|
+
wafer evaluate gpumode make-template --problem pmpp/vectoradd_py
|
|
2118
|
+
|
|
2119
|
+
# Extract to specific directory
|
|
2120
|
+
wafer evaluate gpumode make-template --problem pmpp/vectoradd_py --output ./my-kernel/
|
|
2121
|
+
"""
|
|
2122
|
+
import shutil
|
|
2123
|
+
|
|
2124
|
+
# Get problem path
|
|
2125
|
+
problem_path = get_problem_path("gpumode", problem)
|
|
2126
|
+
if problem_path is None:
|
|
2127
|
+
# Check if problems are downloaded
|
|
2128
|
+
if get_problems_path("gpumode") is None:
|
|
2129
|
+
typer.echo("Error: GPUMode problems not downloaded.", err=True)
|
|
2130
|
+
typer.echo("Run 'wafer evaluate gpumode download' first.", err=True)
|
|
2131
|
+
else:
|
|
2132
|
+
typer.echo(f"Error: Problem '{problem}' not found.", err=True)
|
|
2133
|
+
typer.echo(
|
|
2134
|
+
"Run 'wafer evaluate gpumode list-problems' to see available problems.", err=True
|
|
2135
|
+
)
|
|
2136
|
+
raise typer.Exit(1)
|
|
2137
|
+
|
|
2138
|
+
# Determine output path
|
|
2139
|
+
if output is None:
|
|
2140
|
+
output = Path.cwd() / problem.replace("/", "_")
|
|
2141
|
+
|
|
2142
|
+
output = output.resolve()
|
|
2143
|
+
|
|
2144
|
+
# Check if exists
|
|
2145
|
+
if output.exists() and not force:
|
|
2146
|
+
typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
|
|
2147
|
+
raise typer.Exit(1)
|
|
2148
|
+
|
|
2149
|
+
# Copy the problem directory
|
|
2150
|
+
if output.exists():
|
|
2151
|
+
shutil.rmtree(output)
|
|
2152
|
+
shutil.copytree(problem_path, output)
|
|
2153
|
+
|
|
2154
|
+
typer.echo(f"Created {output}/")
|
|
2155
|
+
typer.echo("")
|
|
2156
|
+
typer.echo("Contents:")
|
|
2157
|
+
for f in sorted(output.iterdir()):
|
|
2158
|
+
if not f.name.startswith("."):
|
|
2159
|
+
typer.echo(f" {f.name}")
|
|
2160
|
+
typer.echo("")
|
|
2161
|
+
typer.echo("Next steps:")
|
|
2162
|
+
typer.echo(" 1. Read reference.py to understand the kernel interface")
|
|
2163
|
+
typer.echo(" 2. Create kernel.py with your custom_kernel implementation:")
|
|
2164
|
+
typer.echo("")
|
|
2165
|
+
typer.echo(" def custom_kernel(data):")
|
|
2166
|
+
typer.echo(" # Your optimized implementation")
|
|
2167
|
+
typer.echo(" ...")
|
|
2168
|
+
typer.echo("")
|
|
2169
|
+
typer.echo(" 3. Run evaluation:")
|
|
2170
|
+
typer.echo(
|
|
2171
|
+
f" wafer evaluate gpumode --impl {output}/kernel.py --reference {output}/reference.py \\"
|
|
2172
|
+
)
|
|
2173
|
+
typer.echo(f" --test-cases {output}/test_cases.json --target <target>")
|
|
2174
|
+
|
|
2175
|
+
|
|
2176
|
+
@gpumode_app.callback(invoke_without_command=True)
|
|
2177
|
+
def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
2178
|
+
ctx: typer.Context,
|
|
2179
|
+
implementation: Path | None = typer.Option(
|
|
2180
|
+
None, "--impl", "-i", help="Path to implementation kernel file"
|
|
2181
|
+
),
|
|
2182
|
+
reference: Path | None = typer.Option(
|
|
2183
|
+
None, "--reference", help="Path to reference kernel file"
|
|
2184
|
+
),
|
|
2185
|
+
test_cases: Path | None = typer.Option(
|
|
2186
|
+
None, "--test-cases", help="Path to test cases JSON file"
|
|
2187
|
+
),
|
|
2188
|
+
target: str | None = typer.Option(
|
|
2189
|
+
None,
|
|
2190
|
+
"--target",
|
|
2191
|
+
"-t",
|
|
2192
|
+
help="GPU target name. See 'wafer config targets list' for available targets.",
|
|
2193
|
+
autocompletion=complete_target_name,
|
|
2194
|
+
),
|
|
2195
|
+
pool: str | None = typer.Option(
|
|
2196
|
+
None,
|
|
2197
|
+
"--pool",
|
|
2198
|
+
"-p",
|
|
2199
|
+
help="Target pool name. Acquires first available target from the pool. "
|
|
2200
|
+
"Define pools in ~/.wafer/config.toml under [pools.<name>].",
|
|
2201
|
+
),
|
|
2202
|
+
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
2203
|
+
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
2204
|
+
defensive: bool = typer.Option(
|
|
2205
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
2206
|
+
),
|
|
2207
|
+
sync_artifacts: bool = typer.Option(
|
|
2208
|
+
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
2209
|
+
),
|
|
2210
|
+
gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
|
|
2211
|
+
) -> None:
|
|
2212
|
+
"""Run kernel evaluation in GPUMode format (functional).
|
|
2213
|
+
|
|
2214
|
+
This format expects:
|
|
2215
|
+
- Implementation: Python file with `custom_kernel(inputs)` function
|
|
2216
|
+
- Reference: Python file with `ref_kernel(inputs)` and `generate_input(**kwargs)` functions
|
|
2217
|
+
- Test cases: JSON file with test parameters
|
|
2218
|
+
|
|
2219
|
+
Examples:
|
|
2220
|
+
# Basic correctness check
|
|
2221
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
|
|
2222
|
+
|
|
2223
|
+
# With benchmarking
|
|
2224
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
2225
|
+
--target vultr-b200 --benchmark
|
|
2226
|
+
|
|
2227
|
+
Subcommands:
|
|
2228
|
+
download Download GPUMode problems from GitHub
|
|
2229
|
+
list-problems List available problems
|
|
2230
|
+
make-template Extract a problem as template files
|
|
2231
|
+
"""
|
|
2232
|
+
# If a subcommand is being invoked, skip the main evaluation logic
|
|
2233
|
+
if ctx.invoked_subcommand is not None:
|
|
2234
|
+
return
|
|
2235
|
+
|
|
2236
|
+
# Validate required args when running evaluation (not subcommands)
|
|
2237
|
+
missing_args = []
|
|
2238
|
+
if implementation is None:
|
|
2239
|
+
missing_args.append("--impl/-i")
|
|
2240
|
+
if reference is None:
|
|
2241
|
+
missing_args.append("--reference")
|
|
2242
|
+
if test_cases is None:
|
|
2243
|
+
missing_args.append("--test-cases")
|
|
2244
|
+
|
|
2245
|
+
if missing_args:
|
|
2246
|
+
typer.echo("Error: Missing required arguments", err=True)
|
|
2247
|
+
typer.echo(f" Required: {', '.join(missing_args)}", err=True)
|
|
2248
|
+
typer.echo("", err=True)
|
|
2249
|
+
typer.echo(
|
|
2250
|
+
"Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
|
|
2251
|
+
err=True,
|
|
2252
|
+
)
|
|
2253
|
+
typer.echo("", err=True)
|
|
2254
|
+
typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
|
|
2255
|
+
typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
|
|
2256
|
+
raise typer.Exit(1)
|
|
2257
|
+
|
|
2258
|
+
# Validate --target and --pool are mutually exclusive
|
|
2259
|
+
if target and pool:
|
|
2260
|
+
typer.echo("Error: Cannot specify both --target and --pool", err=True)
|
|
2261
|
+
raise typer.Exit(1)
|
|
2262
|
+
|
|
2263
|
+
from .evaluate import EvaluateArgs, run_evaluate
|
|
2264
|
+
|
|
2265
|
+
# If pool specified, acquire a target from the pool
|
|
2266
|
+
resolved_target = target or ""
|
|
2267
|
+
pool_lock_context = None
|
|
2268
|
+
|
|
2269
|
+
if pool:
|
|
2270
|
+
from .target_lock import acquire_from_pool
|
|
2271
|
+
from .targets import filter_pool_by_auth, get_pool
|
|
2272
|
+
|
|
2273
|
+
try:
|
|
2274
|
+
pool_targets = get_pool(pool)
|
|
2275
|
+
except FileNotFoundError as e:
|
|
2276
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2277
|
+
raise typer.Exit(1) from None
|
|
2278
|
+
|
|
2279
|
+
# Filter to only targets with valid auth
|
|
2280
|
+
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
2281
|
+
if skipped:
|
|
2282
|
+
typer.echo(f"Skipping targets without auth: {', '.join(skipped)}", err=True)
|
|
2283
|
+
|
|
2284
|
+
if not usable_targets:
|
|
2285
|
+
typer.echo(f"Error: No usable targets in pool '{pool}'", err=True)
|
|
2286
|
+
typer.echo(" All targets require authentication that is not configured.", err=True)
|
|
2287
|
+
typer.echo(" Run 'wafer auth status' to see which providers need setup.", err=True)
|
|
2288
|
+
raise typer.Exit(1) from None
|
|
2289
|
+
|
|
2290
|
+
typer.echo(f"Acquiring target from pool '{pool}' ({len(usable_targets)} targets)...")
|
|
2291
|
+
pool_lock_context = acquire_from_pool(usable_targets)
|
|
2292
|
+
acquired_target = pool_lock_context.__enter__()
|
|
2293
|
+
|
|
2294
|
+
if acquired_target is None:
|
|
2295
|
+
# Exit context manager before raising to avoid resource leak
|
|
2296
|
+
pool_lock_context.__exit__(None, None, None)
|
|
2297
|
+
typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
|
|
2298
|
+
typer.echo(f" Targets: {', '.join(usable_targets)}", err=True)
|
|
2299
|
+
raise typer.Exit(1)
|
|
2300
|
+
|
|
2301
|
+
typer.echo(f"Acquired target: {acquired_target}")
|
|
2302
|
+
resolved_target = acquired_target
|
|
2303
|
+
|
|
2304
|
+
args = EvaluateArgs(
|
|
2305
|
+
implementation=implementation,
|
|
2306
|
+
reference=reference,
|
|
2307
|
+
test_cases=test_cases,
|
|
2308
|
+
target_name=resolved_target,
|
|
2309
|
+
benchmark=benchmark,
|
|
2310
|
+
profile=profile,
|
|
2311
|
+
defensive=defensive,
|
|
2312
|
+
sync_artifacts=sync_artifacts,
|
|
2313
|
+
gpu_id=gpu_id,
|
|
2314
|
+
)
|
|
2315
|
+
|
|
2316
|
+
try:
|
|
2317
|
+
import trio_asyncio
|
|
2318
|
+
|
|
2319
|
+
result = trio_asyncio.run(run_evaluate, args)
|
|
2320
|
+
except KeyboardInterrupt:
|
|
2321
|
+
typer.echo("\nInterrupted by user", err=True)
|
|
2322
|
+
raise typer.Exit(130) from None
|
|
2323
|
+
except Exception as e:
|
|
2324
|
+
if hasattr(e, "exceptions") and e.exceptions:
|
|
2325
|
+
for exc in e.exceptions:
|
|
2326
|
+
typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
|
|
2327
|
+
else:
|
|
2328
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2329
|
+
raise typer.Exit(1) from None
|
|
2330
|
+
finally:
|
|
2331
|
+
# Release pool lock if we acquired one
|
|
2332
|
+
if pool_lock_context is not None:
|
|
2333
|
+
pool_lock_context.__exit__(None, None, None)
|
|
2334
|
+
|
|
2335
|
+
# Print results
|
|
2336
|
+
if result.success:
|
|
2337
|
+
typer.echo("")
|
|
2338
|
+
typer.echo("=" * 60)
|
|
2339
|
+
status = "PASS" if result.all_correct else "FAIL"
|
|
2340
|
+
typer.echo(f"Result: {status}")
|
|
2341
|
+
score_pct = f"{result.correctness_score:.1%}"
|
|
2342
|
+
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
2343
|
+
if result.geomean_speedup > 0:
|
|
2344
|
+
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
2345
|
+
if result.artifact_path:
|
|
2346
|
+
typer.echo(f"Artifacts: {result.artifact_path}")
|
|
2347
|
+
typer.echo("=" * 60)
|
|
2348
|
+
|
|
2349
|
+
if not result.all_correct:
|
|
2350
|
+
raise typer.Exit(1)
|
|
2351
|
+
else:
|
|
2352
|
+
typer.echo(f"Error: {result.error_message}", err=True)
|
|
2353
|
+
raise typer.Exit(1)
|
|
2354
|
+
|
|
2355
|
+
|
|
1743
2356
|
# =============================================================================
|
|
1744
2357
|
# Push and Remote-Run commands
|
|
1745
2358
|
# =============================================================================
|
|
@@ -1871,7 +2484,7 @@ def _run_direct_mode(
|
|
|
1871
2484
|
typer.echo(f"Uploading {upload_dir.name}...")
|
|
1872
2485
|
try:
|
|
1873
2486
|
push_result = push_direct(upload_dir, target)
|
|
1874
|
-
workspace_name = push_result.
|
|
2487
|
+
workspace_name = push_result.workspace_name
|
|
1875
2488
|
typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
|
|
1876
2489
|
except Exception as e:
|
|
1877
2490
|
typer.echo(f"Error uploading: {e}", err=True)
|
|
@@ -1901,6 +2514,7 @@ def _run_api_mode( # noqa: PLR0913
|
|
|
1901
2514
|
upload_dir: Path | None,
|
|
1902
2515
|
workspace_id: str | None,
|
|
1903
2516
|
gpu_id: int | None,
|
|
2517
|
+
gpu_count: int,
|
|
1904
2518
|
docker_image: str | None,
|
|
1905
2519
|
docker_entrypoint: str | None,
|
|
1906
2520
|
pull_image: bool,
|
|
@@ -1915,6 +2529,8 @@ def _run_api_mode( # noqa: PLR0913
|
|
|
1915
2529
|
typer.echo(f"Workspace: {workspace_id}")
|
|
1916
2530
|
if gpu_id is not None:
|
|
1917
2531
|
typer.echo(f"GPU: {gpu_id}")
|
|
2532
|
+
if gpu_count > 1:
|
|
2533
|
+
typer.echo(f"GPU count: {gpu_count}")
|
|
1918
2534
|
if docker_image:
|
|
1919
2535
|
typer.echo(f"Image: {docker_image}")
|
|
1920
2536
|
if docker_entrypoint:
|
|
@@ -1932,6 +2548,7 @@ def _run_api_mode( # noqa: PLR0913
|
|
|
1932
2548
|
upload_dir=upload_dir,
|
|
1933
2549
|
workspace_id=workspace_id,
|
|
1934
2550
|
gpu_id=gpu_id,
|
|
2551
|
+
gpu_count=gpu_count,
|
|
1935
2552
|
docker_image=docker_image,
|
|
1936
2553
|
docker_entrypoint=docker_entrypoint,
|
|
1937
2554
|
pull_image=pull_image,
|
|
@@ -1955,6 +2572,7 @@ def remote_run( # noqa: PLR0913
|
|
|
1955
2572
|
None, "--workspace-id", "-w", help="Workspace ID (from wafer push)"
|
|
1956
2573
|
),
|
|
1957
2574
|
gpu_id: int | None = typer.Option(None, "--gpu", "-g", help="GPU ID"),
|
|
2575
|
+
gpu_count: int = typer.Option(1, "--gpu-count", "-n", help="Number of GPUs (1-8)"),
|
|
1958
2576
|
docker_image: str | None = typer.Option(None, "--image", "-i", help="Docker image override"),
|
|
1959
2577
|
docker_entrypoint: str | None = typer.Option(
|
|
1960
2578
|
None, "--docker-entrypoint", help="Override Docker entrypoint (e.g., 'bash')"
|
|
@@ -2024,6 +2642,7 @@ def remote_run( # noqa: PLR0913
|
|
|
2024
2642
|
upload_dir,
|
|
2025
2643
|
workspace_id,
|
|
2026
2644
|
gpu_id,
|
|
2645
|
+
gpu_count,
|
|
2027
2646
|
docker_image,
|
|
2028
2647
|
docker_entrypoint,
|
|
2029
2648
|
pull_image,
|
|
@@ -2044,27 +2663,41 @@ def login(
|
|
|
2044
2663
|
None, "--token", "-t", help="Access token (skip browser OAuth)"
|
|
2045
2664
|
),
|
|
2046
2665
|
port: int | None = typer.Option(
|
|
2047
|
-
None,
|
|
2666
|
+
None,
|
|
2667
|
+
"--port",
|
|
2668
|
+
"-p",
|
|
2669
|
+
help="Port for OAuth callback server (local only, ignored for SSH)",
|
|
2670
|
+
),
|
|
2671
|
+
no_device_code: bool = typer.Option(
|
|
2672
|
+
False,
|
|
2673
|
+
"--no-device-code",
|
|
2674
|
+
help="Force browser OAuth even on SSH (requires port forwarding)",
|
|
2048
2675
|
),
|
|
2049
2676
|
) -> None:
|
|
2050
2677
|
"""Authenticate CLI with wafer-api via GitHub OAuth.
|
|
2051
2678
|
|
|
2052
|
-
Opens browser for GitHub authentication.
|
|
2679
|
+
Local: Opens browser for GitHub authentication.
|
|
2680
|
+
SSH: Uses device code flow (no port forwarding needed).
|
|
2681
|
+
|
|
2053
2682
|
Uses the API environment from config (see 'wafer config show').
|
|
2054
2683
|
|
|
2055
|
-
SSH Users:
|
|
2056
|
-
-
|
|
2057
|
-
-
|
|
2058
|
-
-
|
|
2059
|
-
|
|
2684
|
+
SSH Users (Easiest):
|
|
2685
|
+
- Just run: wafer login
|
|
2686
|
+
- Visit the URL and enter the code shown
|
|
2687
|
+
- No port forwarding needed!
|
|
2688
|
+
|
|
2689
|
+
SSH with browser (Advanced):
|
|
2690
|
+
- Use --no-device-code to force browser flow
|
|
2691
|
+
- Requires: ssh -L 8765:localhost:8765 user@host
|
|
2060
2692
|
|
|
2061
2693
|
Manual token option:
|
|
2062
2694
|
- Visit auth.wafer.ai, authenticate, copy token from URL
|
|
2063
2695
|
- Run: wafer login --token <paste-token>
|
|
2064
2696
|
|
|
2065
2697
|
Examples:
|
|
2066
|
-
wafer login #
|
|
2067
|
-
wafer login --port
|
|
2698
|
+
wafer login # device code on SSH, browser on local
|
|
2699
|
+
wafer login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2700
|
+
wafer login --port 9000 # custom port for browser flow
|
|
2068
2701
|
wafer login --token xyz # manual token (no browser)
|
|
2069
2702
|
|
|
2070
2703
|
# Change environment:
|
|
@@ -2073,7 +2706,7 @@ def login(
|
|
|
2073
2706
|
"""
|
|
2074
2707
|
import httpx
|
|
2075
2708
|
|
|
2076
|
-
from .auth import browser_login, save_credentials, verify_token
|
|
2709
|
+
from .auth import browser_login, device_code_login, save_credentials, verify_token
|
|
2077
2710
|
from .global_config import get_api_url, get_supabase_url, load_global_config
|
|
2078
2711
|
|
|
2079
2712
|
# Show which environment we're logging into
|
|
@@ -2083,21 +2716,31 @@ def login(
|
|
|
2083
2716
|
typer.echo(f"Auth: {get_supabase_url()}")
|
|
2084
2717
|
typer.echo("")
|
|
2085
2718
|
|
|
2086
|
-
# Auto-detect SSH
|
|
2087
|
-
|
|
2088
|
-
is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
|
|
2089
|
-
if is_ssh:
|
|
2090
|
-
port = 8765
|
|
2091
|
-
typer.echo("🔒 SSH session detected - using port 8765 for OAuth callback")
|
|
2092
|
-
typer.echo(" Make sure you have port forwarding set up:")
|
|
2093
|
-
typer.echo(" ssh -L 8765:localhost:8765 user@host")
|
|
2094
|
-
typer.echo("")
|
|
2719
|
+
# Auto-detect SSH
|
|
2720
|
+
is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
|
|
2095
2721
|
|
|
2096
|
-
#
|
|
2722
|
+
# Choose auth method
|
|
2097
2723
|
refresh_token = None
|
|
2098
2724
|
if token is None:
|
|
2099
2725
|
try:
|
|
2100
|
-
|
|
2726
|
+
if is_ssh and not no_device_code:
|
|
2727
|
+
# Use device code flow for SSH (no port forwarding needed)
|
|
2728
|
+
typer.echo("🔒 SSH session detected - using device code authentication")
|
|
2729
|
+
typer.echo(" (No port forwarding required!)")
|
|
2730
|
+
typer.echo("")
|
|
2731
|
+
token, refresh_token = device_code_login()
|
|
2732
|
+
else:
|
|
2733
|
+
# Use browser OAuth for local or if explicitly requested
|
|
2734
|
+
if is_ssh:
|
|
2735
|
+
typer.echo("🔒 SSH session detected - using browser authentication")
|
|
2736
|
+
typer.echo(" Make sure you have port forwarding set up:")
|
|
2737
|
+
if port is None:
|
|
2738
|
+
port = 8765
|
|
2739
|
+
typer.echo(f" ssh -L {port}:localhost:{port} user@host")
|
|
2740
|
+
else:
|
|
2741
|
+
typer.echo(f" ssh -L {port}:localhost:{port} user@host")
|
|
2742
|
+
typer.echo("")
|
|
2743
|
+
token, refresh_token = browser_login(port=port)
|
|
2101
2744
|
except TimeoutError as e:
|
|
2102
2745
|
typer.echo(f"Error: {e}", err=True)
|
|
2103
2746
|
raise typer.Exit(1) from None
|
|
@@ -2146,9 +2789,8 @@ def login(
|
|
|
2146
2789
|
@app.command("logout")
|
|
2147
2790
|
def logout() -> None:
|
|
2148
2791
|
"""Remove stored credentials."""
|
|
2149
|
-
from .auth import clear_credentials
|
|
2150
|
-
|
|
2151
2792
|
from . import analytics
|
|
2793
|
+
from .auth import clear_credentials
|
|
2152
2794
|
|
|
2153
2795
|
# Track logout event first (while credentials still exist for user identification)
|
|
2154
2796
|
# Note: track_logout() handles the case where user is not logged in
|
|
@@ -2625,6 +3267,7 @@ init_app = typer.Typer(
|
|
|
2625
3267
|
|
|
2626
3268
|
Choose based on your GPU access:
|
|
2627
3269
|
|
|
3270
|
+
local GPU on current machine (no SSH)
|
|
2628
3271
|
ssh Your own hardware via SSH
|
|
2629
3272
|
runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
|
|
2630
3273
|
digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
|
|
@@ -2632,57 +3275,143 @@ Choose based on your GPU access:
|
|
|
2632
3275
|
targets_app.add_typer(init_app, name="init")
|
|
2633
3276
|
|
|
2634
3277
|
|
|
2635
|
-
@init_app.command("
|
|
2636
|
-
def
|
|
2637
|
-
name: str = typer.Option("
|
|
2638
|
-
|
|
2639
|
-
ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
|
|
2640
|
-
keep_alive: bool = typer.Option(
|
|
2641
|
-
True, "--keep-alive/--no-keep-alive", help="Keep pod running after eval"
|
|
2642
|
-
),
|
|
3278
|
+
@init_app.command("local")
|
|
3279
|
+
def init_local(
|
|
3280
|
+
name: str = typer.Option("local", "--name", "-n", help="Target name"),
|
|
3281
|
+
gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
|
|
2643
3282
|
) -> None:
|
|
2644
|
-
"""Initialize a
|
|
3283
|
+
"""Initialize a local target for GPU on current machine.
|
|
2645
3284
|
|
|
2646
|
-
|
|
2647
|
-
|
|
3285
|
+
Detects your local GPU and configures a target for direct execution
|
|
3286
|
+
(no SSH). Use this when running wafer on the same machine as the GPU.
|
|
2648
3287
|
|
|
2649
3288
|
Examples:
|
|
2650
|
-
wafer config targets init
|
|
2651
|
-
wafer config targets init
|
|
3289
|
+
wafer config targets init local
|
|
3290
|
+
wafer config targets init local --name my-5090 --gpu-ids 0,1
|
|
2652
3291
|
"""
|
|
2653
|
-
import os
|
|
2654
|
-
|
|
2655
3292
|
from .targets import save_target
|
|
2656
3293
|
|
|
2657
|
-
#
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
typer.echo("", err=True)
|
|
2662
|
-
typer.
|
|
2663
|
-
typer.echo("Then run: export WAFER_RUNPOD_API_KEY=your_key_here", err=True)
|
|
2664
|
-
raise typer.Exit(1)
|
|
3294
|
+
# Parse GPU IDs
|
|
3295
|
+
try:
|
|
3296
|
+
parsed_gpu_ids = [int(g.strip()) for g in gpu_ids.split(",")]
|
|
3297
|
+
except ValueError:
|
|
3298
|
+
typer.echo(f"Error: Invalid GPU IDs '{gpu_ids}'. Use comma-separated integers.", err=True)
|
|
3299
|
+
raise typer.Exit(1) from None
|
|
2665
3300
|
|
|
2666
|
-
|
|
2667
|
-
gpu_configs = {
|
|
2668
|
-
"MI300X": {
|
|
2669
|
-
"gpu_type_id": "AMD Instinct MI300X OAM",
|
|
2670
|
-
"image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
|
|
2671
|
-
"compute_capability": "9.4",
|
|
2672
|
-
},
|
|
2673
|
-
"H100": {
|
|
2674
|
-
"gpu_type_id": "NVIDIA H100 80GB HBM3",
|
|
2675
|
-
"image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
|
|
2676
|
-
"compute_capability": "9.0",
|
|
2677
|
-
},
|
|
2678
|
-
"A100": {
|
|
2679
|
-
"gpu_type_id": "NVIDIA A100 80GB PCIe",
|
|
2680
|
-
"image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
|
|
2681
|
-
"compute_capability": "8.0",
|
|
2682
|
-
},
|
|
2683
|
-
}
|
|
3301
|
+
typer.echo("Detecting local GPU...")
|
|
2684
3302
|
|
|
2685
|
-
|
|
3303
|
+
try:
|
|
3304
|
+
from wafer_core.gpu_detect import (
|
|
3305
|
+
detect_local_gpu,
|
|
3306
|
+
get_compute_capability,
|
|
3307
|
+
get_torch_requirements,
|
|
3308
|
+
)
|
|
3309
|
+
|
|
3310
|
+
detected_gpu = detect_local_gpu()
|
|
3311
|
+
|
|
3312
|
+
if detected_gpu:
|
|
3313
|
+
typer.echo(f" Found: {detected_gpu.gpu_name}")
|
|
3314
|
+
if detected_gpu.vendor == "nvidia":
|
|
3315
|
+
typer.echo(f" CUDA: {detected_gpu.driver_version}")
|
|
3316
|
+
else:
|
|
3317
|
+
typer.echo(f" ROCm: {detected_gpu.driver_version}")
|
|
3318
|
+
typer.echo(f" GPU count: {detected_gpu.gpu_count}")
|
|
3319
|
+
|
|
3320
|
+
# Get torch requirements and compute capability
|
|
3321
|
+
torch_reqs = get_torch_requirements(detected_gpu)
|
|
3322
|
+
compute_capability = get_compute_capability(detected_gpu)
|
|
3323
|
+
gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
|
|
3324
|
+
|
|
3325
|
+
typer.echo(f" PyTorch: {torch_reqs.packages[0]}")
|
|
3326
|
+
else:
|
|
3327
|
+
typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)", err=True)
|
|
3328
|
+
raise typer.Exit(1)
|
|
3329
|
+
|
|
3330
|
+
except ImportError as e:
|
|
3331
|
+
typer.echo(f"Error: Missing dependency: {e}", err=True)
|
|
3332
|
+
raise typer.Exit(1) from None
|
|
3333
|
+
|
|
3334
|
+
# Build target data
|
|
3335
|
+
target_data = {
|
|
3336
|
+
"name": name,
|
|
3337
|
+
"type": "local",
|
|
3338
|
+
"gpu_ids": parsed_gpu_ids,
|
|
3339
|
+
"gpu_type": gpu_type,
|
|
3340
|
+
"compute_capability": compute_capability,
|
|
3341
|
+
"torch_package": torch_reqs.packages[0],
|
|
3342
|
+
"torch_index_url": torch_reqs.index_url,
|
|
3343
|
+
"vendor": detected_gpu.vendor,
|
|
3344
|
+
"driver_version": detected_gpu.driver_version,
|
|
3345
|
+
}
|
|
3346
|
+
|
|
3347
|
+
try:
|
|
3348
|
+
target = save_target(target_data)
|
|
3349
|
+
typer.echo(f"✓ Created target: {target.name}")
|
|
3350
|
+
typer.echo(" Type: Local (no SSH)")
|
|
3351
|
+
typer.echo(f" GPU IDs: {parsed_gpu_ids}")
|
|
3352
|
+
typer.echo(f" GPU Type: {gpu_type}")
|
|
3353
|
+
typer.echo(f" Compute: {compute_capability}")
|
|
3354
|
+
typer.echo(f" Torch: {torch_reqs.packages[0]}")
|
|
3355
|
+
typer.echo("")
|
|
3356
|
+
typer.echo(
|
|
3357
|
+
f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
|
|
3358
|
+
)
|
|
3359
|
+
except (ValueError, AssertionError) as e:
|
|
3360
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3361
|
+
raise typer.Exit(1) from None
|
|
3362
|
+
|
|
3363
|
+
|
|
3364
|
+
@init_app.command("runpod")
|
|
3365
|
+
def init_runpod(
|
|
3366
|
+
name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
|
|
3367
|
+
gpu_type: str = typer.Option("MI300X", "--gpu", "-g", help="GPU type (MI300X, H100, A100)"),
|
|
3368
|
+
ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
|
|
3369
|
+
keep_alive: bool = typer.Option(
|
|
3370
|
+
True, "--keep-alive/--no-keep-alive", help="Keep pod running after eval"
|
|
3371
|
+
),
|
|
3372
|
+
) -> None:
|
|
3373
|
+
"""Initialize a RunPod target.
|
|
3374
|
+
|
|
3375
|
+
Creates a target config for auto-provisioned RunPod GPUs.
|
|
3376
|
+
Requires WAFER_RUNPOD_API_KEY environment variable.
|
|
3377
|
+
|
|
3378
|
+
Examples:
|
|
3379
|
+
wafer config targets init runpod
|
|
3380
|
+
wafer config targets init runpod --name my-runpod --gpu H100
|
|
3381
|
+
"""
|
|
3382
|
+
import os
|
|
3383
|
+
|
|
3384
|
+
from .targets import save_target
|
|
3385
|
+
|
|
3386
|
+
# Check for API key
|
|
3387
|
+
api_key = os.environ.get("WAFER_RUNPOD_API_KEY", "")
|
|
3388
|
+
if not api_key:
|
|
3389
|
+
typer.echo("Error: WAFER_RUNPOD_API_KEY environment variable not set.", err=True)
|
|
3390
|
+
typer.echo("", err=True)
|
|
3391
|
+
typer.echo("Get your API key from: https://runpod.io/console/user/settings", err=True)
|
|
3392
|
+
typer.echo("Then run: export WAFER_RUNPOD_API_KEY=your_key_here", err=True)
|
|
3393
|
+
raise typer.Exit(1)
|
|
3394
|
+
|
|
3395
|
+
# GPU type mappings
|
|
3396
|
+
gpu_configs = {
|
|
3397
|
+
"MI300X": {
|
|
3398
|
+
"gpu_type_id": "AMD Instinct MI300X OAM",
|
|
3399
|
+
"image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
|
|
3400
|
+
"compute_capability": "9.4",
|
|
3401
|
+
},
|
|
3402
|
+
"H100": {
|
|
3403
|
+
"gpu_type_id": "NVIDIA H100 80GB HBM3",
|
|
3404
|
+
"image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
|
|
3405
|
+
"compute_capability": "9.0",
|
|
3406
|
+
},
|
|
3407
|
+
"A100": {
|
|
3408
|
+
"gpu_type_id": "NVIDIA A100 80GB PCIe",
|
|
3409
|
+
"image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
|
|
3410
|
+
"compute_capability": "8.0",
|
|
3411
|
+
},
|
|
3412
|
+
}
|
|
3413
|
+
|
|
3414
|
+
if gpu_type not in gpu_configs:
|
|
2686
3415
|
typer.echo(
|
|
2687
3416
|
f"Error: Unknown GPU type '{gpu_type}'. Available: {', '.join(gpu_configs.keys())}",
|
|
2688
3417
|
err=True,
|
|
@@ -2795,23 +3524,29 @@ def init_ssh(
|
|
|
2795
3524
|
host: str = typer.Option(..., "--host", "-H", help="SSH host (user@hostname:port)"),
|
|
2796
3525
|
ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
|
|
2797
3526
|
gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
|
|
2798
|
-
gpu_type: str = typer.Option(
|
|
2799
|
-
|
|
3527
|
+
gpu_type: str | None = typer.Option(
|
|
3528
|
+
None, "--gpu-type", help="GPU type (auto-detected if not specified)"
|
|
2800
3529
|
),
|
|
2801
3530
|
docker_image: str | None = typer.Option(
|
|
2802
3531
|
None, "--docker-image", "-d", help="Docker image (optional)"
|
|
2803
3532
|
),
|
|
2804
3533
|
ncu: bool = typer.Option(False, "--ncu/--no-ncu", help="NCU profiling available"),
|
|
3534
|
+
no_detect: bool = typer.Option(False, "--no-detect", help="Skip GPU auto-detection"),
|
|
2805
3535
|
) -> None:
|
|
2806
3536
|
"""Initialize an SSH target for your own GPU hardware.
|
|
2807
3537
|
|
|
2808
3538
|
Creates a target config for direct SSH access to a GPU machine.
|
|
2809
|
-
|
|
3539
|
+
Automatically detects GPU type and selects compatible PyTorch version.
|
|
2810
3540
|
|
|
2811
3541
|
Examples:
|
|
3542
|
+
# Auto-detect GPU (recommended)
|
|
2812
3543
|
wafer config targets init ssh --name my-gpu --host user@192.168.1.100:22
|
|
3544
|
+
|
|
3545
|
+
# Multiple GPUs with NCU profiling
|
|
2813
3546
|
wafer config targets init ssh --name lab-h100 --host ubuntu@gpu.lab.com:22 --gpu-ids 0,1 --ncu
|
|
2814
|
-
|
|
3547
|
+
|
|
3548
|
+
# Skip detection, specify manually
|
|
3549
|
+
wafer config targets init ssh --name my-gpu --host user@host:22 --gpu-type H100 --no-detect
|
|
2815
3550
|
"""
|
|
2816
3551
|
from .targets import save_target
|
|
2817
3552
|
|
|
@@ -2828,17 +3563,86 @@ def init_ssh(
|
|
|
2828
3563
|
typer.echo("Example: user@192.168.1.100:22", err=True)
|
|
2829
3564
|
raise typer.Exit(1)
|
|
2830
3565
|
|
|
3566
|
+
# Auto-detect GPU if not specified
|
|
3567
|
+
detected_gpu = None
|
|
3568
|
+
torch_package = None
|
|
3569
|
+
torch_index_url = None
|
|
3570
|
+
|
|
3571
|
+
if not no_detect:
|
|
3572
|
+
typer.echo(f"Connecting to {host}...")
|
|
3573
|
+
try:
|
|
3574
|
+
import trio
|
|
3575
|
+
import trio_asyncio
|
|
3576
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
3577
|
+
from wafer_core.gpu_detect import (
|
|
3578
|
+
detect_remote_gpu,
|
|
3579
|
+
get_compute_capability,
|
|
3580
|
+
get_torch_requirements,
|
|
3581
|
+
)
|
|
3582
|
+
|
|
3583
|
+
expanded_key = str(Path(ssh_key).expanduser())
|
|
3584
|
+
|
|
3585
|
+
async def _detect() -> None:
|
|
3586
|
+
nonlocal detected_gpu, torch_package, torch_index_url
|
|
3587
|
+
# Need trio_asyncio.open_loop() for asyncssh bridge
|
|
3588
|
+
async with trio_asyncio.open_loop():
|
|
3589
|
+
async with AsyncSSHClient(host, expanded_key) as client:
|
|
3590
|
+
detected_gpu = await detect_remote_gpu(client)
|
|
3591
|
+
|
|
3592
|
+
trio.run(_detect)
|
|
3593
|
+
|
|
3594
|
+
if detected_gpu:
|
|
3595
|
+
typer.echo(f" Found: {detected_gpu.gpu_name}")
|
|
3596
|
+
if detected_gpu.vendor == "nvidia":
|
|
3597
|
+
typer.echo(f" CUDA: {detected_gpu.driver_version}")
|
|
3598
|
+
else:
|
|
3599
|
+
typer.echo(f" ROCm: {detected_gpu.driver_version}")
|
|
3600
|
+
|
|
3601
|
+
# Get torch requirements
|
|
3602
|
+
torch_reqs = get_torch_requirements(detected_gpu)
|
|
3603
|
+
torch_package = torch_reqs.packages[0] # Just torch, not all packages
|
|
3604
|
+
torch_index_url = torch_reqs.index_url
|
|
3605
|
+
typer.echo(f" PyTorch: {torch_package}")
|
|
3606
|
+
|
|
3607
|
+
# Use detected GPU type if not specified
|
|
3608
|
+
if not gpu_type:
|
|
3609
|
+
# Extract GPU name (e.g., "H100" from "NVIDIA H100 80GB HBM3")
|
|
3610
|
+
gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
|
|
3611
|
+
else:
|
|
3612
|
+
typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)")
|
|
3613
|
+
if not gpu_type:
|
|
3614
|
+
gpu_type = "H100" # Default fallback
|
|
3615
|
+
typer.echo(f" Using default: {gpu_type}")
|
|
3616
|
+
|
|
3617
|
+
except Exception as e:
|
|
3618
|
+
typer.echo(f" Detection failed: {e}", err=True)
|
|
3619
|
+
if not gpu_type:
|
|
3620
|
+
gpu_type = "H100"
|
|
3621
|
+
typer.echo(f" Using default: {gpu_type}")
|
|
3622
|
+
|
|
3623
|
+
# Fallback if no detection
|
|
3624
|
+
if not gpu_type:
|
|
3625
|
+
gpu_type = "H100"
|
|
3626
|
+
|
|
2831
3627
|
# Compute capability mappings
|
|
2832
|
-
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
2838
|
-
|
|
2839
|
-
|
|
2840
|
-
|
|
2841
|
-
|
|
3628
|
+
if detected_gpu:
|
|
3629
|
+
from wafer_core.gpu_detect import get_compute_capability
|
|
3630
|
+
|
|
3631
|
+
compute_capability = get_compute_capability(detected_gpu)
|
|
3632
|
+
else:
|
|
3633
|
+
compute_caps = {
|
|
3634
|
+
"B200": "10.0",
|
|
3635
|
+
"H100": "9.0",
|
|
3636
|
+
"A100": "8.0",
|
|
3637
|
+
"A10": "8.6",
|
|
3638
|
+
"V100": "7.0",
|
|
3639
|
+
"MI300X": "9.4",
|
|
3640
|
+
"MI250X": "9.0",
|
|
3641
|
+
"RTX 5090": "10.0",
|
|
3642
|
+
"RTX 4090": "8.9",
|
|
3643
|
+
"RTX 3090": "8.6",
|
|
3644
|
+
}
|
|
3645
|
+
compute_capability = compute_caps.get(gpu_type, "8.0")
|
|
2842
3646
|
|
|
2843
3647
|
# Build target data
|
|
2844
3648
|
target_data = {
|
|
@@ -2855,6 +3659,12 @@ def init_ssh(
|
|
|
2855
3659
|
if docker_image:
|
|
2856
3660
|
target_data["docker_image"] = docker_image
|
|
2857
3661
|
|
|
3662
|
+
# Add torch requirements if detected
|
|
3663
|
+
if torch_package:
|
|
3664
|
+
target_data["torch_package"] = torch_package
|
|
3665
|
+
if torch_index_url:
|
|
3666
|
+
target_data["torch_index_url"] = torch_index_url
|
|
3667
|
+
|
|
2858
3668
|
try:
|
|
2859
3669
|
target = save_target(target_data)
|
|
2860
3670
|
typer.echo(f"✓ Created target: {target.name}")
|
|
@@ -2862,9 +3672,12 @@ def init_ssh(
|
|
|
2862
3672
|
typer.echo(f" Host: {host}")
|
|
2863
3673
|
typer.echo(f" GPU IDs: {parsed_gpu_ids}")
|
|
2864
3674
|
typer.echo(f" GPU Type: {gpu_type}")
|
|
3675
|
+
typer.echo(f" Compute: {compute_capability}")
|
|
2865
3676
|
typer.echo(f" NCU: {'Yes' if ncu else 'No'}")
|
|
2866
3677
|
if docker_image:
|
|
2867
3678
|
typer.echo(f" Docker: {docker_image}")
|
|
3679
|
+
if torch_package:
|
|
3680
|
+
typer.echo(f" Torch: {torch_package}")
|
|
2868
3681
|
typer.echo("")
|
|
2869
3682
|
typer.echo(
|
|
2870
3683
|
f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
|
|
@@ -2874,6 +3687,44 @@ def init_ssh(
|
|
|
2874
3687
|
raise typer.Exit(1) from None
|
|
2875
3688
|
|
|
2876
3689
|
|
|
3690
|
+
def _extract_gpu_type(gpu_name: str) -> str:
|
|
3691
|
+
"""Extract GPU type from full GPU name.
|
|
3692
|
+
|
|
3693
|
+
Examples:
|
|
3694
|
+
"NVIDIA H100 80GB HBM3" -> "H100"
|
|
3695
|
+
"NVIDIA GeForce RTX 4090" -> "RTX 4090"
|
|
3696
|
+
"AMD Instinct MI300X OAM" -> "MI300X"
|
|
3697
|
+
"""
|
|
3698
|
+
gpu_name_upper = gpu_name.upper()
|
|
3699
|
+
|
|
3700
|
+
# Check for known GPU types
|
|
3701
|
+
known_types = [
|
|
3702
|
+
"B200",
|
|
3703
|
+
"B100",
|
|
3704
|
+
"H200",
|
|
3705
|
+
"H100",
|
|
3706
|
+
"A100",
|
|
3707
|
+
"A10",
|
|
3708
|
+
"V100",
|
|
3709
|
+
"RTX 5090",
|
|
3710
|
+
"RTX 5080",
|
|
3711
|
+
"RTX 4090",
|
|
3712
|
+
"RTX 4080",
|
|
3713
|
+
"RTX 3090",
|
|
3714
|
+
"RTX 3080",
|
|
3715
|
+
"MI300X",
|
|
3716
|
+
"MI250X",
|
|
3717
|
+
"MI100",
|
|
3718
|
+
]
|
|
3719
|
+
|
|
3720
|
+
for gpu_type in known_types:
|
|
3721
|
+
if gpu_type in gpu_name_upper:
|
|
3722
|
+
return gpu_type
|
|
3723
|
+
|
|
3724
|
+
# Fallback: return cleaned name
|
|
3725
|
+
return gpu_name.replace("NVIDIA ", "").replace("AMD ", "").strip()
|
|
3726
|
+
|
|
3727
|
+
|
|
2877
3728
|
@targets_app.command("add")
|
|
2878
3729
|
def targets_add(
|
|
2879
3730
|
file_path: Path = typer.Argument(..., help="Path to target TOML file"),
|
|
@@ -2956,6 +3807,93 @@ def targets_show(
|
|
|
2956
3807
|
raise typer.Exit(1) from None
|
|
2957
3808
|
|
|
2958
3809
|
|
|
3810
|
+
@targets_app.command("probe")
|
|
3811
|
+
def targets_probe(
|
|
3812
|
+
name: str = typer.Argument(..., help="Target name"),
|
|
3813
|
+
) -> None:
|
|
3814
|
+
"""Probe a target to discover available compilation backends.
|
|
3815
|
+
|
|
3816
|
+
Connects to the target and checks what's available:
|
|
3817
|
+
- Triton
|
|
3818
|
+
- torch.compile/inductor
|
|
3819
|
+
- HIP/hipcc or CUDA/nvcc
|
|
3820
|
+
- ROCm or CUDA version
|
|
3821
|
+
- Python packages (torch, triton, etc.)
|
|
3822
|
+
|
|
3823
|
+
Example:
|
|
3824
|
+
wafer config targets probe runpod-mi300x
|
|
3825
|
+
"""
|
|
3826
|
+
import trio
|
|
3827
|
+
|
|
3828
|
+
from .targets import ProbeError, load_target, probe_target_capabilities
|
|
3829
|
+
|
|
3830
|
+
try:
|
|
3831
|
+
target = load_target(name)
|
|
3832
|
+
except FileNotFoundError as e:
|
|
3833
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3834
|
+
raise typer.Exit(1) from None
|
|
3835
|
+
|
|
3836
|
+
typer.echo(f"Probing target: {name}...")
|
|
3837
|
+
|
|
3838
|
+
try:
|
|
3839
|
+
capabilities = trio.run(probe_target_capabilities, target)
|
|
3840
|
+
except ProbeError as e:
|
|
3841
|
+
# ProbeError already has actionable context
|
|
3842
|
+
typer.echo(f"\nError: {e}", err=True)
|
|
3843
|
+
raise typer.Exit(1) from None
|
|
3844
|
+
except Exception as e:
|
|
3845
|
+
# Unexpected errors - include type for debugging
|
|
3846
|
+
typer.echo(f"\nUnexpected error probing target: {type(e).__name__}: {e}", err=True)
|
|
3847
|
+
raise typer.Exit(1) from None
|
|
3848
|
+
|
|
3849
|
+
# Display results
|
|
3850
|
+
typer.echo(f"\nTarget: {name}")
|
|
3851
|
+
|
|
3852
|
+
if capabilities.get("gpu_name"):
|
|
3853
|
+
typer.echo(f" GPU: {capabilities['gpu_name']}")
|
|
3854
|
+
if capabilities.get("compute_capability"):
|
|
3855
|
+
typer.echo(f" Compute: {capabilities['compute_capability']}")
|
|
3856
|
+
|
|
3857
|
+
typer.echo("\n Compilation Backends:")
|
|
3858
|
+
backends = capabilities.get("backends", {})
|
|
3859
|
+
|
|
3860
|
+
# Triton
|
|
3861
|
+
triton_ver = backends.get("triton")
|
|
3862
|
+
if triton_ver:
|
|
3863
|
+
typer.echo(f" ✓ Triton: {triton_ver}")
|
|
3864
|
+
else:
|
|
3865
|
+
typer.echo(" ✗ Triton: not installed")
|
|
3866
|
+
|
|
3867
|
+
# torch.compile
|
|
3868
|
+
if triton_ver and backends.get("torch"):
|
|
3869
|
+
typer.echo(" ✓ torch.compile/inductor: available")
|
|
3870
|
+
else:
|
|
3871
|
+
typer.echo(" ✗ torch.compile/inductor: requires Triton")
|
|
3872
|
+
|
|
3873
|
+
# HIP/CUDA compiler
|
|
3874
|
+
if backends.get("hipcc"):
|
|
3875
|
+
typer.echo(f" ✓ HIP/hipcc: {backends['hipcc']}")
|
|
3876
|
+
elif backends.get("nvcc"):
|
|
3877
|
+
typer.echo(f" ✓ CUDA/nvcc: {backends['nvcc']}")
|
|
3878
|
+
else:
|
|
3879
|
+
typer.echo(" ✗ No GPU compiler found")
|
|
3880
|
+
|
|
3881
|
+
# ROCm/CUDA version
|
|
3882
|
+
if capabilities.get("rocm_version"):
|
|
3883
|
+
typer.echo(f" ROCm: {capabilities['rocm_version']}")
|
|
3884
|
+
if capabilities.get("cuda_version"):
|
|
3885
|
+
typer.echo(f" CUDA: {capabilities['cuda_version']}")
|
|
3886
|
+
|
|
3887
|
+
typer.echo("\n Python Environment:")
|
|
3888
|
+
typer.echo(f" Python: {capabilities.get('python_version', 'unknown')}")
|
|
3889
|
+
|
|
3890
|
+
packages = capabilities.get("packages", {})
|
|
3891
|
+
if packages.get("torch"):
|
|
3892
|
+
typer.echo(f" PyTorch: {packages['torch']}")
|
|
3893
|
+
if triton_ver:
|
|
3894
|
+
typer.echo(f" Triton: {triton_ver}")
|
|
3895
|
+
|
|
3896
|
+
|
|
2959
3897
|
@targets_app.command("remove")
|
|
2960
3898
|
def targets_remove(
|
|
2961
3899
|
name: str = typer.Argument(..., help="Target name"),
|
|
@@ -3086,6 +4024,92 @@ def targets_pods() -> None:
|
|
|
3086
4024
|
typer.echo()
|
|
3087
4025
|
|
|
3088
4026
|
|
|
4027
|
+
# ── Pool commands ───────────────────────────────────────────────────────────
|
|
4028
|
+
|
|
4029
|
+
|
|
4030
|
+
@targets_app.command("pool-list")
|
|
4031
|
+
def targets_pool_list() -> None:
|
|
4032
|
+
"""List all configured target pools.
|
|
4033
|
+
|
|
4034
|
+
Example:
|
|
4035
|
+
wafer config targets pool-list
|
|
4036
|
+
"""
|
|
4037
|
+
from .targets import get_pool, list_pools
|
|
4038
|
+
|
|
4039
|
+
pools = list_pools()
|
|
4040
|
+
|
|
4041
|
+
if not pools:
|
|
4042
|
+
typer.echo("No pools configured")
|
|
4043
|
+
typer.echo("")
|
|
4044
|
+
typer.echo("Define pools in ~/.wafer/config.toml:")
|
|
4045
|
+
typer.echo(" [pools.my-pool]")
|
|
4046
|
+
typer.echo(' targets = ["target-1", "target-2"]')
|
|
4047
|
+
return
|
|
4048
|
+
|
|
4049
|
+
typer.echo("Configured pools:\n")
|
|
4050
|
+
for pool_name in pools:
|
|
4051
|
+
try:
|
|
4052
|
+
targets = get_pool(pool_name)
|
|
4053
|
+
typer.echo(f" {pool_name}: {', '.join(targets)}")
|
|
4054
|
+
except Exception as e:
|
|
4055
|
+
typer.echo(f" {pool_name}: (error: {e})")
|
|
4056
|
+
|
|
4057
|
+
|
|
4058
|
+
@targets_app.command("pool-create")
|
|
4059
|
+
def targets_pool_create(
|
|
4060
|
+
name: str = typer.Argument(..., help="Pool name"),
|
|
4061
|
+
targets: list[str] = typer.Argument(..., help="Target names to include in pool"),
|
|
4062
|
+
) -> None:
|
|
4063
|
+
"""Create or update a target pool.
|
|
4064
|
+
|
|
4065
|
+
Example:
|
|
4066
|
+
wafer config targets pool-create mi300x-pool mi300x-1 mi300x-2 mi300x-3
|
|
4067
|
+
"""
|
|
4068
|
+
from .targets import save_pool
|
|
4069
|
+
|
|
4070
|
+
try:
|
|
4071
|
+
save_pool(name, targets)
|
|
4072
|
+
typer.echo(f"Pool '{name}' created with {len(targets)} targets")
|
|
4073
|
+
except FileNotFoundError as e:
|
|
4074
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4075
|
+
raise typer.Exit(1) from None
|
|
4076
|
+
|
|
4077
|
+
|
|
4078
|
+
@targets_app.command("pool-status")
|
|
4079
|
+
def targets_pool_status(
|
|
4080
|
+
name: str = typer.Argument(..., help="Pool name"),
|
|
4081
|
+
) -> None:
|
|
4082
|
+
"""Show status of targets in a pool (locked/available).
|
|
4083
|
+
|
|
4084
|
+
Example:
|
|
4085
|
+
wafer config targets pool-status mi300x-pool
|
|
4086
|
+
"""
|
|
4087
|
+
from .target_lock import get_lock_holder, is_target_locked
|
|
4088
|
+
from .targets import get_pool
|
|
4089
|
+
|
|
4090
|
+
try:
|
|
4091
|
+
targets = get_pool(name)
|
|
4092
|
+
except FileNotFoundError as e:
|
|
4093
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4094
|
+
raise typer.Exit(1) from None
|
|
4095
|
+
|
|
4096
|
+
typer.echo(f"Pool '{name}' ({len(targets)} targets):\n")
|
|
4097
|
+
|
|
4098
|
+
available = 0
|
|
4099
|
+
for target_name in targets:
|
|
4100
|
+
locked = is_target_locked(target_name)
|
|
4101
|
+
if locked:
|
|
4102
|
+
pid = get_lock_holder(target_name)
|
|
4103
|
+
pid_str = f" (pid {pid})" if pid else ""
|
|
4104
|
+
typer.echo(f" [busy] {target_name}{pid_str}")
|
|
4105
|
+
else:
|
|
4106
|
+
typer.echo(f" [free] {target_name}")
|
|
4107
|
+
available += 1
|
|
4108
|
+
|
|
4109
|
+
typer.echo("")
|
|
4110
|
+
typer.echo(f"Available: {available}/{len(targets)}")
|
|
4111
|
+
|
|
4112
|
+
|
|
3089
4113
|
# =============================================================================
|
|
3090
4114
|
# Billing commands
|
|
3091
4115
|
# =============================================================================
|
|
@@ -3119,7 +4143,9 @@ def billing_usage(
|
|
|
3119
4143
|
@billing_app.command("topup")
|
|
3120
4144
|
def billing_topup(
|
|
3121
4145
|
amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
|
|
3122
|
-
no_browser: bool = typer.Option(
|
|
4146
|
+
no_browser: bool = typer.Option(
|
|
4147
|
+
False, "--no-browser", help="Print URL instead of opening browser"
|
|
4148
|
+
),
|
|
3123
4149
|
) -> None:
|
|
3124
4150
|
"""Add credits to your account.
|
|
3125
4151
|
|
|
@@ -3165,7 +4191,9 @@ def billing_topup(
|
|
|
3165
4191
|
|
|
3166
4192
|
@billing_app.command("portal")
|
|
3167
4193
|
def billing_portal(
|
|
3168
|
-
no_browser: bool = typer.Option(
|
|
4194
|
+
no_browser: bool = typer.Option(
|
|
4195
|
+
False, "--no-browser", help="Print URL instead of opening browser"
|
|
4196
|
+
),
|
|
3169
4197
|
) -> None:
|
|
3170
4198
|
"""Open Stripe billing portal.
|
|
3171
4199
|
|
|
@@ -3198,6 +4226,81 @@ def billing_portal(
|
|
|
3198
4226
|
raise typer.Exit(1) from None
|
|
3199
4227
|
|
|
3200
4228
|
|
|
4229
|
+
# =============================================================================
|
|
4230
|
+
# SSH Keys commands (BYOK - Bring Your Own Key)
|
|
4231
|
+
# =============================================================================
|
|
4232
|
+
|
|
4233
|
+
|
|
4234
|
+
@ssh_keys_app.command("list")
|
|
4235
|
+
def ssh_keys_list(
|
|
4236
|
+
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
4237
|
+
) -> None:
|
|
4238
|
+
"""List all registered SSH public keys.
|
|
4239
|
+
|
|
4240
|
+
Example:
|
|
4241
|
+
wafer ssh-keys list
|
|
4242
|
+
wafer ssh-keys list --json
|
|
4243
|
+
"""
|
|
4244
|
+
from .ssh_keys import list_ssh_keys
|
|
4245
|
+
|
|
4246
|
+
try:
|
|
4247
|
+
result = list_ssh_keys(json_output=json_output)
|
|
4248
|
+
typer.echo(result)
|
|
4249
|
+
except RuntimeError as e:
|
|
4250
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4251
|
+
raise typer.Exit(1) from e
|
|
4252
|
+
|
|
4253
|
+
|
|
4254
|
+
@ssh_keys_app.command("add")
|
|
4255
|
+
def ssh_keys_add(
|
|
4256
|
+
pubkey_path: Path | None = typer.Argument(
|
|
4257
|
+
None, help="Path to public key file (auto-detects ~/.ssh/id_ed25519.pub if not specified)"
|
|
4258
|
+
),
|
|
4259
|
+
name: str | None = typer.Option(None, "--name", "-n", help="Friendly name for the key"),
|
|
4260
|
+
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
4261
|
+
) -> None:
|
|
4262
|
+
"""Add an SSH public key.
|
|
4263
|
+
|
|
4264
|
+
If no path is specified, auto-detects keys from ~/.ssh/ in preference order:
|
|
4265
|
+
id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
|
|
4266
|
+
|
|
4267
|
+
Example:
|
|
4268
|
+
wafer ssh-keys add # Auto-detect
|
|
4269
|
+
wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
|
|
4270
|
+
wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
|
|
4271
|
+
"""
|
|
4272
|
+
from .ssh_keys import add_ssh_key
|
|
4273
|
+
|
|
4274
|
+
try:
|
|
4275
|
+
result = add_ssh_key(pubkey_path=pubkey_path, name=name, json_output=json_output)
|
|
4276
|
+
typer.echo(result)
|
|
4277
|
+
except RuntimeError as e:
|
|
4278
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4279
|
+
raise typer.Exit(1) from e
|
|
4280
|
+
|
|
4281
|
+
|
|
4282
|
+
@ssh_keys_app.command("remove")
|
|
4283
|
+
def ssh_keys_remove(
|
|
4284
|
+
key_id: str = typer.Argument(..., help="UUID of the SSH key to remove"),
|
|
4285
|
+
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
4286
|
+
) -> None:
|
|
4287
|
+
"""Remove an SSH public key.
|
|
4288
|
+
|
|
4289
|
+
Get the key ID from 'wafer ssh-keys list'.
|
|
4290
|
+
|
|
4291
|
+
Example:
|
|
4292
|
+
wafer ssh-keys remove abc123-def456-...
|
|
4293
|
+
"""
|
|
4294
|
+
from .ssh_keys import remove_ssh_key
|
|
4295
|
+
|
|
4296
|
+
try:
|
|
4297
|
+
result = remove_ssh_key(key_id=key_id, json_output=json_output)
|
|
4298
|
+
typer.echo(result)
|
|
4299
|
+
except RuntimeError as e:
|
|
4300
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4301
|
+
raise typer.Exit(1) from e
|
|
4302
|
+
|
|
4303
|
+
|
|
3201
4304
|
# =============================================================================
|
|
3202
4305
|
# Workspaces commands
|
|
3203
4306
|
# =============================================================================
|
|
@@ -3226,21 +4329,34 @@ def workspaces_list(
|
|
|
3226
4329
|
@workspaces_app.command("create")
|
|
3227
4330
|
def workspaces_create(
|
|
3228
4331
|
name: str = typer.Argument(..., help="Workspace name"),
|
|
3229
|
-
gpu_type: str = typer.Option("B200", "--gpu", "-g", help="GPU type (
|
|
4332
|
+
gpu_type: str = typer.Option("B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"),
|
|
3230
4333
|
image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
|
|
4334
|
+
wait: bool = typer.Option(False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"),
|
|
3231
4335
|
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
3232
4336
|
) -> None:
|
|
3233
4337
|
"""Create a new workspace.
|
|
3234
4338
|
|
|
4339
|
+
Available GPUs:
|
|
4340
|
+
MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
4341
|
+
B200 NVIDIA Blackwell B200 (180GB HBM3e, CUDA)
|
|
4342
|
+
|
|
3235
4343
|
Example:
|
|
3236
|
-
wafer workspaces create my-kernel
|
|
3237
|
-
wafer workspaces create my-kernel --gpu
|
|
4344
|
+
wafer workspaces create my-kernel # B200 (default)
|
|
4345
|
+
wafer workspaces create my-kernel --gpu MI300X # AMD MI300X
|
|
4346
|
+
wafer workspaces create my-kernel --gpu B200 # NVIDIA B200
|
|
3238
4347
|
wafer workspaces create my-kernel --image pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
|
|
4348
|
+
wafer workspaces create my-kernel --wait
|
|
3239
4349
|
"""
|
|
3240
4350
|
from .workspaces import create_workspace
|
|
3241
4351
|
|
|
3242
4352
|
try:
|
|
3243
|
-
result = create_workspace(
|
|
4353
|
+
result = create_workspace(
|
|
4354
|
+
name,
|
|
4355
|
+
gpu_type=gpu_type,
|
|
4356
|
+
image=image,
|
|
4357
|
+
wait=wait,
|
|
4358
|
+
json_output=json_output,
|
|
4359
|
+
)
|
|
3244
4360
|
typer.echo(result)
|
|
3245
4361
|
except RuntimeError as e:
|
|
3246
4362
|
typer.echo(f"Error: {e}", err=True)
|
|
@@ -3250,16 +4366,23 @@ def workspaces_create(
|
|
|
3250
4366
|
@workspaces_app.command("delete")
|
|
3251
4367
|
def workspaces_delete(
|
|
3252
4368
|
workspace_id: str = typer.Argument(..., help="Workspace ID to delete"),
|
|
4369
|
+
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
|
|
3253
4370
|
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
3254
4371
|
) -> None:
|
|
3255
4372
|
"""Delete a workspace.
|
|
3256
4373
|
|
|
3257
4374
|
Example:
|
|
3258
4375
|
wafer workspaces delete ws_abc123
|
|
4376
|
+
wafer workspaces delete ws_abc123 -y
|
|
3259
4377
|
"""
|
|
3260
4378
|
from .workspaces import delete_workspace
|
|
3261
4379
|
|
|
3262
4380
|
try:
|
|
4381
|
+
if not yes:
|
|
4382
|
+
confirm = typer.confirm(f"Delete workspace '{workspace_id}'?")
|
|
4383
|
+
if not confirm:
|
|
4384
|
+
typer.echo("Cancelled.")
|
|
4385
|
+
raise typer.Exit(0)
|
|
3263
4386
|
result = delete_workspace(workspace_id, json_output=json_output)
|
|
3264
4387
|
typer.echo(result)
|
|
3265
4388
|
except RuntimeError as e:
|
|
@@ -3267,32 +4390,6 @@ def workspaces_delete(
|
|
|
3267
4390
|
raise typer.Exit(1) from None
|
|
3268
4391
|
|
|
3269
4392
|
|
|
3270
|
-
@workspaces_app.command("attach")
|
|
3271
|
-
def workspaces_attach(
|
|
3272
|
-
workspace_id: str = typer.Argument(..., help="Workspace ID to attach to"),
|
|
3273
|
-
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
3274
|
-
) -> None:
|
|
3275
|
-
"""Attach to a workspace (get SSH credentials).
|
|
3276
|
-
|
|
3277
|
-
This will:
|
|
3278
|
-
1. Start the workspace if needed
|
|
3279
|
-
2. Return SSH connection details
|
|
3280
|
-
3. Save the private key to ~/.wafer/keys/
|
|
3281
|
-
|
|
3282
|
-
Example:
|
|
3283
|
-
wafer workspaces attach ws_abc123
|
|
3284
|
-
wafer workspaces attach ws_abc123 --json
|
|
3285
|
-
"""
|
|
3286
|
-
from .workspaces import attach_workspace
|
|
3287
|
-
|
|
3288
|
-
try:
|
|
3289
|
-
result = attach_workspace(workspace_id, json_output=json_output)
|
|
3290
|
-
typer.echo(result)
|
|
3291
|
-
except RuntimeError as e:
|
|
3292
|
-
typer.echo(f"Error: {e}", err=True)
|
|
3293
|
-
raise typer.Exit(1) from None
|
|
3294
|
-
|
|
3295
|
-
|
|
3296
4393
|
@workspaces_app.command("show")
|
|
3297
4394
|
def workspaces_show(
|
|
3298
4395
|
workspace_id: str = typer.Argument(..., help="Workspace ID to show"),
|
|
@@ -3314,12 +4411,19 @@ def workspaces_show(
|
|
|
3314
4411
|
raise typer.Exit(1) from None
|
|
3315
4412
|
|
|
3316
4413
|
|
|
3317
|
-
@workspaces_app.command(
|
|
4414
|
+
@workspaces_app.command(
|
|
4415
|
+
"exec",
|
|
4416
|
+
context_settings={
|
|
4417
|
+
"allow_interspersed_args": False,
|
|
4418
|
+
"ignore_unknown_options": True,
|
|
4419
|
+
"allow_extra_args": True,
|
|
4420
|
+
},
|
|
4421
|
+
)
|
|
3318
4422
|
def workspaces_exec(
|
|
4423
|
+
ctx: typer.Context,
|
|
3319
4424
|
workspace: str | None = typer.Argument(
|
|
3320
4425
|
None, help="Workspace name or ID (optional if default set)"
|
|
3321
4426
|
),
|
|
3322
|
-
command: list[str] = typer.Argument(..., help="Command to execute on GPU"),
|
|
3323
4427
|
timeout: int | None = typer.Option(
|
|
3324
4428
|
None,
|
|
3325
4429
|
"--timeout",
|
|
@@ -3332,17 +4436,30 @@ def workspaces_exec(
|
|
|
3332
4436
|
"-s",
|
|
3333
4437
|
help="Sync local directory to workspace before executing",
|
|
3334
4438
|
),
|
|
4439
|
+
gpu: bool = typer.Option(False, "--gpu", help="Force GPU routing (default behavior)"),
|
|
4440
|
+
cpu: bool = typer.Option(False, "--cpu", help="Run in workspace container (no GPU)"),
|
|
4441
|
+
baremetal: bool = typer.Option(
|
|
4442
|
+
False, "--baremetal", help="Force baremetal target (for hardware counters like ncu/nsys)"
|
|
4443
|
+
),
|
|
4444
|
+
pull_image: bool = typer.Option(False, "--pull-image", help="Pull image on target if missing"),
|
|
3335
4445
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
3336
4446
|
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
3337
4447
|
) -> None:
|
|
3338
|
-
"""Execute a command in workspace
|
|
4448
|
+
"""Execute a command in workspace.
|
|
4449
|
+
|
|
4450
|
+
By default, auto-detects whether to route to GPU based on the command.
|
|
4451
|
+
Use --gpu, --cpu, or --baremetal to override.
|
|
3339
4452
|
|
|
3340
|
-
|
|
3341
|
-
|
|
4453
|
+
Routing options:
|
|
4454
|
+
--gpu Force GPU container (Modal or baremetal with GPU)
|
|
4455
|
+
--cpu Run in workspace container directly (no GPU)
|
|
4456
|
+
--baremetal Force baremetal target (for ncu, nsys, hardware counters)
|
|
3342
4457
|
|
|
3343
4458
|
If workspace is not specified, uses the default workspace from config,
|
|
3344
4459
|
or the only workspace if you have exactly one.
|
|
3345
4460
|
|
|
4461
|
+
IMPORTANT: Options must come before the workspace name.
|
|
4462
|
+
|
|
3346
4463
|
Examples:
|
|
3347
4464
|
wafer workspaces exec dev -- python train.py
|
|
3348
4465
|
wafer workspaces exec dev -- python -c "import torch; print(torch.cuda.is_available())"
|
|
@@ -3353,6 +4470,49 @@ def workspaces_exec(
|
|
|
3353
4470
|
from .global_config import get_defaults, get_preferences
|
|
3354
4471
|
from .workspaces import exec_command, resolve_workspace, sync_files
|
|
3355
4472
|
|
|
4473
|
+
# Enforce option ordering to avoid treating CLI flags as remote commands
|
|
4474
|
+
known_options = {
|
|
4475
|
+
"--timeout",
|
|
4476
|
+
"-t",
|
|
4477
|
+
"--sync",
|
|
4478
|
+
"-s",
|
|
4479
|
+
"--gpu",
|
|
4480
|
+
"--cpu",
|
|
4481
|
+
"--baremetal",
|
|
4482
|
+
"--pull-image",
|
|
4483
|
+
"--verbose",
|
|
4484
|
+
"-v",
|
|
4485
|
+
"--quiet",
|
|
4486
|
+
"-q",
|
|
4487
|
+
"--help",
|
|
4488
|
+
"-h",
|
|
4489
|
+
}
|
|
4490
|
+
for arg in ctx.args:
|
|
4491
|
+
if arg == "--":
|
|
4492
|
+
break
|
|
4493
|
+
if arg in known_options:
|
|
4494
|
+
typer.echo(
|
|
4495
|
+
"Error: options must come before the workspace name. "
|
|
4496
|
+
"Example: wafer workspaces exec --pull-image dev -- python -V",
|
|
4497
|
+
err=True,
|
|
4498
|
+
)
|
|
4499
|
+
raise typer.Exit(1)
|
|
4500
|
+
|
|
4501
|
+
# Validate mutually exclusive routing flags
|
|
4502
|
+
routing_flags = sum([gpu, cpu, baremetal])
|
|
4503
|
+
if routing_flags > 1:
|
|
4504
|
+
typer.echo("Error: --gpu, --cpu, and --baremetal are mutually exclusive", err=True)
|
|
4505
|
+
raise typer.Exit(1)
|
|
4506
|
+
|
|
4507
|
+
# Determine routing (None = auto-detect)
|
|
4508
|
+
routing: str | None = None
|
|
4509
|
+
if gpu:
|
|
4510
|
+
routing = "gpu"
|
|
4511
|
+
elif cpu:
|
|
4512
|
+
routing = "cpu"
|
|
4513
|
+
elif baremetal:
|
|
4514
|
+
routing = "baremetal"
|
|
4515
|
+
|
|
3356
4516
|
# Resolve workspace (specified, config default, or single workspace)
|
|
3357
4517
|
try:
|
|
3358
4518
|
resolved_workspace = resolve_workspace(workspace)
|
|
@@ -3377,7 +4537,8 @@ def workspaces_exec(
|
|
|
3377
4537
|
show_status = prefs.mode == "explicit"
|
|
3378
4538
|
|
|
3379
4539
|
if show_status:
|
|
3380
|
-
|
|
4540
|
+
routing_label = routing or "auto"
|
|
4541
|
+
typer.echo(f"[wafer] Workspace: {resolved_workspace} (routing: {routing_label})", err=True)
|
|
3381
4542
|
|
|
3382
4543
|
# Sync files if requested
|
|
3383
4544
|
if sync is not None:
|
|
@@ -3403,114 +4564,617 @@ def workspaces_exec(
|
|
|
3403
4564
|
typer.echo(f"Error: {e}", err=True)
|
|
3404
4565
|
raise typer.Exit(1) from None
|
|
3405
4566
|
|
|
4567
|
+
# Get command from context args (passthrough after --)
|
|
4568
|
+
import shlex
|
|
4569
|
+
|
|
4570
|
+
command = list(ctx.args)
|
|
4571
|
+
if command and command[0] == "--":
|
|
4572
|
+
command = command[1:]
|
|
4573
|
+
|
|
4574
|
+
if not command:
|
|
4575
|
+
typer.echo("Error: No command specified", err=True)
|
|
4576
|
+
raise typer.Exit(1)
|
|
4577
|
+
|
|
3406
4578
|
if show_status:
|
|
3407
4579
|
typer.echo(f"[wafer] Executing (timeout: {effective_timeout}s)...", err=True)
|
|
3408
4580
|
|
|
3409
|
-
#
|
|
4581
|
+
# Build command string
|
|
4582
|
+
# Handle two cases:
|
|
4583
|
+
# 1. Single element: user quoted the whole command (e.g., "echo hello world")
|
|
4584
|
+
# -> use directly, don't re-quote
|
|
4585
|
+
# 2. Multiple elements: user passed separate args (e.g., -- python -c "print(1)")
|
|
4586
|
+
# -> use shlex.join to properly quote args with spaces
|
|
4587
|
+
if len(command) == 1:
|
|
4588
|
+
command_str = command[0]
|
|
4589
|
+
else:
|
|
4590
|
+
command_str = shlex.join(command)
|
|
4591
|
+
|
|
4592
|
+
try:
|
|
4593
|
+
exit_code = exec_command(
|
|
4594
|
+
workspace_id=resolved_workspace,
|
|
4595
|
+
command=command_str,
|
|
4596
|
+
timeout_seconds=effective_timeout,
|
|
4597
|
+
routing=routing,
|
|
4598
|
+
pull_image=pull_image,
|
|
4599
|
+
)
|
|
4600
|
+
except RuntimeError as e:
|
|
4601
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4602
|
+
raise typer.Exit(1) from None
|
|
4603
|
+
|
|
4604
|
+
if show_status:
|
|
4605
|
+
typer.echo(f"[wafer] Exit code: {exit_code}", err=True)
|
|
4606
|
+
|
|
4607
|
+
raise typer.Exit(exit_code)
|
|
4608
|
+
|
|
4609
|
+
|
|
4610
|
+
@workspaces_app.command("ssh")
|
|
4611
|
+
def workspaces_ssh(
|
|
4612
|
+
workspace: str | None = typer.Argument(
|
|
4613
|
+
None, help="Workspace name or ID (optional if default set)"
|
|
4614
|
+
),
|
|
4615
|
+
) -> None:
|
|
4616
|
+
"""SSH into a workspace.
|
|
4617
|
+
|
|
4618
|
+
Uses workspace SSH credentials once the workspace is running.
|
|
4619
|
+
If workspace is not specified, uses the default workspace.
|
|
4620
|
+
|
|
4621
|
+
Examples:
|
|
4622
|
+
wafer workspaces ssh dev
|
|
4623
|
+
wafer workspaces ssh # uses default workspace
|
|
4624
|
+
"""
|
|
4625
|
+
import os
|
|
4626
|
+
|
|
4627
|
+
from .workspaces import get_workspace_raw, resolve_workspace
|
|
4628
|
+
|
|
4629
|
+
# Resolve workspace
|
|
4630
|
+
try:
|
|
4631
|
+
resolved_workspace = resolve_workspace(workspace)
|
|
4632
|
+
except RuntimeError as e:
|
|
4633
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4634
|
+
raise typer.Exit(1) from None
|
|
4635
|
+
|
|
4636
|
+
typer.echo(f"Connecting to workspace: {resolved_workspace}...", err=True)
|
|
4637
|
+
|
|
4638
|
+
# Get SSH credentials from workspace
|
|
4639
|
+
try:
|
|
4640
|
+
ws = get_workspace_raw(resolved_workspace)
|
|
4641
|
+
except RuntimeError as e:
|
|
4642
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4643
|
+
raise typer.Exit(1) from None
|
|
4644
|
+
|
|
4645
|
+
from .workspaces import VALID_STATUSES
|
|
4646
|
+
|
|
4647
|
+
workspace_status = ws.get("status")
|
|
4648
|
+
assert workspace_status in VALID_STATUSES, (
|
|
4649
|
+
f"Workspace {resolved_workspace} has invalid status '{workspace_status}'. "
|
|
4650
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
4651
|
+
)
|
|
4652
|
+
|
|
4653
|
+
if workspace_status != "running":
|
|
4654
|
+
typer.echo(f"Error: Workspace is {workspace_status}. Wait for it to be running.", err=True)
|
|
4655
|
+
raise typer.Exit(1)
|
|
4656
|
+
if not ws.get("ssh_host") or not ws.get("ssh_port") or not ws.get("ssh_user"):
|
|
4657
|
+
typer.echo("Error: SSH credentials not available yet.", err=True)
|
|
4658
|
+
raise typer.Exit(1)
|
|
4659
|
+
|
|
4660
|
+
# Build SSH args - key_path is None for BYOK model (uses default SSH key)
|
|
4661
|
+
ssh_args = ["ssh"]
|
|
4662
|
+
ssh_args.extend([
|
|
4663
|
+
"-p",
|
|
4664
|
+
str(ws.get("ssh_port")),
|
|
4665
|
+
"-o",
|
|
4666
|
+
"StrictHostKeyChecking=no",
|
|
4667
|
+
"-o",
|
|
4668
|
+
"UserKnownHostsFile=/dev/null",
|
|
4669
|
+
f"{ws.get('ssh_user')}@{ws.get('ssh_host')}",
|
|
4670
|
+
])
|
|
4671
|
+
|
|
4672
|
+
# Replace current process with SSH
|
|
4673
|
+
os.execvp("ssh", ssh_args)
|
|
4674
|
+
|
|
4675
|
+
|
|
4676
|
+
@workspaces_app.command("sync")
|
|
4677
|
+
def workspaces_sync(
|
|
4678
|
+
workspace: str | None = typer.Argument(
|
|
4679
|
+
None, help="Workspace name or ID (optional if default set)"
|
|
4680
|
+
),
|
|
4681
|
+
path: Path = typer.Argument(..., help="Local file or directory to sync"),
|
|
4682
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
4683
|
+
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
4684
|
+
) -> None:
|
|
4685
|
+
"""Sync local files to workspace.
|
|
4686
|
+
|
|
4687
|
+
Uses rsync over SSH to sync files to the workspace's /workspace directory.
|
|
4688
|
+
If workspace is not specified, uses the default workspace.
|
|
4689
|
+
|
|
4690
|
+
Examples:
|
|
4691
|
+
wafer workspaces sync dev ./my-project
|
|
4692
|
+
wafer workspaces sync ./my-project # uses default workspace
|
|
4693
|
+
wafer workspaces sync dev . # sync current directory
|
|
4694
|
+
wafer workspaces sync dev ./script.py # sync single file
|
|
4695
|
+
"""
|
|
4696
|
+
from .global_config import get_preferences
|
|
4697
|
+
from .workspaces import resolve_workspace, sync_files
|
|
4698
|
+
|
|
4699
|
+
# Determine verbosity based on mode
|
|
4700
|
+
prefs = get_preferences()
|
|
4701
|
+
if quiet:
|
|
4702
|
+
show_status = False
|
|
4703
|
+
elif verbose:
|
|
4704
|
+
show_status = True
|
|
4705
|
+
else:
|
|
4706
|
+
show_status = prefs.mode == "explicit"
|
|
4707
|
+
|
|
4708
|
+
# Validate path
|
|
4709
|
+
if not path.exists():
|
|
4710
|
+
typer.echo(f"Error: Path not found: {path}", err=True)
|
|
4711
|
+
raise typer.Exit(1)
|
|
4712
|
+
|
|
4713
|
+
# Resolve workspace
|
|
4714
|
+
try:
|
|
4715
|
+
resolved_workspace = resolve_workspace(workspace)
|
|
4716
|
+
except RuntimeError as e:
|
|
4717
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4718
|
+
raise typer.Exit(1) from None
|
|
4719
|
+
|
|
4720
|
+
if show_status:
|
|
4721
|
+
typer.echo(f"[wafer] Syncing {path} to workspace {resolved_workspace}...", err=True)
|
|
4722
|
+
|
|
4723
|
+
def on_progress(msg: str) -> None:
|
|
4724
|
+
if show_status:
|
|
4725
|
+
typer.echo(f"[wafer] {msg}", err=True)
|
|
4726
|
+
|
|
4727
|
+
try:
|
|
4728
|
+
file_count, warning = sync_files(
|
|
4729
|
+
resolved_workspace, path.resolve(), on_progress=on_progress
|
|
4730
|
+
)
|
|
4731
|
+
except RuntimeError as e:
|
|
4732
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4733
|
+
raise typer.Exit(1) from None
|
|
4734
|
+
|
|
4735
|
+
|
|
4736
|
+
# =============================================================================
|
|
4737
|
+
# Target operations commands (exec/ssh/sync)
|
|
4738
|
+
# =============================================================================
|
|
4739
|
+
|
|
4740
|
+
|
|
4741
|
+
@targets_ops_app.command("exec", context_settings={"allow_interspersed_args": False})
|
|
4742
|
+
def targets_exec(
|
|
4743
|
+
target: str = typer.Argument(
|
|
4744
|
+
...,
|
|
4745
|
+
help="Target name",
|
|
4746
|
+
autocompletion=complete_target_name,
|
|
4747
|
+
),
|
|
4748
|
+
command: list[str] = typer.Argument(..., help="Command to execute"),
|
|
4749
|
+
timeout: int | None = typer.Option(
|
|
4750
|
+
None,
|
|
4751
|
+
"--timeout",
|
|
4752
|
+
"-t",
|
|
4753
|
+
help="Execution timeout in seconds (default: 300)",
|
|
4754
|
+
),
|
|
4755
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
4756
|
+
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
4757
|
+
) -> None:
|
|
4758
|
+
"""Execute a command on a configured target.
|
|
4759
|
+
|
|
4760
|
+
Provisions the target if needed (RunPod, DigitalOcean), then runs the command via SSH.
|
|
4761
|
+
For cloud targets, the instance is kept alive after execution - use
|
|
4762
|
+
'wafer config targets cleanup <name>' to terminate.
|
|
4763
|
+
|
|
4764
|
+
Supported targets: RunPod, DigitalOcean, SSH (baremetal/vm).
|
|
4765
|
+
Not supported: Modal (serverless), Local (no SSH), Workspace (use 'wafer workspaces exec').
|
|
4766
|
+
|
|
4767
|
+
Examples:
|
|
4768
|
+
wafer targets exec runpod-mi300x -- python -c "import torch; print(torch.cuda.is_available())"
|
|
4769
|
+
wafer targets exec runpod-mi300x -- rocm-smi
|
|
4770
|
+
wafer targets exec my-ssh-server -- nvidia-smi
|
|
4771
|
+
wafer targets exec runpod-mi300x "echo hello && ls -la" --timeout 60
|
|
4772
|
+
"""
|
|
4773
|
+
from .global_config import get_preferences
|
|
4774
|
+
from .targets import load_target
|
|
4775
|
+
from .targets_ops import TargetExecError, exec_on_target_sync, get_target_ssh_info
|
|
4776
|
+
|
|
4777
|
+
# Determine verbosity
|
|
4778
|
+
prefs = get_preferences()
|
|
4779
|
+
if quiet:
|
|
4780
|
+
show_status = False
|
|
4781
|
+
elif verbose:
|
|
4782
|
+
show_status = True
|
|
4783
|
+
else:
|
|
4784
|
+
show_status = prefs.mode == "explicit"
|
|
4785
|
+
|
|
4786
|
+
# Load target
|
|
4787
|
+
try:
|
|
4788
|
+
target_config = load_target(target)
|
|
4789
|
+
except FileNotFoundError as e:
|
|
4790
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4791
|
+
typer.echo("List available targets with: wafer config targets list", err=True)
|
|
4792
|
+
raise typer.Exit(1) from None
|
|
4793
|
+
except ValueError as e:
|
|
4794
|
+
typer.echo(f"Error loading target config: {e}", err=True)
|
|
4795
|
+
raise typer.Exit(1) from None
|
|
4796
|
+
|
|
4797
|
+
if show_status:
|
|
4798
|
+
typer.echo(f"[wafer] Target: {target} ({type(target_config).__name__})", err=True)
|
|
4799
|
+
|
|
4800
|
+
# Get SSH info (may provision)
|
|
4801
|
+
if show_status:
|
|
4802
|
+
typer.echo("[wafer] Connecting to target...", err=True)
|
|
4803
|
+
|
|
4804
|
+
try:
|
|
4805
|
+
ssh_info = trio.run(get_target_ssh_info, target_config)
|
|
4806
|
+
except TargetExecError as e:
|
|
4807
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4808
|
+
raise typer.Exit(1) from None
|
|
4809
|
+
|
|
4810
|
+
if show_status:
|
|
4811
|
+
typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
|
|
4812
|
+
|
|
4813
|
+
# Build command string
|
|
3410
4814
|
if isinstance(command, list):
|
|
3411
4815
|
import shlex
|
|
3412
4816
|
|
|
3413
|
-
# Remove leading "--" if present
|
|
4817
|
+
# Remove leading "--" if present
|
|
3414
4818
|
if command and command[0] == "--":
|
|
3415
4819
|
command = command[1:]
|
|
3416
|
-
|
|
3417
|
-
|
|
4820
|
+
|
|
4821
|
+
if not command:
|
|
4822
|
+
typer.echo("Error: No command specified", err=True)
|
|
4823
|
+
raise typer.Exit(1)
|
|
4824
|
+
|
|
4825
|
+
if len(command) == 1:
|
|
4826
|
+
command_str = command[0]
|
|
4827
|
+
else:
|
|
4828
|
+
command_str = shlex.join(command)
|
|
3418
4829
|
else:
|
|
3419
4830
|
command_str = command
|
|
3420
4831
|
|
|
4832
|
+
# Default timeout
|
|
4833
|
+
effective_timeout = timeout if timeout is not None else 300
|
|
4834
|
+
|
|
4835
|
+
if show_status:
|
|
4836
|
+
typer.echo(f"[wafer] Executing (timeout: {effective_timeout}s)...", err=True)
|
|
4837
|
+
|
|
4838
|
+
# Execute
|
|
4839
|
+
try:
|
|
4840
|
+
exit_code = exec_on_target_sync(ssh_info, command_str, effective_timeout)
|
|
4841
|
+
except TargetExecError as e:
|
|
4842
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4843
|
+
raise typer.Exit(1) from None
|
|
4844
|
+
|
|
4845
|
+
if show_status:
|
|
4846
|
+
typer.echo(f"[wafer] Exit code: {exit_code}", err=True)
|
|
4847
|
+
|
|
4848
|
+
raise typer.Exit(exit_code)
|
|
4849
|
+
|
|
4850
|
+
|
|
4851
|
+
@targets_ops_app.command("ssh")
|
|
4852
|
+
def targets_ssh(
|
|
4853
|
+
target: str = typer.Argument(
|
|
4854
|
+
...,
|
|
4855
|
+
help="Target name",
|
|
4856
|
+
autocompletion=complete_target_name,
|
|
4857
|
+
),
|
|
4858
|
+
) -> None:
|
|
4859
|
+
"""SSH into a configured target.
|
|
4860
|
+
|
|
4861
|
+
Provisions the target if needed (RunPod, DigitalOcean), then starts an interactive SSH session.
|
|
4862
|
+
For cloud targets, the instance is kept alive - use 'wafer config targets cleanup <name>' to terminate.
|
|
4863
|
+
|
|
4864
|
+
Examples:
|
|
4865
|
+
wafer targets ssh runpod-mi300x
|
|
4866
|
+
wafer targets ssh my-baremetal-server
|
|
4867
|
+
"""
|
|
4868
|
+
from .targets import load_target
|
|
4869
|
+
from .targets_ops import TargetExecError, get_target_ssh_info
|
|
4870
|
+
|
|
4871
|
+
# Load target
|
|
4872
|
+
try:
|
|
4873
|
+
target_config = load_target(target)
|
|
4874
|
+
except FileNotFoundError as e:
|
|
4875
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4876
|
+
typer.echo("List available targets with: wafer config targets list", err=True)
|
|
4877
|
+
raise typer.Exit(1) from None
|
|
4878
|
+
except ValueError as e:
|
|
4879
|
+
typer.echo(f"Error loading target config: {e}", err=True)
|
|
4880
|
+
raise typer.Exit(1) from None
|
|
4881
|
+
|
|
4882
|
+
typer.echo(f"Connecting to target: {target}...", err=True)
|
|
4883
|
+
|
|
4884
|
+
# Get SSH info (may provision)
|
|
4885
|
+
try:
|
|
4886
|
+
ssh_info = trio.run(get_target_ssh_info, target_config)
|
|
4887
|
+
except TargetExecError as e:
|
|
4888
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4889
|
+
raise typer.Exit(1) from None
|
|
4890
|
+
|
|
4891
|
+
# Build SSH command
|
|
4892
|
+
ssh_args = [
|
|
4893
|
+
"ssh",
|
|
4894
|
+
"-i",
|
|
4895
|
+
str(ssh_info.key_path),
|
|
4896
|
+
"-p",
|
|
4897
|
+
str(ssh_info.port),
|
|
4898
|
+
"-o",
|
|
4899
|
+
"StrictHostKeyChecking=no",
|
|
4900
|
+
"-o",
|
|
4901
|
+
"UserKnownHostsFile=/dev/null",
|
|
4902
|
+
f"{ssh_info.user}@{ssh_info.host}",
|
|
4903
|
+
]
|
|
4904
|
+
|
|
4905
|
+
# Replace current process with SSH
|
|
4906
|
+
os.execvp("ssh", ssh_args)
|
|
4907
|
+
|
|
4908
|
+
|
|
4909
|
+
@targets_ops_app.command("sync")
|
|
4910
|
+
def targets_sync(
|
|
4911
|
+
target: str = typer.Argument(
|
|
4912
|
+
...,
|
|
4913
|
+
help="Target name",
|
|
4914
|
+
autocompletion=complete_target_name,
|
|
4915
|
+
),
|
|
4916
|
+
path: Path = typer.Argument(..., help="Local file or directory to sync"),
|
|
4917
|
+
dest: str | None = typer.Option(
|
|
4918
|
+
None,
|
|
4919
|
+
"--dest",
|
|
4920
|
+
"-d",
|
|
4921
|
+
help="Remote destination path (default: /tmp/<basename>)",
|
|
4922
|
+
),
|
|
4923
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
4924
|
+
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
4925
|
+
) -> None:
|
|
4926
|
+
"""Sync local files to a configured target.
|
|
4927
|
+
|
|
4928
|
+
Uses rsync over SSH to copy files to the target. Provisions the target if needed.
|
|
4929
|
+
|
|
4930
|
+
Examples:
|
|
4931
|
+
wafer targets sync runpod-mi300x ./my-project
|
|
4932
|
+
wafer targets sync runpod-mi300x ./script.py --dest /workspace/script.py
|
|
4933
|
+
wafer targets sync my-server ./kernels --dest /tmp/kernels
|
|
4934
|
+
"""
|
|
4935
|
+
from .global_config import get_preferences
|
|
4936
|
+
from .targets import load_target
|
|
4937
|
+
from .targets_ops import TargetExecError, get_target_ssh_info, sync_to_target
|
|
4938
|
+
|
|
4939
|
+
# Determine verbosity
|
|
4940
|
+
prefs = get_preferences()
|
|
4941
|
+
if quiet:
|
|
4942
|
+
show_status = False
|
|
4943
|
+
elif verbose:
|
|
4944
|
+
show_status = True
|
|
4945
|
+
else:
|
|
4946
|
+
show_status = prefs.mode == "explicit"
|
|
4947
|
+
|
|
4948
|
+
# Validate path
|
|
4949
|
+
if not path.exists():
|
|
4950
|
+
typer.echo(f"Error: Path not found: {path}", err=True)
|
|
4951
|
+
raise typer.Exit(1)
|
|
4952
|
+
|
|
4953
|
+
# Load target
|
|
4954
|
+
try:
|
|
4955
|
+
target_config = load_target(target)
|
|
4956
|
+
except FileNotFoundError as e:
|
|
4957
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4958
|
+
typer.echo("List available targets with: wafer config targets list", err=True)
|
|
4959
|
+
raise typer.Exit(1) from None
|
|
4960
|
+
except ValueError as e:
|
|
4961
|
+
typer.echo(f"Error loading target config: {e}", err=True)
|
|
4962
|
+
raise typer.Exit(1) from None
|
|
4963
|
+
|
|
4964
|
+
if show_status:
|
|
4965
|
+
typer.echo(f"[wafer] Target: {target} ({type(target_config).__name__})", err=True)
|
|
4966
|
+
|
|
4967
|
+
# Get SSH info (may provision)
|
|
4968
|
+
if show_status:
|
|
4969
|
+
typer.echo("[wafer] Connecting to target...", err=True)
|
|
4970
|
+
|
|
4971
|
+
try:
|
|
4972
|
+
ssh_info = trio.run(get_target_ssh_info, target_config)
|
|
4973
|
+
except TargetExecError as e:
|
|
4974
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4975
|
+
raise typer.Exit(1) from None
|
|
4976
|
+
|
|
4977
|
+
if show_status:
|
|
4978
|
+
typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
|
|
4979
|
+
|
|
4980
|
+
# Sync
|
|
4981
|
+
def on_progress(msg: str) -> None:
|
|
4982
|
+
if show_status:
|
|
4983
|
+
typer.echo(f"[wafer] {msg}", err=True)
|
|
4984
|
+
|
|
3421
4985
|
try:
|
|
3422
|
-
|
|
3423
|
-
|
|
3424
|
-
command=command_str,
|
|
3425
|
-
timeout_seconds=effective_timeout,
|
|
3426
|
-
)
|
|
3427
|
-
except RuntimeError as e:
|
|
4986
|
+
file_count = sync_to_target(ssh_info, path.resolve(), dest, on_progress)
|
|
4987
|
+
except TargetExecError as e:
|
|
3428
4988
|
typer.echo(f"Error: {e}", err=True)
|
|
3429
4989
|
raise typer.Exit(1) from None
|
|
3430
4990
|
|
|
3431
4991
|
if show_status:
|
|
3432
|
-
typer.echo(f"[wafer]
|
|
3433
|
-
|
|
3434
|
-
raise typer.Exit(exit_code)
|
|
4992
|
+
typer.echo(f"[wafer] Done. Synced {file_count} files.", err=True)
|
|
3435
4993
|
|
|
3436
4994
|
|
|
3437
|
-
@
|
|
3438
|
-
def
|
|
3439
|
-
|
|
3440
|
-
|
|
3441
|
-
),
|
|
4995
|
+
@targets_ops_app.command("scp")
|
|
4996
|
+
def targets_scp(
|
|
4997
|
+
source: str = typer.Argument(..., help="Source path (prefix with target: for remote)"),
|
|
4998
|
+
dest: str = typer.Argument(..., help="Destination path (prefix with target: for remote)"),
|
|
4999
|
+
recursive: bool = typer.Option(False, "-r", "--recursive", help="Copy directories recursively"),
|
|
5000
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
5001
|
+
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
3442
5002
|
) -> None:
|
|
3443
|
-
"""
|
|
5003
|
+
"""Copy files to/from a target using scp-style syntax.
|
|
3444
5004
|
|
|
3445
|
-
|
|
3446
|
-
|
|
5005
|
+
Use target: prefix to indicate remote paths. Exactly one of source or dest
|
|
5006
|
+
must be remote.
|
|
3447
5007
|
|
|
3448
5008
|
Examples:
|
|
3449
|
-
wafer
|
|
3450
|
-
wafer
|
|
5009
|
+
wafer targets scp runpod-mi300x:/tmp/trace.json ./trace.json # download
|
|
5010
|
+
wafer targets scp ./script.py runpod-mi300x:/tmp/script.py # upload
|
|
5011
|
+
wafer targets scp -r ./kernels runpod-mi300x:/tmp/kernels # upload dir
|
|
5012
|
+
wafer targets scp -r runpod-mi300x:/tmp/results ./results # download dir
|
|
3451
5013
|
"""
|
|
3452
|
-
import
|
|
5014
|
+
from .global_config import get_preferences
|
|
5015
|
+
from .targets import load_target
|
|
5016
|
+
from .targets_ops import TargetExecError, get_target_ssh_info, parse_scp_path, scp_transfer
|
|
5017
|
+
|
|
5018
|
+
# Determine verbosity
|
|
5019
|
+
prefs = get_preferences()
|
|
5020
|
+
if quiet:
|
|
5021
|
+
show_status = False
|
|
5022
|
+
elif verbose:
|
|
5023
|
+
show_status = True
|
|
5024
|
+
else:
|
|
5025
|
+
show_status = prefs.mode == "explicit"
|
|
3453
5026
|
|
|
3454
|
-
|
|
5027
|
+
# Parse source and dest
|
|
5028
|
+
source_target, source_path = parse_scp_path(source)
|
|
5029
|
+
dest_target, dest_path = parse_scp_path(dest)
|
|
3455
5030
|
|
|
3456
|
-
#
|
|
5031
|
+
# Validate: exactly one must be remote
|
|
5032
|
+
if source_target and dest_target:
|
|
5033
|
+
typer.echo("Error: Both paths are remote. Use ssh to transfer between remotes.", err=True)
|
|
5034
|
+
raise typer.Exit(1)
|
|
5035
|
+
|
|
5036
|
+
if not source_target and not dest_target:
|
|
5037
|
+
typer.echo("Error: Both paths are local. Use regular cp command.", err=True)
|
|
5038
|
+
raise typer.Exit(1)
|
|
5039
|
+
|
|
5040
|
+
# Determine direction and target
|
|
5041
|
+
is_download = source_target is not None
|
|
5042
|
+
target_name = source_target if is_download else dest_target
|
|
5043
|
+
|
|
5044
|
+
# Load target
|
|
3457
5045
|
try:
|
|
3458
|
-
|
|
3459
|
-
except
|
|
3460
|
-
typer.echo(f"Error: {
|
|
5046
|
+
target_config = load_target(target_name)
|
|
5047
|
+
except FileNotFoundError:
|
|
5048
|
+
typer.echo(f"Error: Target '{target_name}' not found.", err=True)
|
|
5049
|
+
typer.echo("Run 'wafer config targets list' to see available targets.", err=True)
|
|
5050
|
+
raise typer.Exit(1) from None
|
|
5051
|
+
except ValueError as e:
|
|
5052
|
+
typer.echo(f"Error loading target config: {e}", err=True)
|
|
3461
5053
|
raise typer.Exit(1) from None
|
|
3462
5054
|
|
|
3463
|
-
|
|
5055
|
+
# Validate local path exists (for upload)
|
|
5056
|
+
if not is_download:
|
|
5057
|
+
local_path = Path(source_path)
|
|
5058
|
+
if not local_path.exists():
|
|
5059
|
+
typer.echo(f"Error: Local path '{source_path}' does not exist.", err=True)
|
|
5060
|
+
raise typer.Exit(1)
|
|
5061
|
+
if local_path.is_dir() and not recursive:
|
|
5062
|
+
typer.echo(
|
|
5063
|
+
f"Error: '{source_path}' is a directory. Use -r flag for recursive copy.", err=True
|
|
5064
|
+
)
|
|
5065
|
+
raise typer.Exit(1)
|
|
3464
5066
|
|
|
3465
|
-
|
|
5067
|
+
if show_status:
|
|
5068
|
+
typer.echo(f"[wafer] Target: {target_name} ({type(target_config).__name__})", err=True)
|
|
5069
|
+
typer.echo("[wafer] Connecting to target...", err=True)
|
|
5070
|
+
|
|
5071
|
+
# Get SSH info (may provision)
|
|
3466
5072
|
try:
|
|
3467
|
-
|
|
3468
|
-
except
|
|
5073
|
+
ssh_info = trio.run(get_target_ssh_info, target_config)
|
|
5074
|
+
except TargetExecError as e:
|
|
3469
5075
|
typer.echo(f"Error: {e}", err=True)
|
|
3470
5076
|
raise typer.Exit(1) from None
|
|
3471
5077
|
|
|
3472
|
-
|
|
3473
|
-
|
|
3474
|
-
"
|
|
3475
|
-
"
|
|
3476
|
-
str(creds.key_path),
|
|
3477
|
-
"-p",
|
|
3478
|
-
str(creds.port),
|
|
3479
|
-
"-o",
|
|
3480
|
-
"StrictHostKeyChecking=no",
|
|
3481
|
-
"-o",
|
|
3482
|
-
"UserKnownHostsFile=/dev/null",
|
|
3483
|
-
f"{creds.user}@{creds.host}",
|
|
3484
|
-
]
|
|
5078
|
+
if show_status:
|
|
5079
|
+
typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
|
|
5080
|
+
direction = "Downloading" if is_download else "Uploading"
|
|
5081
|
+
typer.echo(f"[wafer] {direction}...", err=True)
|
|
3485
5082
|
|
|
3486
|
-
#
|
|
3487
|
-
|
|
5083
|
+
# Transfer
|
|
5084
|
+
try:
|
|
5085
|
+
if is_download:
|
|
5086
|
+
scp_transfer(ssh_info, source_path, dest_path, is_download=True, recursive=recursive)
|
|
5087
|
+
else:
|
|
5088
|
+
scp_transfer(ssh_info, source_path, dest_path, is_download=False, recursive=recursive)
|
|
5089
|
+
except TargetExecError as e:
|
|
5090
|
+
typer.echo(f"Error: {e}", err=True)
|
|
5091
|
+
raise typer.Exit(1) from None
|
|
3488
5092
|
|
|
5093
|
+
if show_status:
|
|
5094
|
+
typer.echo("[wafer] Done.", err=True)
|
|
3489
5095
|
|
|
3490
|
-
|
|
3491
|
-
|
|
3492
|
-
|
|
3493
|
-
|
|
5096
|
+
|
|
5097
|
+
@targets_ops_app.command("ensure")
|
|
5098
|
+
def targets_ensure( # noqa: PLR0915
|
|
5099
|
+
target: str = typer.Argument(
|
|
5100
|
+
None,
|
|
5101
|
+
help="Target name",
|
|
5102
|
+
autocompletion=complete_target_name,
|
|
3494
5103
|
),
|
|
3495
|
-
|
|
5104
|
+
tool: str = typer.Argument(None, help="Tool to ensure is installed"),
|
|
5105
|
+
check_only: bool = typer.Option(False, "--check-only", "-c", help="Only check, don't install"),
|
|
5106
|
+
force: bool = typer.Option(False, "--force", "-f", help="Reinstall even if present"),
|
|
5107
|
+
list_tools: bool = typer.Option(False, "--list", "-l", help="List available tools"),
|
|
5108
|
+
timeout: int = typer.Option(300, "--timeout", "-t", help="Installation timeout in seconds"),
|
|
3496
5109
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
3497
5110
|
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
3498
5111
|
) -> None:
|
|
3499
|
-
"""
|
|
5112
|
+
"""Ensure a tool is installed on a target.
|
|
3500
5113
|
|
|
3501
|
-
|
|
3502
|
-
|
|
5114
|
+
Checks if a tool exists on the target and installs it if missing.
|
|
5115
|
+
Useful for profiling tools like rocprof-compute that aren't pre-installed.
|
|
3503
5116
|
|
|
3504
5117
|
Examples:
|
|
3505
|
-
wafer
|
|
3506
|
-
wafer
|
|
3507
|
-
wafer
|
|
3508
|
-
wafer
|
|
5118
|
+
wafer targets ensure runpod-mi300x rocprof-compute
|
|
5119
|
+
wafer targets ensure runpod-mi300x rocprof-compute --check-only
|
|
5120
|
+
wafer targets ensure runpod-mi300x rocprof-compute --force
|
|
5121
|
+
wafer targets ensure --list
|
|
3509
5122
|
"""
|
|
3510
5123
|
from .global_config import get_preferences
|
|
3511
|
-
from .
|
|
5124
|
+
from .targets import load_target
|
|
5125
|
+
from .targets_ops import (
|
|
5126
|
+
TOOL_REGISTRY,
|
|
5127
|
+
TargetExecError,
|
|
5128
|
+
ensure_tool,
|
|
5129
|
+
get_target_platform,
|
|
5130
|
+
get_target_ssh_info,
|
|
5131
|
+
)
|
|
3512
5132
|
|
|
3513
|
-
#
|
|
5133
|
+
# Handle --list flag
|
|
5134
|
+
if list_tools:
|
|
5135
|
+
typer.echo("Available tools:\n")
|
|
5136
|
+
typer.echo("AMD tools:")
|
|
5137
|
+
for name, spec in sorted(TOOL_REGISTRY.items()):
|
|
5138
|
+
if spec.platform == "amd":
|
|
5139
|
+
auto = "auto-install" if spec.install_cmd else "manual"
|
|
5140
|
+
typer.echo(f" {name:20} ({auto}) - {spec.description}")
|
|
5141
|
+
|
|
5142
|
+
typer.echo("\nNVIDIA tools:")
|
|
5143
|
+
for name, spec in sorted(TOOL_REGISTRY.items()):
|
|
5144
|
+
if spec.platform == "nvidia":
|
|
5145
|
+
auto = "auto-install" if spec.install_cmd else "manual"
|
|
5146
|
+
typer.echo(f" {name:20} ({auto}) - {spec.description}")
|
|
5147
|
+
|
|
5148
|
+
typer.echo("\nCross-platform:")
|
|
5149
|
+
for name, spec in sorted(TOOL_REGISTRY.items()):
|
|
5150
|
+
if spec.platform == "any":
|
|
5151
|
+
auto = "auto-install" if spec.install_cmd else "manual"
|
|
5152
|
+
typer.echo(f" {name:20} ({auto}) - {spec.description}")
|
|
5153
|
+
return
|
|
5154
|
+
|
|
5155
|
+
# Require target and tool if not listing
|
|
5156
|
+
if not target:
|
|
5157
|
+
typer.echo("Error: Missing argument 'TARGET'", err=True)
|
|
5158
|
+
typer.echo("Usage: wafer targets ensure TARGET TOOL", err=True)
|
|
5159
|
+
typer.echo(" or: wafer targets ensure --list", err=True)
|
|
5160
|
+
raise typer.Exit(1)
|
|
5161
|
+
|
|
5162
|
+
if not tool:
|
|
5163
|
+
typer.echo("Error: Missing argument 'TOOL'", err=True)
|
|
5164
|
+
typer.echo("Usage: wafer targets ensure TARGET TOOL", err=True)
|
|
5165
|
+
typer.echo(" or: wafer targets ensure --list", err=True)
|
|
5166
|
+
raise typer.Exit(1)
|
|
5167
|
+
|
|
5168
|
+
# Check tool exists
|
|
5169
|
+
if tool not in TOOL_REGISTRY:
|
|
5170
|
+
typer.echo(f"Error: Unknown tool '{tool}'", err=True)
|
|
5171
|
+
typer.echo(f"Available tools: {', '.join(sorted(TOOL_REGISTRY.keys()))}", err=True)
|
|
5172
|
+
typer.echo("Run 'wafer targets ensure --list' for details.", err=True)
|
|
5173
|
+
raise typer.Exit(1)
|
|
5174
|
+
|
|
5175
|
+
spec = TOOL_REGISTRY[tool]
|
|
5176
|
+
|
|
5177
|
+
# Determine verbosity
|
|
3514
5178
|
prefs = get_preferences()
|
|
3515
5179
|
if quiet:
|
|
3516
5180
|
show_status = False
|
|
@@ -3519,33 +5183,72 @@ def workspaces_sync(
|
|
|
3519
5183
|
else:
|
|
3520
5184
|
show_status = prefs.mode == "explicit"
|
|
3521
5185
|
|
|
3522
|
-
#
|
|
3523
|
-
if not path.exists():
|
|
3524
|
-
typer.echo(f"Error: Path not found: {path}", err=True)
|
|
3525
|
-
raise typer.Exit(1)
|
|
3526
|
-
|
|
3527
|
-
# Resolve workspace
|
|
5186
|
+
# Load target
|
|
3528
5187
|
try:
|
|
3529
|
-
|
|
3530
|
-
except
|
|
5188
|
+
target_config = load_target(target)
|
|
5189
|
+
except FileNotFoundError as e:
|
|
3531
5190
|
typer.echo(f"Error: {e}", err=True)
|
|
5191
|
+
typer.echo("List available targets with: wafer config targets list", err=True)
|
|
5192
|
+
raise typer.Exit(1) from None
|
|
5193
|
+
except ValueError as e:
|
|
5194
|
+
typer.echo(f"Error loading target config: {e}", err=True)
|
|
3532
5195
|
raise typer.Exit(1) from None
|
|
3533
5196
|
|
|
3534
|
-
|
|
3535
|
-
|
|
5197
|
+
# Platform validation
|
|
5198
|
+
platform = get_target_platform(target_config)
|
|
5199
|
+
if spec.platform != "any" and spec.platform != platform:
|
|
5200
|
+
typer.echo(
|
|
5201
|
+
f"Error: {tool} is an {spec.platform.upper()} tool but target '{target}' "
|
|
5202
|
+
f"is {platform.upper()}",
|
|
5203
|
+
err=True,
|
|
5204
|
+
)
|
|
5205
|
+
raise typer.Exit(1)
|
|
3536
5206
|
|
|
3537
|
-
|
|
3538
|
-
|
|
3539
|
-
|
|
5207
|
+
if show_status:
|
|
5208
|
+
typer.echo(f"[wafer] Target: {target} ({platform.upper()})", err=True)
|
|
5209
|
+
typer.echo(f"[wafer] Checking for {tool}...", err=True)
|
|
3540
5210
|
|
|
5211
|
+
# Get SSH info (may provision)
|
|
3541
5212
|
try:
|
|
3542
|
-
|
|
3543
|
-
|
|
3544
|
-
)
|
|
3545
|
-
except RuntimeError as e:
|
|
5213
|
+
ssh_info = trio.run(get_target_ssh_info, target_config)
|
|
5214
|
+
except TargetExecError as e:
|
|
3546
5215
|
typer.echo(f"Error: {e}", err=True)
|
|
3547
5216
|
raise typer.Exit(1) from None
|
|
3548
5217
|
|
|
5218
|
+
if show_status:
|
|
5219
|
+
typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
|
|
5220
|
+
|
|
5221
|
+
# Check-only mode
|
|
5222
|
+
if check_only:
|
|
5223
|
+
from .targets_ops import TargetExecError, exec_on_target_sync
|
|
5224
|
+
|
|
5225
|
+
try:
|
|
5226
|
+
exit_code = exec_on_target_sync(ssh_info, spec.check_cmd, timeout_seconds=30)
|
|
5227
|
+
except TargetExecError as e:
|
|
5228
|
+
typer.echo(f"Error: {e}", err=True)
|
|
5229
|
+
raise typer.Exit(1) from None
|
|
5230
|
+
if exit_code == 0:
|
|
5231
|
+
typer.echo(f"{tool} is installed")
|
|
5232
|
+
else:
|
|
5233
|
+
typer.echo(f"{tool} is NOT installed", err=True)
|
|
5234
|
+
raise typer.Exit(1)
|
|
5235
|
+
return
|
|
5236
|
+
|
|
5237
|
+
# Ensure tool is installed
|
|
5238
|
+
result = ensure_tool(ssh_info, tool, force=force, timeout=timeout)
|
|
5239
|
+
|
|
5240
|
+
if result.error:
|
|
5241
|
+
typer.echo(f"Error: {result.error}", err=True)
|
|
5242
|
+
raise typer.Exit(1)
|
|
5243
|
+
|
|
5244
|
+
if result.already_installed:
|
|
5245
|
+
typer.echo(f"{tool} is already installed")
|
|
5246
|
+
elif result.installed:
|
|
5247
|
+
if result.verified:
|
|
5248
|
+
typer.echo(f"{tool} installed successfully")
|
|
5249
|
+
else:
|
|
5250
|
+
typer.echo(f"{tool} installed (verification skipped)")
|
|
5251
|
+
|
|
3549
5252
|
|
|
3550
5253
|
# =============================================================================
|
|
3551
5254
|
# Perfetto trace analysis commands
|
|
@@ -3830,13 +5533,39 @@ def ncu_analyze(
|
|
|
3830
5533
|
|
|
3831
5534
|
|
|
3832
5535
|
# =============================================================================
|
|
3833
|
-
# NSYS
|
|
5536
|
+
# NSYS commands
|
|
3834
5537
|
# =============================================================================
|
|
3835
5538
|
|
|
3836
5539
|
|
|
5540
|
+
@nsys_app.command("check")
|
|
5541
|
+
def nsys_check() -> None:
|
|
5542
|
+
"""Check if NSYS (Nsight Systems) is installed and show version.
|
|
5543
|
+
|
|
5544
|
+
NSYS is required for local analysis. If not installed, shows install instructions.
|
|
5545
|
+
|
|
5546
|
+
Examples:
|
|
5547
|
+
wafer nvidia nsys check
|
|
5548
|
+
"""
|
|
5549
|
+
from .nsys_analyze import check_nsys_installation
|
|
5550
|
+
|
|
5551
|
+
result = check_nsys_installation()
|
|
5552
|
+
|
|
5553
|
+
if result.installed:
|
|
5554
|
+
typer.echo(f"✓ NSYS installed: {result.path}")
|
|
5555
|
+
if result.version:
|
|
5556
|
+
typer.echo(f" Version: {result.version}")
|
|
5557
|
+
else:
|
|
5558
|
+
typer.echo("✗ NSYS not installed")
|
|
5559
|
+
if result.install_command:
|
|
5560
|
+
typer.echo(f" Install with: {result.install_command}")
|
|
5561
|
+
|
|
5562
|
+
|
|
3837
5563
|
@nsys_app.command("analyze")
|
|
3838
5564
|
def nsys_analyze(
|
|
3839
5565
|
filepath: Path = typer.Argument(..., help="Path to .nsys-rep profile file"),
|
|
5566
|
+
output_dir: Path | None = typer.Option(
|
|
5567
|
+
None, "--output-dir", "-o", help="Output directory for analysis files"
|
|
5568
|
+
),
|
|
3840
5569
|
json_output: bool = typer.Option(
|
|
3841
5570
|
False, "--json", help="Output raw JSON instead of formatted text"
|
|
3842
5571
|
),
|
|
@@ -3845,6 +5574,12 @@ def nsys_analyze(
|
|
|
3845
5574
|
"--remote/--local",
|
|
3846
5575
|
help="Force remote (via API) or local analysis. Default: auto-detect (remote if nsys not installed locally)",
|
|
3847
5576
|
),
|
|
5577
|
+
target: str | None = typer.Option(
|
|
5578
|
+
None,
|
|
5579
|
+
"--target",
|
|
5580
|
+
"-t",
|
|
5581
|
+
help="Remote target: 'workspace:id' for workspace execution, or target name from ~/.wafer/targets/",
|
|
5582
|
+
),
|
|
3848
5583
|
) -> None:
|
|
3849
5584
|
"""Analyze an NVIDIA Nsight Systems profile (.nsys-rep file).
|
|
3850
5585
|
|
|
@@ -3853,10 +5588,20 @@ def nsys_analyze(
|
|
|
3853
5588
|
By default, uses local nsys if available, otherwise runs analysis
|
|
3854
5589
|
remotely via wafer-api (requires authentication: wafer login).
|
|
3855
5590
|
|
|
5591
|
+
Supports multiple execution modes:
|
|
5592
|
+
- Local: Uses local nsys CLI (no GPU required for analysis)
|
|
5593
|
+
- Remote API: Uploads file and runs analysis on Modal
|
|
5594
|
+
- Workspace: Runs analysis on a Wafer workspace via SSH
|
|
5595
|
+
- Target: Runs analysis on a configured target machine via SSH
|
|
5596
|
+
|
|
3856
5597
|
Examples:
|
|
3857
5598
|
wafer nvidia nsys analyze profile.nsys-rep
|
|
3858
5599
|
wafer nvidia nsys analyze profile.nsys-rep --json
|
|
5600
|
+
wafer nvidia nsys analyze profile.nsys-rep --local
|
|
3859
5601
|
wafer nvidia nsys analyze profile.nsys-rep --remote
|
|
5602
|
+
wafer nvidia nsys analyze profile.nsys-rep --target workspace:abc123
|
|
5603
|
+
wafer nvidia nsys analyze profile.nsys-rep --target vultr-b200
|
|
5604
|
+
wafer nvidia nsys analyze profile.nsys-rep -o ./results/
|
|
3860
5605
|
"""
|
|
3861
5606
|
from .nsys_analyze import analyze_nsys_profile
|
|
3862
5607
|
|
|
@@ -3868,11 +5613,20 @@ def nsys_analyze(
|
|
|
3868
5613
|
typer.echo(f"Error: Expected .nsys-rep file, got: {filepath.suffix}", err=True)
|
|
3869
5614
|
raise typer.Exit(1)
|
|
3870
5615
|
|
|
5616
|
+
# Warn if both remote flag and target are specified
|
|
5617
|
+
if target and remote is not None:
|
|
5618
|
+
typer.echo(
|
|
5619
|
+
"Warning: --target overrides --remote/--local flag",
|
|
5620
|
+
err=True,
|
|
5621
|
+
)
|
|
5622
|
+
|
|
3871
5623
|
try:
|
|
3872
5624
|
result = analyze_nsys_profile(
|
|
3873
5625
|
filepath,
|
|
3874
5626
|
json_output=json_output,
|
|
3875
5627
|
remote=remote,
|
|
5628
|
+
target=target,
|
|
5629
|
+
output_dir=output_dir,
|
|
3876
5630
|
)
|
|
3877
5631
|
typer.echo(result)
|
|
3878
5632
|
except FileNotFoundError as e:
|
|
@@ -3883,6 +5637,150 @@ def nsys_analyze(
|
|
|
3883
5637
|
raise typer.Exit(1) from None
|
|
3884
5638
|
|
|
3885
5639
|
|
|
5640
|
+
@nsys_app.command("profile", context_settings={"allow_interspersed_args": False})
|
|
5641
|
+
def nsys_profile(
|
|
5642
|
+
command: list[str] = typer.Argument(..., help="Command to profile"),
|
|
5643
|
+
output: str = typer.Option(
|
|
5644
|
+
"profile",
|
|
5645
|
+
"--output",
|
|
5646
|
+
"-o",
|
|
5647
|
+
help="Output filename (without .nsys-rep extension)",
|
|
5648
|
+
),
|
|
5649
|
+
trace: str | None = typer.Option(
|
|
5650
|
+
None,
|
|
5651
|
+
"--trace",
|
|
5652
|
+
"-t",
|
|
5653
|
+
help="Trace APIs to capture (comma-separated: cuda,nvtx,osrt,cudnn,cublas). Default: cuda",
|
|
5654
|
+
),
|
|
5655
|
+
duration: int | None = typer.Option(
|
|
5656
|
+
None,
|
|
5657
|
+
"--duration",
|
|
5658
|
+
"-d",
|
|
5659
|
+
help="Maximum profiling duration in seconds",
|
|
5660
|
+
),
|
|
5661
|
+
target: str | None = typer.Option(
|
|
5662
|
+
None,
|
|
5663
|
+
"--target",
|
|
5664
|
+
help="Remote target: 'workspace:id' for workspace execution, or target name from ~/.wafer/targets/",
|
|
5665
|
+
),
|
|
5666
|
+
analyze: bool = typer.Option(
|
|
5667
|
+
False,
|
|
5668
|
+
"--analyze",
|
|
5669
|
+
"-a",
|
|
5670
|
+
help="Automatically analyze the profile after completion",
|
|
5671
|
+
),
|
|
5672
|
+
json_output: bool = typer.Option(
|
|
5673
|
+
False,
|
|
5674
|
+
"--json",
|
|
5675
|
+
help="Output analysis as JSON (only with --analyze)",
|
|
5676
|
+
),
|
|
5677
|
+
verbose: bool = typer.Option(
|
|
5678
|
+
False,
|
|
5679
|
+
"--verbose",
|
|
5680
|
+
"-v",
|
|
5681
|
+
help="Show verbose progress messages",
|
|
5682
|
+
),
|
|
5683
|
+
extra_args: str | None = typer.Option(
|
|
5684
|
+
None,
|
|
5685
|
+
"--extra",
|
|
5686
|
+
help="Extra arguments to pass to nsys profile",
|
|
5687
|
+
),
|
|
5688
|
+
) -> None:
|
|
5689
|
+
"""Profile a command with NVIDIA Nsight Systems.
|
|
5690
|
+
|
|
5691
|
+
Runs nsys profile on the specified command and generates a .nsys-rep file.
|
|
5692
|
+
Profiling requires an NVIDIA GPU. Use --target to run on a remote GPU server
|
|
5693
|
+
or workspace.
|
|
5694
|
+
|
|
5695
|
+
Examples:
|
|
5696
|
+
wafer nvidia nsys profile -- python train.py
|
|
5697
|
+
wafer nvidia nsys profile -o gemm_profile -- ./gemm_kernel
|
|
5698
|
+
wafer nvidia nsys profile --trace cuda,nvtx -- python model.py
|
|
5699
|
+
wafer nvidia nsys profile --duration 60 -- ./long_running_app
|
|
5700
|
+
wafer nvidia nsys profile --target workspace:abc123 -- python test.py
|
|
5701
|
+
wafer nvidia nsys profile --target vultr-b200 -- ./benchmark
|
|
5702
|
+
wafer nvidia nsys profile --analyze -- python train.py
|
|
5703
|
+
wafer nvidia nsys profile --analyze --json -- ./kernel > results.json
|
|
5704
|
+
"""
|
|
5705
|
+
# Parse command
|
|
5706
|
+
import shlex
|
|
5707
|
+
|
|
5708
|
+
from .nsys_analyze import _parse_target
|
|
5709
|
+
from .nsys_profile import (
|
|
5710
|
+
NSYSProfileOptions,
|
|
5711
|
+
profile_and_analyze,
|
|
5712
|
+
profile_local,
|
|
5713
|
+
profile_remote_ssh,
|
|
5714
|
+
profile_workspace,
|
|
5715
|
+
)
|
|
5716
|
+
|
|
5717
|
+
if isinstance(command, list):
|
|
5718
|
+
# Remove leading "--" if present
|
|
5719
|
+
if command and command[0] == "--":
|
|
5720
|
+
command = command[1:]
|
|
5721
|
+
if len(command) == 1:
|
|
5722
|
+
command_str = command[0]
|
|
5723
|
+
else:
|
|
5724
|
+
command_str = shlex.join(command)
|
|
5725
|
+
else:
|
|
5726
|
+
command_str = command
|
|
5727
|
+
|
|
5728
|
+
if not command_str:
|
|
5729
|
+
typer.echo("Error: No command specified", err=True)
|
|
5730
|
+
raise typer.Exit(1)
|
|
5731
|
+
|
|
5732
|
+
# Parse trace options
|
|
5733
|
+
trace_list = trace.split(",") if trace else None
|
|
5734
|
+
|
|
5735
|
+
# Build options
|
|
5736
|
+
options = NSYSProfileOptions(
|
|
5737
|
+
command=command_str,
|
|
5738
|
+
output=output,
|
|
5739
|
+
trace=trace_list,
|
|
5740
|
+
duration=duration,
|
|
5741
|
+
extra_args=extra_args,
|
|
5742
|
+
)
|
|
5743
|
+
|
|
5744
|
+
if verbose:
|
|
5745
|
+
typer.echo(f"[nsys] Command: {command_str}", err=True)
|
|
5746
|
+
if target:
|
|
5747
|
+
typer.echo(f"[nsys] Target: {target}", err=True)
|
|
5748
|
+
|
|
5749
|
+
# Execute
|
|
5750
|
+
if analyze:
|
|
5751
|
+
profile_result, analysis_result = profile_and_analyze(
|
|
5752
|
+
options,
|
|
5753
|
+
target=target,
|
|
5754
|
+
json_output=json_output,
|
|
5755
|
+
verbose=verbose,
|
|
5756
|
+
)
|
|
5757
|
+
else:
|
|
5758
|
+
if target:
|
|
5759
|
+
target_type, target_id = _parse_target(target)
|
|
5760
|
+
if target_type == "workspace":
|
|
5761
|
+
profile_result = profile_workspace(target_id, options, verbose=verbose)
|
|
5762
|
+
else:
|
|
5763
|
+
profile_result = profile_remote_ssh(target_id, options, verbose=verbose)
|
|
5764
|
+
else:
|
|
5765
|
+
profile_result = profile_local(options, verbose=verbose)
|
|
5766
|
+
analysis_result = None
|
|
5767
|
+
|
|
5768
|
+
# Report results
|
|
5769
|
+
if not profile_result.success:
|
|
5770
|
+
typer.echo(f"Error: {profile_result.error}", err=True)
|
|
5771
|
+
if profile_result.stderr:
|
|
5772
|
+
typer.echo(f"stderr: {profile_result.stderr}", err=True)
|
|
5773
|
+
raise typer.Exit(1)
|
|
5774
|
+
|
|
5775
|
+
if verbose or not analyze:
|
|
5776
|
+
typer.echo(f"Profile created: {profile_result.output_path}")
|
|
5777
|
+
|
|
5778
|
+
if analysis_result:
|
|
5779
|
+
if not analysis_result.success:
|
|
5780
|
+
typer.echo(f"Analysis error: {analysis_result.error}", err=True)
|
|
5781
|
+
raise typer.Exit(1)
|
|
5782
|
+
|
|
5783
|
+
|
|
3886
5784
|
# =============================================================================
|
|
3887
5785
|
# ROCprof-Compute commands
|
|
3888
5786
|
# =============================================================================
|
|
@@ -4441,8 +6339,8 @@ def _setup_wafer_core_env() -> None:
|
|
|
4441
6339
|
- WAFER_API_URL: If already set, uses that instead of config
|
|
4442
6340
|
- WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
|
|
4443
6341
|
"""
|
|
4444
|
-
from .global_config import get_api_url
|
|
4445
6342
|
from .auth import get_valid_token
|
|
6343
|
+
from .global_config import get_api_url
|
|
4446
6344
|
|
|
4447
6345
|
# Set API URL (get_api_url already respects WAFER_API_URL env var)
|
|
4448
6346
|
os.environ["WAFER_API_URL"] = get_api_url()
|
|
@@ -4746,8 +6644,8 @@ def capture_command( # noqa: PLR0915
|
|
|
4746
6644
|
import os
|
|
4747
6645
|
import tomllib
|
|
4748
6646
|
|
|
4749
|
-
from .global_config import get_api_url
|
|
4750
6647
|
from .auth import get_valid_token
|
|
6648
|
+
from .global_config import get_api_url
|
|
4751
6649
|
|
|
4752
6650
|
# Set environment variables for wafer-core BEFORE importing it
|
|
4753
6651
|
# wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
|
|
@@ -4951,8 +6849,8 @@ def capture_list_command(
|
|
|
4951
6849
|
"""
|
|
4952
6850
|
import os
|
|
4953
6851
|
|
|
4954
|
-
from .global_config import get_api_url
|
|
4955
6852
|
from .auth import get_valid_token
|
|
6853
|
+
from .global_config import get_api_url
|
|
4956
6854
|
|
|
4957
6855
|
# Set environment variables for wafer-core BEFORE importing it
|
|
4958
6856
|
os.environ["WAFER_API_URL"] = get_api_url()
|
|
@@ -5015,13 +6913,14 @@ def capture_list_command(
|
|
|
5015
6913
|
|
|
5016
6914
|
@corpus_app.command("download")
|
|
5017
6915
|
def corpus_download(
|
|
5018
|
-
name: str = typer.Argument(..., help="Corpus name (cuda, cutlass, hip)"),
|
|
6916
|
+
name: str = typer.Argument(..., help="Corpus name (cuda, cutlass, hip, amd)"),
|
|
5019
6917
|
force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
|
|
5020
6918
|
) -> None:
|
|
5021
6919
|
"""Download a documentation corpus for agent filesystem access.
|
|
5022
6920
|
|
|
5023
6921
|
Examples:
|
|
5024
6922
|
wafer corpus download cuda
|
|
6923
|
+
wafer corpus download amd
|
|
5025
6924
|
wafer corpus download cutlass --force
|
|
5026
6925
|
"""
|
|
5027
6926
|
from .corpus import CORPORA, download_corpus
|
|
@@ -5236,71 +7135,107 @@ def tracelens_collective(
|
|
|
5236
7135
|
|
|
5237
7136
|
|
|
5238
7137
|
# =============================================================================
|
|
5239
|
-
# ISA Analysis Commands
|
|
7138
|
+
# Unified ISA Analysis Commands (wafer amd isa ...)
|
|
5240
7139
|
# =============================================================================
|
|
5241
7140
|
|
|
5242
7141
|
|
|
5243
7142
|
@isa_app.command("analyze")
|
|
5244
7143
|
def isa_analyze(
|
|
5245
|
-
|
|
5246
|
-
json_output: bool = typer.Option(False, "--json", help="Output as JSON"),
|
|
7144
|
+
path: Path = typer.Argument(..., help="Path to file or directory to analyze"),
|
|
7145
|
+
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
7146
|
+
csv_output: bool = typer.Option(False, "--csv", help="Output as CSV"),
|
|
7147
|
+
recursive: bool = typer.Option(
|
|
7148
|
+
True, "--recursive/--no-recursive", "-r", help="Scan directories recursively"
|
|
7149
|
+
),
|
|
7150
|
+
filter_expr: str | None = typer.Option(
|
|
7151
|
+
None, "--filter", "-f", help="Filter results (e.g., 'spills > 0')"
|
|
7152
|
+
),
|
|
7153
|
+
output_file: Path | None = typer.Option(None, "--output", "-o", help="Write output to file"),
|
|
7154
|
+
kernel_index: int = typer.Option(0, "--kernel", "-k", help="Kernel index if multiple in file"),
|
|
5247
7155
|
) -> None:
|
|
5248
|
-
"""Analyze AMD GPU
|
|
7156
|
+
"""Analyze AMD GPU ISA files (.co, .s, .ll, .ttgir).
|
|
5249
7157
|
|
|
5250
|
-
|
|
5251
|
-
spills, and
|
|
7158
|
+
Performs static analysis to extract performance metrics like register
|
|
7159
|
+
pressure, spills, MFMA density, and occupancy limits.
|
|
5252
7160
|
|
|
5253
|
-
|
|
5254
|
-
|
|
7161
|
+
Supports:
|
|
7162
|
+
- AMD GPU code objects (.co) - Requires API authentication
|
|
7163
|
+
- AMDGCN ISA assembly (.s, .gcn, .asm) - Local parsing
|
|
7164
|
+
- LLVM-IR files (.ll) - Local parsing
|
|
7165
|
+
- TTGIR files (.ttgir, .ttir, .mlir) - Local parsing
|
|
5255
7166
|
|
|
5256
7167
|
Examples:
|
|
5257
|
-
wafer isa analyze kernel.co
|
|
5258
|
-
wafer isa analyze kernel.
|
|
7168
|
+
wafer amd isa analyze kernel.co # Code object (needs login)
|
|
7169
|
+
wafer amd isa analyze kernel.s # ISA assembly
|
|
7170
|
+
wafer amd isa analyze kernel.s --json # Output as JSON
|
|
7171
|
+
wafer amd isa analyze ~/.triton/cache/ --filter 'spills > 0'
|
|
7172
|
+
wafer amd isa analyze . -r --csv -o metrics.csv
|
|
5259
7173
|
"""
|
|
5260
|
-
from dataclasses import asdict
|
|
5261
|
-
|
|
5262
|
-
from wafer_core.tools.isa_analysis_tools import analyze_isa, format_isa_summary
|
|
5263
|
-
|
|
5264
7174
|
from .auth import get_auth_headers
|
|
5265
7175
|
from .global_config import get_api_url
|
|
7176
|
+
from .kernel_scope import analyze_command
|
|
5266
7177
|
|
|
5267
|
-
#
|
|
5268
|
-
if not file.exists():
|
|
5269
|
-
typer.echo(f"Error: File not found: {file}", err=True)
|
|
5270
|
-
raise typer.Exit(1)
|
|
5271
|
-
|
|
5272
|
-
if not file.suffix == ".co":
|
|
5273
|
-
typer.echo(f"Error: Expected .co file, got: {file.suffix}", err=True)
|
|
5274
|
-
raise typer.Exit(1)
|
|
5275
|
-
|
|
5276
|
-
# Get API URL and auth
|
|
7178
|
+
# Get API credentials for .co files
|
|
5277
7179
|
api_url = get_api_url()
|
|
5278
7180
|
auth_headers = get_auth_headers()
|
|
5279
7181
|
|
|
5280
|
-
if not auth_headers:
|
|
5281
|
-
typer.echo("Error: Not logged in. Run 'wafer login' first.", err=True)
|
|
5282
|
-
raise typer.Exit(1)
|
|
5283
|
-
|
|
5284
7182
|
try:
|
|
5285
|
-
|
|
5286
|
-
|
|
7183
|
+
output = analyze_command(
|
|
7184
|
+
path=str(path),
|
|
7185
|
+
json_output=json_output,
|
|
7186
|
+
csv_output=csv_output,
|
|
7187
|
+
recursive=recursive,
|
|
7188
|
+
filter_expr=filter_expr,
|
|
7189
|
+
output_file=str(output_file) if output_file else None,
|
|
7190
|
+
kernel_index=kernel_index,
|
|
5287
7191
|
api_url=api_url,
|
|
5288
7192
|
auth_headers=auth_headers,
|
|
5289
7193
|
)
|
|
5290
|
-
|
|
5291
|
-
if json_output:
|
|
5292
|
-
typer.echo(json.dumps(asdict(result)))
|
|
5293
|
-
else:
|
|
5294
|
-
typer.echo(format_isa_summary(result))
|
|
7194
|
+
typer.echo(output)
|
|
5295
7195
|
|
|
5296
7196
|
except FileNotFoundError as e:
|
|
5297
7197
|
typer.echo(f"Error: {e}", err=True)
|
|
5298
7198
|
raise typer.Exit(1) from None
|
|
7199
|
+
except RuntimeError as e:
|
|
7200
|
+
typer.echo(f"Error: {e}", err=True)
|
|
7201
|
+
raise typer.Exit(1) from None
|
|
5299
7202
|
except Exception as e:
|
|
5300
7203
|
typer.echo(f"Error: {e}", err=True)
|
|
5301
7204
|
raise typer.Exit(1) from None
|
|
5302
7205
|
|
|
5303
7206
|
|
|
7207
|
+
@isa_app.command("metrics")
|
|
7208
|
+
def isa_metrics() -> None:
|
|
7209
|
+
"""List available metrics for ISA analysis.
|
|
7210
|
+
|
|
7211
|
+
Shows all metrics that can be extracted from AMD GPU ISA files,
|
|
7212
|
+
along with their derivation.
|
|
7213
|
+
|
|
7214
|
+
Examples:
|
|
7215
|
+
wafer amd isa metrics
|
|
7216
|
+
"""
|
|
7217
|
+
from .kernel_scope import metrics_command
|
|
7218
|
+
|
|
7219
|
+
output = metrics_command()
|
|
7220
|
+
typer.echo(output)
|
|
7221
|
+
|
|
7222
|
+
|
|
7223
|
+
@isa_app.command("targets")
|
|
7224
|
+
def isa_targets() -> None:
|
|
7225
|
+
"""List supported GPU targets and their specifications.
|
|
7226
|
+
|
|
7227
|
+
Shows hardware specs (VGPRs, SGPRs, LDS, etc.) for each supported
|
|
7228
|
+
AMD GPU architecture.
|
|
7229
|
+
|
|
7230
|
+
Examples:
|
|
7231
|
+
wafer amd isa targets
|
|
7232
|
+
"""
|
|
7233
|
+
from .kernel_scope import targets_command
|
|
7234
|
+
|
|
7235
|
+
output = targets_command()
|
|
7236
|
+
typer.echo(output)
|
|
7237
|
+
|
|
7238
|
+
|
|
5304
7239
|
def main() -> None:
|
|
5305
7240
|
"""Entry point for wafer CLI."""
|
|
5306
7241
|
app()
|