wafer-cli 0.2.14__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/auth.py +7 -0
- wafer/billing.py +6 -6
- wafer/cli.py +905 -131
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +313 -15
- wafer/evaluate.py +480 -146
- wafer/global_config.py +13 -0
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/specs_cli.py +157 -0
- wafer/ssh_keys.py +6 -6
- wafer/targets_cli.py +472 -0
- wafer/targets_ops.py +29 -2
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +3 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +274 -0
- wafer/wevin_cli.py +125 -26
- wafer/workspaces.py +163 -16
- wafer_cli-0.2.30.dist-info/METADATA +107 -0
- wafer_cli-0.2.30.dist-info/RECORD +47 -0
- wafer_cli-0.2.14.dist-info/METADATA +0 -16
- wafer_cli-0.2.14.dist-info/RECORD +0 -41
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
# ruff: noqa: PLR0913
|
|
1
|
+
# ruff: noqa: PLR0913, E402
|
|
2
2
|
# PLR0913 (too many arguments) is suppressed because Typer CLI commands
|
|
3
3
|
# naturally have many parameters - each --flag becomes a function argument.
|
|
4
|
+
# E402 (module level import not at top) is suppressed because we intentionally
|
|
5
|
+
# load .env files before importing other modules that may read env vars.
|
|
4
6
|
"""Wafer CLI - GPU development toolkit for LLM coding agents.
|
|
5
7
|
|
|
6
8
|
Core commands:
|
|
@@ -27,6 +29,12 @@ from pathlib import Path
|
|
|
27
29
|
|
|
28
30
|
import trio
|
|
29
31
|
import typer
|
|
32
|
+
from dotenv import load_dotenv
|
|
33
|
+
|
|
34
|
+
# Auto-load .env from current directory and ~/.wafer/.env
|
|
35
|
+
# This runs at import time so env vars are available before any config is accessed
|
|
36
|
+
load_dotenv() # cwd/.env
|
|
37
|
+
load_dotenv(Path.home() / ".wafer" / ".env") # ~/.wafer/.env
|
|
30
38
|
|
|
31
39
|
from .config import WaferConfig, WaferEnvironment
|
|
32
40
|
from .inference import infer_upload_files, resolve_environment
|
|
@@ -42,6 +50,7 @@ from .problems import (
|
|
|
42
50
|
app = typer.Typer(
|
|
43
51
|
help="GPU development toolkit for LLM coding agents",
|
|
44
52
|
no_args_is_help=True,
|
|
53
|
+
pretty_exceptions_show_locals=False, # Don't dump local vars (makes tracebacks huge)
|
|
45
54
|
)
|
|
46
55
|
|
|
47
56
|
# =============================================================================
|
|
@@ -58,11 +67,11 @@ def _show_version() -> None:
|
|
|
58
67
|
"""Show CLI version and environment, then exit."""
|
|
59
68
|
from .analytics import _get_cli_version
|
|
60
69
|
from .global_config import load_global_config
|
|
61
|
-
|
|
70
|
+
|
|
62
71
|
version = _get_cli_version()
|
|
63
72
|
config = load_global_config()
|
|
64
73
|
environment = config.environment
|
|
65
|
-
|
|
74
|
+
|
|
66
75
|
typer.echo(f"wafer-cli {version} ({environment})")
|
|
67
76
|
raise typer.Exit()
|
|
68
77
|
|
|
@@ -110,7 +119,7 @@ def main_callback(
|
|
|
110
119
|
if version:
|
|
111
120
|
_show_version()
|
|
112
121
|
return
|
|
113
|
-
|
|
122
|
+
|
|
114
123
|
global _command_start_time, _command_outcome
|
|
115
124
|
_command_start_time = time.time()
|
|
116
125
|
_command_outcome = "success" # Default to success, mark failure on exceptions
|
|
@@ -121,6 +130,7 @@ def main_callback(
|
|
|
121
130
|
analytics.init_analytics()
|
|
122
131
|
|
|
123
132
|
# Install exception hook to catch SystemExit and mark failures
|
|
133
|
+
# Also prints error message FIRST so it's visible even when traceback is truncated
|
|
124
134
|
original_excepthook = sys.excepthook
|
|
125
135
|
|
|
126
136
|
def custom_excepthook(
|
|
@@ -136,7 +146,11 @@ def main_callback(
|
|
|
136
146
|
_command_outcome = "failure"
|
|
137
147
|
else:
|
|
138
148
|
_command_outcome = "failure"
|
|
139
|
-
|
|
149
|
+
# Print error summary FIRST (before traceback) so it's visible even if truncated
|
|
150
|
+
print(
|
|
151
|
+
f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
|
|
152
|
+
)
|
|
153
|
+
# Call original excepthook (prints the full traceback)
|
|
140
154
|
original_excepthook(exc_type, exc_value, exc_traceback)
|
|
141
155
|
|
|
142
156
|
sys.excepthook = custom_excepthook
|
|
@@ -180,11 +194,16 @@ def complete_target_name(incomplete: str) -> list[str]:
|
|
|
180
194
|
|
|
181
195
|
# =============================================================================
|
|
182
196
|
# Core subcommand groups (visible in --help)
|
|
197
|
+
#
|
|
198
|
+
# TODO: Further consolidate top-level commands to reduce --help surface area.
|
|
199
|
+
# Candidates:
|
|
200
|
+
# - compare → wafer nvidia compare or keep top-level (cross-platform)
|
|
201
|
+
# - guide/skill/demo → wafer onboard {guide,skill,demo}
|
|
183
202
|
# =============================================================================
|
|
184
203
|
|
|
185
204
|
# Config management (includes targets as nested subcommand)
|
|
186
205
|
config_app = typer.Typer(help="Manage CLI configuration and local GPU targets")
|
|
187
|
-
app.add_typer(config_app, name="config")
|
|
206
|
+
app.add_typer(config_app, name="config", rich_help_panel="Configuration")
|
|
188
207
|
|
|
189
208
|
# Target management - nested under config
|
|
190
209
|
targets_app = typer.Typer(
|
|
@@ -204,7 +223,7 @@ config_app.add_typer(targets_app, name="targets")
|
|
|
204
223
|
workspaces_app = typer.Typer(
|
|
205
224
|
help="""Manage cloud GPU workspaces for remote development.
|
|
206
225
|
|
|
207
|
-
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
|
|
226
|
+
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer auth login).
|
|
208
227
|
|
|
209
228
|
Available GPUs:
|
|
210
229
|
MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
@@ -217,21 +236,21 @@ Commands:
|
|
|
217
236
|
wafer workspaces sync dev ./project # Sync files
|
|
218
237
|
wafer workspaces delete dev # Clean up"""
|
|
219
238
|
)
|
|
220
|
-
app.add_typer(workspaces_app, name="workspaces")
|
|
239
|
+
app.add_typer(workspaces_app, name="workspaces", rich_help_panel="Infrastructure")
|
|
221
240
|
|
|
222
|
-
# SSH Key management (BYOK - Bring Your Own Key)
|
|
241
|
+
# SSH Key management (BYOK - Bring Your Own Key) - nested under config
|
|
223
242
|
ssh_keys_app = typer.Typer(
|
|
224
243
|
help="""Manage SSH public keys for workspace access.
|
|
225
244
|
|
|
226
245
|
Register your SSH public keys here. These keys are installed in all workspaces
|
|
227
246
|
you provision, enabling SSH access from any machine with your private key.
|
|
228
247
|
|
|
229
|
-
wafer ssh-keys list # List registered keys
|
|
230
|
-
wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
|
|
231
|
-
wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
|
|
232
|
-
wafer ssh-keys remove <key-id> # Remove a key"""
|
|
248
|
+
wafer config ssh-keys list # List registered keys
|
|
249
|
+
wafer config ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
|
|
250
|
+
wafer config ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
|
|
251
|
+
wafer config ssh-keys remove <key-id> # Remove a key"""
|
|
233
252
|
)
|
|
234
|
-
|
|
253
|
+
config_app.add_typer(ssh_keys_app, name="ssh-keys")
|
|
235
254
|
|
|
236
255
|
# Target operations (exec/ssh/sync on configured targets)
|
|
237
256
|
targets_ops_app = typer.Typer(
|
|
@@ -247,22 +266,48 @@ Useful for exploratory work, debugging, or custom scripts.
|
|
|
247
266
|
Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
|
|
248
267
|
Configure targets with: wafer config targets init ..."""
|
|
249
268
|
)
|
|
250
|
-
app.add_typer(targets_ops_app, name="targets")
|
|
269
|
+
app.add_typer(targets_ops_app, name="targets", rich_help_panel="Infrastructure")
|
|
270
|
+
|
|
271
|
+
# Specs management (new: local TOML configs)
|
|
272
|
+
from wafer.specs_cli import specs_app
|
|
273
|
+
|
|
274
|
+
app.add_typer(specs_app, name="specs", rich_help_panel="Configuration")
|
|
275
|
+
|
|
276
|
+
# Live resource management (new: API-backed commands on `wafer targets`)
|
|
277
|
+
# These become: wafer targets list, wafer targets terminate, etc.
|
|
278
|
+
from wafer.targets_cli import (
|
|
279
|
+
targets_list as _targets_list_cmd,
|
|
280
|
+
)
|
|
281
|
+
from wafer.targets_cli import (
|
|
282
|
+
targets_provision as _targets_provision_cmd,
|
|
283
|
+
)
|
|
284
|
+
from wafer.targets_cli import (
|
|
285
|
+
targets_reconcile as _targets_reconcile_cmd,
|
|
286
|
+
)
|
|
287
|
+
from wafer.targets_cli import (
|
|
288
|
+
targets_terminate as _targets_terminate_cmd,
|
|
289
|
+
)
|
|
290
|
+
from wafer.targets_cli import (
|
|
291
|
+
targets_pools as _targets_pools_cmd,
|
|
292
|
+
)
|
|
293
|
+
from wafer.targets_cli import (
|
|
294
|
+
targets_probe as _targets_probe_cmd,
|
|
295
|
+
)
|
|
251
296
|
|
|
252
|
-
# Billing management
|
|
297
|
+
# Billing management - nested under config
|
|
253
298
|
billing_app = typer.Typer(help="Manage billing, credits, and subscription")
|
|
254
|
-
|
|
299
|
+
config_app.add_typer(billing_app, name="billing")
|
|
255
300
|
|
|
256
301
|
# Corpus management
|
|
257
302
|
corpus_app = typer.Typer(help="Download and manage GPU documentation")
|
|
258
|
-
app.add_typer(corpus_app, name="corpus")
|
|
303
|
+
app.add_typer(corpus_app, name="corpus", rich_help_panel="Kernel Development")
|
|
259
304
|
|
|
260
305
|
# Evaluate (supports multiple kernel formats)
|
|
261
306
|
evaluate_app = typer.Typer(
|
|
262
307
|
help="Test kernel correctness and performance",
|
|
263
308
|
invoke_without_command=True,
|
|
264
309
|
)
|
|
265
|
-
app.add_typer(evaluate_app, name="evaluate")
|
|
310
|
+
app.add_typer(evaluate_app, name="evaluate", rich_help_panel="Kernel Development")
|
|
266
311
|
|
|
267
312
|
# Nested subcommand for kernelbench format
|
|
268
313
|
kernelbench_app = typer.Typer(
|
|
@@ -291,7 +336,7 @@ app.add_typer(dev_app, name="dev")
|
|
|
291
336
|
# =============================================================================
|
|
292
337
|
|
|
293
338
|
nvidia_app = typer.Typer(help="NVIDIA GPU profiling and analysis tools")
|
|
294
|
-
app.add_typer(nvidia_app, name="nvidia")
|
|
339
|
+
app.add_typer(nvidia_app, name="nvidia", rich_help_panel="Profiling")
|
|
295
340
|
|
|
296
341
|
# NCU analysis - under nvidia
|
|
297
342
|
ncu_app = typer.Typer(help="Nsight Compute profile analysis")
|
|
@@ -314,18 +359,25 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
|
|
|
314
359
|
# =============================================================================
|
|
315
360
|
|
|
316
361
|
amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
|
|
317
|
-
app.add_typer(amd_app, name="amd")
|
|
362
|
+
app.add_typer(amd_app, name="amd", rich_help_panel="Profiling")
|
|
318
363
|
|
|
319
364
|
# Unified ISA Analyzer - supports both .co files and Triton artifacts
|
|
320
365
|
isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
|
|
321
366
|
amd_app.add_typer(isa_app, name="isa")
|
|
322
367
|
|
|
368
|
+
# =============================================================================
|
|
369
|
+
# Trace comparison (wafer compare)
|
|
370
|
+
# =============================================================================
|
|
371
|
+
|
|
372
|
+
compare_app = typer.Typer(help="Compare GPU traces across platforms (AMD vs NVIDIA)")
|
|
373
|
+
app.add_typer(compare_app, name="compare", rich_help_panel="Profiling")
|
|
374
|
+
|
|
323
375
|
# =============================================================================
|
|
324
376
|
# Roofline analysis (wafer roofline)
|
|
325
377
|
# =============================================================================
|
|
326
378
|
|
|
327
379
|
|
|
328
|
-
@app.command("roofline")
|
|
380
|
+
@app.command("roofline", rich_help_panel="Kernel Development")
|
|
329
381
|
def roofline_cmd(
|
|
330
382
|
gpu: str | None = typer.Option(
|
|
331
383
|
None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
|
|
@@ -416,7 +468,7 @@ def roofline_cmd(
|
|
|
416
468
|
# =============================================================================
|
|
417
469
|
|
|
418
470
|
skill_app = typer.Typer(help="Manage AI coding assistant skills (Claude Code, Codex)")
|
|
419
|
-
app.add_typer(skill_app, name="skill")
|
|
471
|
+
app.add_typer(skill_app, name="skill", rich_help_panel="Onboarding")
|
|
420
472
|
|
|
421
473
|
|
|
422
474
|
@skill_app.command("install")
|
|
@@ -580,18 +632,23 @@ def skill_status() -> None:
|
|
|
580
632
|
|
|
581
633
|
|
|
582
634
|
# =============================================================================
|
|
583
|
-
#
|
|
635
|
+
# Authentication (wafer auth ...)
|
|
584
636
|
# =============================================================================
|
|
585
637
|
|
|
586
|
-
|
|
587
|
-
app.add_typer(
|
|
638
|
+
auth_app = typer.Typer(help="Authenticate with Wafer and cloud GPU providers")
|
|
639
|
+
app.add_typer(auth_app, name="auth", rich_help_panel="Configuration")
|
|
588
640
|
|
|
641
|
+
providers_app = typer.Typer(
|
|
642
|
+
help="Manage API keys for cloud GPU providers (RunPod, DigitalOcean, etc.)"
|
|
643
|
+
)
|
|
644
|
+
auth_app.add_typer(providers_app, name="providers")
|
|
589
645
|
|
|
590
|
-
|
|
646
|
+
|
|
647
|
+
@providers_app.command("login")
|
|
591
648
|
def provider_auth_login(
|
|
592
649
|
provider: str = typer.Argument(
|
|
593
650
|
...,
|
|
594
|
-
help="Provider name: runpod, digitalocean, or
|
|
651
|
+
help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
|
|
595
652
|
),
|
|
596
653
|
api_key: str | None = typer.Option(
|
|
597
654
|
None,
|
|
@@ -600,15 +657,16 @@ def provider_auth_login(
|
|
|
600
657
|
help="API key (if not provided, reads from stdin)",
|
|
601
658
|
),
|
|
602
659
|
) -> None:
|
|
603
|
-
"""Save API key for a
|
|
660
|
+
"""Save API key for a provider.
|
|
604
661
|
|
|
605
662
|
Stores the key in ~/.wafer/auth.json. Environment variables
|
|
606
|
-
(e.g.,
|
|
663
|
+
(e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
|
|
607
664
|
|
|
608
665
|
Examples:
|
|
609
|
-
wafer auth login
|
|
610
|
-
wafer auth login
|
|
611
|
-
|
|
666
|
+
wafer auth providers login anthropic --api-key sk-ant-xxx
|
|
667
|
+
wafer auth providers login runpod --api-key rp_xxx
|
|
668
|
+
wafer auth providers login openai --api-key sk-xxx
|
|
669
|
+
echo $API_KEY | wafer auth providers login anthropic
|
|
612
670
|
"""
|
|
613
671
|
import sys
|
|
614
672
|
|
|
@@ -638,18 +696,18 @@ def provider_auth_login(
|
|
|
638
696
|
typer.echo("Stored in: ~/.wafer/auth.json")
|
|
639
697
|
|
|
640
698
|
|
|
641
|
-
@
|
|
699
|
+
@providers_app.command("logout")
|
|
642
700
|
def provider_auth_logout(
|
|
643
701
|
provider: str = typer.Argument(
|
|
644
702
|
...,
|
|
645
|
-
help="Provider name: runpod, digitalocean, or
|
|
703
|
+
help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
|
|
646
704
|
),
|
|
647
705
|
) -> None:
|
|
648
706
|
"""Remove stored API key for a cloud GPU provider.
|
|
649
707
|
|
|
650
708
|
Examples:
|
|
651
|
-
wafer auth logout runpod
|
|
652
|
-
wafer auth logout digitalocean
|
|
709
|
+
wafer auth providers logout runpod
|
|
710
|
+
wafer auth providers logout digitalocean
|
|
653
711
|
"""
|
|
654
712
|
from wafer_core.auth import PROVIDERS, remove_api_key
|
|
655
713
|
|
|
@@ -665,7 +723,7 @@ def provider_auth_logout(
|
|
|
665
723
|
typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
|
|
666
724
|
|
|
667
725
|
|
|
668
|
-
@
|
|
726
|
+
@providers_app.command("status")
|
|
669
727
|
def provider_auth_status() -> None:
|
|
670
728
|
"""Show authentication status for all cloud GPU providers.
|
|
671
729
|
|
|
@@ -673,7 +731,7 @@ def provider_auth_status() -> None:
|
|
|
673
731
|
the keys are coming from (environment variable or auth.json).
|
|
674
732
|
|
|
675
733
|
Example:
|
|
676
|
-
wafer auth status
|
|
734
|
+
wafer auth providers status
|
|
677
735
|
"""
|
|
678
736
|
from wafer_core.auth import get_all_auth_status
|
|
679
737
|
|
|
@@ -688,7 +746,7 @@ def provider_auth_status() -> None:
|
|
|
688
746
|
typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
|
|
689
747
|
else:
|
|
690
748
|
typer.echo(f" {status.display_name}: ✗ Not configured")
|
|
691
|
-
typer.echo(f" Run: wafer auth login {status.provider}")
|
|
749
|
+
typer.echo(f" Run: wafer auth providers login {status.provider}")
|
|
692
750
|
typer.echo(f" Or set: {status.key_url}")
|
|
693
751
|
|
|
694
752
|
typer.echo("")
|
|
@@ -1233,7 +1291,7 @@ def config_show_legacy() -> None:
|
|
|
1233
1291
|
config_show_new()
|
|
1234
1292
|
|
|
1235
1293
|
|
|
1236
|
-
@app.command()
|
|
1294
|
+
@app.command(rich_help_panel="Kernel Development")
|
|
1237
1295
|
def agent( # noqa: PLR0913
|
|
1238
1296
|
prompt: str | None = typer.Argument(
|
|
1239
1297
|
None,
|
|
@@ -1303,7 +1361,7 @@ def agent( # noqa: PLR0913
|
|
|
1303
1361
|
None,
|
|
1304
1362
|
"--model",
|
|
1305
1363
|
"-m",
|
|
1306
|
-
help="Model override (default: claude-
|
|
1364
|
+
help="Model override (default: claude-opus-4-5)",
|
|
1307
1365
|
),
|
|
1308
1366
|
json_output: bool = typer.Option(
|
|
1309
1367
|
False,
|
|
@@ -1327,6 +1385,16 @@ def agent( # noqa: PLR0913
|
|
|
1327
1385
|
"-c",
|
|
1328
1386
|
help="Documentation corpus to use (cuda, cutlass, hip, amd). Must be downloaded first.",
|
|
1329
1387
|
),
|
|
1388
|
+
no_sandbox: bool = typer.Option(
|
|
1389
|
+
False,
|
|
1390
|
+
"--no-sandbox",
|
|
1391
|
+
help="Disable OS-level sandboxing (YOU accept liability for any damage caused by the agent)",
|
|
1392
|
+
),
|
|
1393
|
+
no_proxy: bool = typer.Option(
|
|
1394
|
+
False,
|
|
1395
|
+
"--no-proxy",
|
|
1396
|
+
help="Skip wafer proxy, use ANTHROPIC_API_KEY directly",
|
|
1397
|
+
),
|
|
1330
1398
|
) -> None:
|
|
1331
1399
|
"""AI assistant for GPU kernel development.
|
|
1332
1400
|
|
|
@@ -1408,6 +1476,13 @@ def agent( # noqa: PLR0913
|
|
|
1408
1476
|
raise typer.Exit(1) from None
|
|
1409
1477
|
corpus_path = str(path)
|
|
1410
1478
|
|
|
1479
|
+
# Warn user about sandbox disabled
|
|
1480
|
+
if no_sandbox:
|
|
1481
|
+
print(
|
|
1482
|
+
"Warning: Sandbox disabled. You accept liability for any damage caused by the agent.",
|
|
1483
|
+
file=sys.stderr,
|
|
1484
|
+
)
|
|
1485
|
+
|
|
1411
1486
|
wevin_main(
|
|
1412
1487
|
prompt=actual_prompt,
|
|
1413
1488
|
interactive=use_tui,
|
|
@@ -1425,6 +1500,8 @@ def agent( # noqa: PLR0913
|
|
|
1425
1500
|
template=template,
|
|
1426
1501
|
template_args=parsed_template_args,
|
|
1427
1502
|
corpus_path=corpus_path,
|
|
1503
|
+
no_sandbox=no_sandbox,
|
|
1504
|
+
no_proxy=no_proxy,
|
|
1428
1505
|
)
|
|
1429
1506
|
|
|
1430
1507
|
|
|
@@ -1455,6 +1532,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1455
1532
|
template: str | None = typer.Option(None, "--template", "-t"),
|
|
1456
1533
|
template_args: list[str] | None = typer.Option(None, "--args"),
|
|
1457
1534
|
corpus: str | None = typer.Option(None, "--corpus"),
|
|
1535
|
+
no_sandbox: bool = typer.Option(False, "--no-sandbox"),
|
|
1458
1536
|
) -> None:
|
|
1459
1537
|
agent(
|
|
1460
1538
|
prompt=prompt,
|
|
@@ -1474,6 +1552,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1474
1552
|
template=template,
|
|
1475
1553
|
template_args=template_args,
|
|
1476
1554
|
corpus=corpus,
|
|
1555
|
+
no_sandbox=no_sandbox,
|
|
1477
1556
|
)
|
|
1478
1557
|
|
|
1479
1558
|
alias_cmd.__doc__ = doc
|
|
@@ -1497,7 +1576,11 @@ def evaluate( # noqa: PLR0913
|
|
|
1497
1576
|
None, "--reference", help="Path to reference kernel file"
|
|
1498
1577
|
),
|
|
1499
1578
|
test_cases: Path | None = typer.Option(
|
|
1500
|
-
None,
|
|
1579
|
+
None,
|
|
1580
|
+
"--test-cases",
|
|
1581
|
+
help="Path to test cases JSON file. "
|
|
1582
|
+
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
1583
|
+
"Run 'wafer evaluate make-template' to generate an example.",
|
|
1501
1584
|
),
|
|
1502
1585
|
target: str | None = typer.Option(
|
|
1503
1586
|
None,
|
|
@@ -1527,20 +1610,20 @@ def evaluate( # noqa: PLR0913
|
|
|
1527
1610
|
|
|
1528
1611
|
Examples:
|
|
1529
1612
|
# Basic correctness check
|
|
1530
|
-
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json
|
|
1613
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
|
|
1531
1614
|
|
|
1532
1615
|
# With benchmarking on a specific target
|
|
1533
|
-
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1616
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1534
1617
|
--target vultr-b200 --benchmark
|
|
1535
1618
|
|
|
1536
1619
|
# Full evaluation with defensive timing (detects cheating)
|
|
1537
|
-
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1620
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1538
1621
|
--benchmark --defensive
|
|
1539
1622
|
|
|
1540
1623
|
Subcommands:
|
|
1541
1624
|
gpumode Use GPUMode format (functional) - RECOMMENDED
|
|
1542
1625
|
kernelbench Use KernelBench format (ModelNew class)
|
|
1543
|
-
make-template Generate template files for this format
|
|
1626
|
+
make-template Generate template files for this format
|
|
1544
1627
|
"""
|
|
1545
1628
|
# If a subcommand is being invoked, skip the main evaluation logic
|
|
1546
1629
|
if ctx.invoked_subcommand is not None:
|
|
@@ -1694,7 +1777,7 @@ def evaluate_make_template(
|
|
|
1694
1777
|
typer.echo(f" 2. Edit {output_dir / 'reference.py'} with the ground truth + input generator")
|
|
1695
1778
|
typer.echo(f" 3. Edit {output_dir / 'test_cases.json'} with your test parameters")
|
|
1696
1779
|
typer.echo(" 4. Run:")
|
|
1697
|
-
typer.echo(f" wafer evaluate --impl {output_dir / 'kernel.py'} \\")
|
|
1780
|
+
typer.echo(f" wafer evaluate gpumode --impl {output_dir / 'kernel.py'} \\")
|
|
1698
1781
|
typer.echo(f" --reference {output_dir / 'reference.py'} \\")
|
|
1699
1782
|
typer.echo(f" --test-cases {output_dir / 'test_cases.json'} --benchmark")
|
|
1700
1783
|
|
|
@@ -1758,6 +1841,93 @@ def kernelbench_list_problems() -> None:
|
|
|
1758
1841
|
raise typer.Exit(1) from None
|
|
1759
1842
|
|
|
1760
1843
|
|
|
1844
|
+
def _resolve_pool_query(pool: str, collector) -> tuple[str, object]:
|
|
1845
|
+
"""Resolve a PoolQuery pool to a target spec name + lock context.
|
|
1846
|
+
|
|
1847
|
+
Queries live providers, matches by pool query, locks one target,
|
|
1848
|
+
returns (spec_name, lock_context) for the evaluator.
|
|
1849
|
+
"""
|
|
1850
|
+
import trio
|
|
1851
|
+
from wafer_core.targets.pool import resolve_pool
|
|
1852
|
+
|
|
1853
|
+
from .target_lock import acquire_from_pool
|
|
1854
|
+
|
|
1855
|
+
matched_targets = trio.run(resolve_pool, pool)
|
|
1856
|
+
|
|
1857
|
+
if not matched_targets:
|
|
1858
|
+
collector.set_error("pool", "NoMatchingTargets", pool=pool)
|
|
1859
|
+
collector.finalize()
|
|
1860
|
+
raise typer.Exit(1)
|
|
1861
|
+
|
|
1862
|
+
# Filter to targets with a spec (evaluator needs spec fields)
|
|
1863
|
+
spec_targets = [t for t in matched_targets if t.spec_name]
|
|
1864
|
+
if not spec_targets:
|
|
1865
|
+
collector.set_error(
|
|
1866
|
+
"pool", "NoSpecTargets", pool=pool,
|
|
1867
|
+
message="Matched targets have no spec binding — evaluator needs spec fields",
|
|
1868
|
+
)
|
|
1869
|
+
collector.finalize()
|
|
1870
|
+
raise typer.Exit(1)
|
|
1871
|
+
|
|
1872
|
+
# Lock one by resource_id
|
|
1873
|
+
resource_ids = [t.resource_id for t in spec_targets]
|
|
1874
|
+
collector.emit("pool_acquire", pool=pool, count=len(resource_ids))
|
|
1875
|
+
|
|
1876
|
+
lock_ctx = acquire_from_pool(resource_ids)
|
|
1877
|
+
acquired_id = lock_ctx.__enter__()
|
|
1878
|
+
|
|
1879
|
+
if acquired_id is None:
|
|
1880
|
+
lock_ctx.__exit__(None, None, None)
|
|
1881
|
+
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=resource_ids)
|
|
1882
|
+
collector.finalize()
|
|
1883
|
+
raise typer.Exit(1)
|
|
1884
|
+
|
|
1885
|
+
# Map resource_id back to spec_name
|
|
1886
|
+
acquired_target = next(t for t in spec_targets if t.resource_id == acquired_id)
|
|
1887
|
+
spec_name = acquired_target.spec_name
|
|
1888
|
+
|
|
1889
|
+
collector.emit("pool_acquired", target=spec_name, resource_id=acquired_id)
|
|
1890
|
+
return spec_name, lock_ctx
|
|
1891
|
+
|
|
1892
|
+
|
|
1893
|
+
def _resolve_pool_legacy(pool: str, collector) -> tuple[str, object]:
|
|
1894
|
+
"""Resolve an old-style pool (static target name list) to a target name + lock context.
|
|
1895
|
+
|
|
1896
|
+
Old format: [pools.name] targets = ["t1", "t2"]
|
|
1897
|
+
"""
|
|
1898
|
+
from .target_lock import acquire_from_pool
|
|
1899
|
+
from .targets import filter_pool_by_auth, get_pool
|
|
1900
|
+
|
|
1901
|
+
try:
|
|
1902
|
+
pool_targets = get_pool(pool)
|
|
1903
|
+
except FileNotFoundError as e:
|
|
1904
|
+
collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
|
|
1905
|
+
collector.finalize()
|
|
1906
|
+
raise typer.Exit(1) from None
|
|
1907
|
+
|
|
1908
|
+
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1909
|
+
if skipped:
|
|
1910
|
+
collector.emit("pool_auth_skip", targets=skipped)
|
|
1911
|
+
|
|
1912
|
+
if not usable_targets:
|
|
1913
|
+
collector.set_error("pool", "NoUsableTargets", pool=pool)
|
|
1914
|
+
collector.finalize()
|
|
1915
|
+
raise typer.Exit(1) from None
|
|
1916
|
+
|
|
1917
|
+
collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
|
|
1918
|
+
lock_ctx = acquire_from_pool(usable_targets)
|
|
1919
|
+
acquired_target = lock_ctx.__enter__()
|
|
1920
|
+
|
|
1921
|
+
if acquired_target is None:
|
|
1922
|
+
lock_ctx.__exit__(None, None, None)
|
|
1923
|
+
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
|
|
1924
|
+
collector.finalize()
|
|
1925
|
+
raise typer.Exit(1)
|
|
1926
|
+
|
|
1927
|
+
collector.emit("pool_acquired", target=acquired_target)
|
|
1928
|
+
return acquired_target, lock_ctx
|
|
1929
|
+
|
|
1930
|
+
|
|
1761
1931
|
@kernelbench_app.callback(invoke_without_command=True)
|
|
1762
1932
|
def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
1763
1933
|
ctx: typer.Context,
|
|
@@ -1888,39 +2058,12 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
1888
2058
|
pool_lock_context = None
|
|
1889
2059
|
|
|
1890
2060
|
if pool:
|
|
1891
|
-
from .
|
|
1892
|
-
from .targets import filter_pool_by_auth, get_pool
|
|
1893
|
-
|
|
1894
|
-
try:
|
|
1895
|
-
pool_targets = get_pool(pool)
|
|
1896
|
-
except FileNotFoundError as e:
|
|
1897
|
-
collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
|
|
1898
|
-
collector.finalize()
|
|
1899
|
-
raise typer.Exit(1) from None
|
|
1900
|
-
|
|
1901
|
-
# Filter to only targets with valid auth
|
|
1902
|
-
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1903
|
-
if skipped:
|
|
1904
|
-
collector.emit("pool_auth_skip", targets=skipped)
|
|
1905
|
-
|
|
1906
|
-
if not usable_targets:
|
|
1907
|
-
collector.set_error("pool", "NoUsableTargets", pool=pool)
|
|
1908
|
-
collector.finalize()
|
|
1909
|
-
raise typer.Exit(1) from None
|
|
2061
|
+
from wafer_core.targets.pool import is_query_pool
|
|
1910
2062
|
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
if acquired_target is None:
|
|
1916
|
-
# Exit context manager before raising to avoid resource leak
|
|
1917
|
-
pool_lock_context.__exit__(None, None, None)
|
|
1918
|
-
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
|
|
1919
|
-
collector.finalize()
|
|
1920
|
-
raise typer.Exit(1)
|
|
1921
|
-
|
|
1922
|
-
collector.emit("pool_acquired", target=acquired_target)
|
|
1923
|
-
resolved_target = acquired_target
|
|
2063
|
+
if is_query_pool(pool):
|
|
2064
|
+
resolved_target, pool_lock_context = _resolve_pool_query(pool, collector)
|
|
2065
|
+
else:
|
|
2066
|
+
resolved_target, pool_lock_context = _resolve_pool_legacy(pool, collector)
|
|
1924
2067
|
|
|
1925
2068
|
collector.target = resolved_target
|
|
1926
2069
|
|
|
@@ -2245,7 +2388,11 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2245
2388
|
None, "--reference", help="Path to reference kernel file"
|
|
2246
2389
|
),
|
|
2247
2390
|
test_cases: Path | None = typer.Option(
|
|
2248
|
-
None,
|
|
2391
|
+
None,
|
|
2392
|
+
"--test-cases",
|
|
2393
|
+
help="Path to test cases JSON file. "
|
|
2394
|
+
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
2395
|
+
"Run 'wafer evaluate make-template' to generate an example.",
|
|
2249
2396
|
),
|
|
2250
2397
|
target: str | None = typer.Option(
|
|
2251
2398
|
None,
|
|
@@ -2313,6 +2460,13 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2313
2460
|
err=True,
|
|
2314
2461
|
)
|
|
2315
2462
|
typer.echo("", err=True)
|
|
2463
|
+
if "--test-cases" in missing_args:
|
|
2464
|
+
typer.echo(
|
|
2465
|
+
"Tip: Run 'wafer evaluate make-template' to generate template files "
|
|
2466
|
+
"including test_cases.json.",
|
|
2467
|
+
err=True,
|
|
2468
|
+
)
|
|
2469
|
+
typer.echo("", err=True)
|
|
2316
2470
|
typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
|
|
2317
2471
|
typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
|
|
2318
2472
|
raise typer.Exit(1)
|
|
@@ -2719,7 +2873,7 @@ def remote_run( # noqa: PLR0913
|
|
|
2719
2873
|
# =============================================================================
|
|
2720
2874
|
|
|
2721
2875
|
|
|
2722
|
-
@
|
|
2876
|
+
@auth_app.command("login")
|
|
2723
2877
|
def login(
|
|
2724
2878
|
token: str | None = typer.Option(
|
|
2725
2879
|
None, "--token", "-t", help="Access token (skip browser OAuth)"
|
|
@@ -2744,7 +2898,7 @@ def login(
|
|
|
2744
2898
|
Uses the API environment from config (see 'wafer config show').
|
|
2745
2899
|
|
|
2746
2900
|
SSH Users (Easiest):
|
|
2747
|
-
- Just run: wafer login
|
|
2901
|
+
- Just run: wafer auth login
|
|
2748
2902
|
- Visit the URL and enter the code shown
|
|
2749
2903
|
- No port forwarding needed!
|
|
2750
2904
|
|
|
@@ -2754,17 +2908,17 @@ def login(
|
|
|
2754
2908
|
|
|
2755
2909
|
Manual token option:
|
|
2756
2910
|
- Visit auth.wafer.ai, authenticate, copy token from URL
|
|
2757
|
-
- Run: wafer login --token <paste-token>
|
|
2911
|
+
- Run: wafer auth login --token <paste-token>
|
|
2758
2912
|
|
|
2759
2913
|
Examples:
|
|
2760
|
-
wafer login # device code on SSH, browser on local
|
|
2761
|
-
wafer login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2762
|
-
wafer login --port 9000 # custom port for browser flow
|
|
2763
|
-
wafer login --token xyz # manual token (no browser)
|
|
2914
|
+
wafer auth login # device code on SSH, browser on local
|
|
2915
|
+
wafer auth login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2916
|
+
wafer auth login --port 9000 # custom port for browser flow
|
|
2917
|
+
wafer auth login --token xyz # manual token (no browser)
|
|
2764
2918
|
|
|
2765
2919
|
# Change environment:
|
|
2766
2920
|
wafer config set api.environment staging
|
|
2767
|
-
wafer login
|
|
2921
|
+
wafer auth login
|
|
2768
2922
|
"""
|
|
2769
2923
|
import httpx
|
|
2770
2924
|
|
|
@@ -2848,7 +3002,7 @@ def login(
|
|
|
2848
3002
|
typer.echo("Token saved to ~/.wafer/credentials.json")
|
|
2849
3003
|
|
|
2850
3004
|
|
|
2851
|
-
@
|
|
3005
|
+
@auth_app.command("logout")
|
|
2852
3006
|
def logout() -> None:
|
|
2853
3007
|
"""Remove stored credentials."""
|
|
2854
3008
|
from . import analytics
|
|
@@ -2865,7 +3019,7 @@ def logout() -> None:
|
|
|
2865
3019
|
typer.echo("Not logged in (no credentials found).")
|
|
2866
3020
|
|
|
2867
3021
|
|
|
2868
|
-
@
|
|
3022
|
+
@auth_app.command("whoami")
|
|
2869
3023
|
def whoami(
|
|
2870
3024
|
verify: bool = typer.Option(False, "--verify", "-v", help="Verify token with API"),
|
|
2871
3025
|
refresh: bool = typer.Option(False, "--refresh", "-r", help="Refresh token if expired"),
|
|
@@ -2879,7 +3033,7 @@ def whoami(
|
|
|
2879
3033
|
|
|
2880
3034
|
creds = load_credentials()
|
|
2881
3035
|
if creds is None:
|
|
2882
|
-
typer.echo("Not logged in. Run: wafer login")
|
|
3036
|
+
typer.echo("Not logged in. Run: wafer auth login")
|
|
2883
3037
|
raise typer.Exit(1)
|
|
2884
3038
|
|
|
2885
3039
|
if verify or refresh:
|
|
@@ -2887,7 +3041,7 @@ def whoami(
|
|
|
2887
3041
|
# Try to get valid token with auto-refresh
|
|
2888
3042
|
token = get_valid_token()
|
|
2889
3043
|
if token is None:
|
|
2890
|
-
typer.echo("Token expired and refresh failed. Run: wafer login", err=True)
|
|
3044
|
+
typer.echo("Token expired and refresh failed. Run: wafer auth login", err=True)
|
|
2891
3045
|
raise typer.Exit(1)
|
|
2892
3046
|
if token != creds.access_token:
|
|
2893
3047
|
typer.echo("Token refreshed successfully")
|
|
@@ -2900,10 +3054,10 @@ def whoami(
|
|
|
2900
3054
|
except Exception as e:
|
|
2901
3055
|
if creds.refresh_token and not refresh:
|
|
2902
3056
|
typer.echo(f"Token expired: {e}", err=True)
|
|
2903
|
-
typer.echo("Try: wafer whoami --refresh", err=True)
|
|
3057
|
+
typer.echo("Try: wafer auth whoami --refresh", err=True)
|
|
2904
3058
|
else:
|
|
2905
3059
|
typer.echo(f"Token invalid or expired: {e}", err=True)
|
|
2906
|
-
typer.echo("Run: wafer login", err=True)
|
|
3060
|
+
typer.echo("Run: wafer auth login", err=True)
|
|
2907
3061
|
raise typer.Exit(1) from None
|
|
2908
3062
|
elif creds.email:
|
|
2909
3063
|
typer.echo(creds.email)
|
|
@@ -2911,7 +3065,7 @@ def whoami(
|
|
|
2911
3065
|
typer.echo("Logged in (email not available)")
|
|
2912
3066
|
|
|
2913
3067
|
|
|
2914
|
-
@app.command("guide")
|
|
3068
|
+
@app.command("guide", rich_help_panel="Onboarding")
|
|
2915
3069
|
def guide() -> None:
|
|
2916
3070
|
"""Show the Wafer CLI usage guide.
|
|
2917
3071
|
|
|
@@ -2942,7 +3096,7 @@ demo_app = typer.Typer(
|
|
|
2942
3096
|
wafer demo trace Analyze a sample performance trace
|
|
2943
3097
|
wafer demo eval Run kernel evaluation on cloud GPU (requires login)"""
|
|
2944
3098
|
)
|
|
2945
|
-
app.add_typer(demo_app, name="demo")
|
|
3099
|
+
app.add_typer(demo_app, name="demo", rich_help_panel="Onboarding")
|
|
2946
3100
|
|
|
2947
3101
|
DEMO_TRACES_URL = "https://github.com/wafer-ai/wafer/raw/main/apps/wafer-cli/wafer/demo_data"
|
|
2948
3102
|
DEMO_DIR = Path.home() / ".cache" / "wafer" / "demo"
|
|
@@ -3162,7 +3316,7 @@ def demo_eval(
|
|
|
3162
3316
|
"""Demo: Evaluate a kernel on a cloud GPU.
|
|
3163
3317
|
|
|
3164
3318
|
Creates a workspace, runs a sample Triton kernel evaluation, and cleans up.
|
|
3165
|
-
Requires authentication (wafer login).
|
|
3319
|
+
Requires authentication (wafer auth login).
|
|
3166
3320
|
|
|
3167
3321
|
Example:
|
|
3168
3322
|
wafer demo eval
|
|
@@ -3177,7 +3331,7 @@ def demo_eval(
|
|
|
3177
3331
|
# Check auth first
|
|
3178
3332
|
creds = load_credentials()
|
|
3179
3333
|
if not creds:
|
|
3180
|
-
typer.echo("Error: Not authenticated. Run: wafer login")
|
|
3334
|
+
typer.echo("Error: Not authenticated. Run: wafer auth login")
|
|
3181
3335
|
raise typer.Exit(1)
|
|
3182
3336
|
|
|
3183
3337
|
if not yes:
|
|
@@ -3458,7 +3612,7 @@ def init_runpod(
|
|
|
3458
3612
|
gpu_configs = {
|
|
3459
3613
|
"MI300X": {
|
|
3460
3614
|
"gpu_type_id": "AMD Instinct MI300X OAM",
|
|
3461
|
-
"image": "
|
|
3615
|
+
"image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
|
|
3462
3616
|
"compute_capability": "9.4",
|
|
3463
3617
|
},
|
|
3464
3618
|
"H100": {
|
|
@@ -3554,7 +3708,7 @@ def init_digitalocean(
|
|
|
3554
3708
|
"ssh_key": ssh_key,
|
|
3555
3709
|
"region": region,
|
|
3556
3710
|
"size_slug": "gpu-mi300x1-192gb-devcloud",
|
|
3557
|
-
"image": "
|
|
3711
|
+
"image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
|
|
3558
3712
|
"provision_timeout": 600,
|
|
3559
3713
|
"eval_timeout": 600,
|
|
3560
3714
|
"keep_alive": keep_alive,
|
|
@@ -3826,12 +3980,16 @@ def targets_add(
|
|
|
3826
3980
|
|
|
3827
3981
|
@targets_app.command("list")
|
|
3828
3982
|
def targets_list() -> None:
|
|
3829
|
-
"""List all configured targets.
|
|
3983
|
+
"""List all configured targets with live provider status.
|
|
3830
3984
|
|
|
3831
3985
|
Example:
|
|
3832
3986
|
wafer config targets list
|
|
3833
3987
|
"""
|
|
3834
|
-
|
|
3988
|
+
import socket
|
|
3989
|
+
|
|
3990
|
+
import trio
|
|
3991
|
+
|
|
3992
|
+
from .targets import get_default_target, list_targets, load_target, remove_target
|
|
3835
3993
|
|
|
3836
3994
|
targets = list_targets()
|
|
3837
3995
|
default = get_default_target()
|
|
@@ -3841,10 +3999,146 @@ def targets_list() -> None:
|
|
|
3841
3999
|
typer.echo("Add one with: wafer config targets add <path/to/target.toml>")
|
|
3842
4000
|
return
|
|
3843
4001
|
|
|
4002
|
+
def _parse_ssh_target(ssh_target: str) -> tuple[str, int]:
|
|
4003
|
+
"""Extract (host, port) from user@host:port string."""
|
|
4004
|
+
parts = ssh_target.rsplit(":", 1)
|
|
4005
|
+
host_part = parts[0]
|
|
4006
|
+
port = int(parts[1]) if len(parts) > 1 else 22
|
|
4007
|
+
if "@" in host_part:
|
|
4008
|
+
host = host_part.split("@", 1)[1]
|
|
4009
|
+
else:
|
|
4010
|
+
host = host_part
|
|
4011
|
+
return (host, port)
|
|
4012
|
+
|
|
4013
|
+
async def _get_live_provider_endpoints() -> set[tuple[str, int]]:
|
|
4014
|
+
"""Query RunPod + DO APIs. Returns set of live (ip, port) endpoints."""
|
|
4015
|
+
from wafer_core.targets.digitalocean import list_running_droplets
|
|
4016
|
+
from wafer_core.targets.runpod import sync_pods_from_api
|
|
4017
|
+
|
|
4018
|
+
live_endpoints: set[tuple[str, int]] = set()
|
|
4019
|
+
|
|
4020
|
+
async def _fetch_runpod() -> None:
|
|
4021
|
+
try:
|
|
4022
|
+
pods = await sync_pods_from_api()
|
|
4023
|
+
for p in pods:
|
|
4024
|
+
live_endpoints.add((p.public_ip, p.ssh_port))
|
|
4025
|
+
except Exception:
|
|
4026
|
+
pass
|
|
4027
|
+
|
|
4028
|
+
async def _fetch_do() -> None:
|
|
4029
|
+
try:
|
|
4030
|
+
droplets = await list_running_droplets()
|
|
4031
|
+
for d in droplets:
|
|
4032
|
+
live_endpoints.add((d.public_ip, d.ssh_port))
|
|
4033
|
+
except Exception:
|
|
4034
|
+
pass
|
|
4035
|
+
|
|
4036
|
+
async with trio.open_nursery() as nursery:
|
|
4037
|
+
nursery.start_soon(_fetch_runpod)
|
|
4038
|
+
nursery.start_soon(_fetch_do)
|
|
4039
|
+
|
|
4040
|
+
return live_endpoints
|
|
4041
|
+
|
|
4042
|
+
async def _get_target_status(
|
|
4043
|
+
name: str,
|
|
4044
|
+
live_endpoints: set[tuple[str, int]],
|
|
4045
|
+
) -> tuple[str, str, str]:
|
|
4046
|
+
"""Returns (name, status, ssh_info)."""
|
|
4047
|
+
from wafer_core.targets.digitalocean import (
|
|
4048
|
+
_remove_droplet_from_state,
|
|
4049
|
+
check_droplet_running,
|
|
4050
|
+
get_droplet_state,
|
|
4051
|
+
)
|
|
4052
|
+
from wafer_core.targets.runpod import (
|
|
4053
|
+
_remove_pod_from_state,
|
|
4054
|
+
check_pod_running,
|
|
4055
|
+
get_pod_state,
|
|
4056
|
+
)
|
|
4057
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
4058
|
+
BaremetalTarget,
|
|
4059
|
+
DigitalOceanTarget,
|
|
4060
|
+
ModalTarget,
|
|
4061
|
+
RunPodTarget,
|
|
4062
|
+
)
|
|
4063
|
+
|
|
4064
|
+
try:
|
|
4065
|
+
target = load_target(name)
|
|
4066
|
+
except (FileNotFoundError, ValueError, AssertionError, TypeError):
|
|
4067
|
+
return (name, "error", "")
|
|
4068
|
+
|
|
4069
|
+
if isinstance(target, RunPodTarget):
|
|
4070
|
+
pod = get_pod_state(name)
|
|
4071
|
+
if not pod:
|
|
4072
|
+
return (name, "no instance", "")
|
|
4073
|
+
if await check_pod_running(pod.pod_id):
|
|
4074
|
+
return (name, "running", f"{pod.ssh_username}@{pod.public_ip}:{pod.ssh_port}")
|
|
4075
|
+
_remove_pod_from_state(name)
|
|
4076
|
+
return (name, "stopped", "")
|
|
4077
|
+
|
|
4078
|
+
if isinstance(target, DigitalOceanTarget):
|
|
4079
|
+
droplet = get_droplet_state(name)
|
|
4080
|
+
if not droplet:
|
|
4081
|
+
return (name, "no instance", "")
|
|
4082
|
+
if await check_droplet_running(droplet.droplet_id):
|
|
4083
|
+
return (
|
|
4084
|
+
name,
|
|
4085
|
+
"running",
|
|
4086
|
+
f"{droplet.ssh_username}@{droplet.public_ip}:{droplet.ssh_port}",
|
|
4087
|
+
)
|
|
4088
|
+
_remove_droplet_from_state(name)
|
|
4089
|
+
return (name, "stopped", "")
|
|
4090
|
+
|
|
4091
|
+
if isinstance(target, BaremetalTarget):
|
|
4092
|
+
ssh_target = target.ssh_target
|
|
4093
|
+
host, port = _parse_ssh_target(ssh_target)
|
|
4094
|
+
|
|
4095
|
+
def _tcp_check() -> bool:
|
|
4096
|
+
try:
|
|
4097
|
+
sock = socket.create_connection((host, port), timeout=2)
|
|
4098
|
+
sock.close()
|
|
4099
|
+
return True
|
|
4100
|
+
except OSError:
|
|
4101
|
+
return False
|
|
4102
|
+
|
|
4103
|
+
reachable = await trio.to_thread.run_sync(_tcp_check)
|
|
4104
|
+
if reachable:
|
|
4105
|
+
return (name, "reachable", ssh_target)
|
|
4106
|
+
|
|
4107
|
+
# Unreachable + has a provider = backed by an ephemeral instance.
|
|
4108
|
+
# If not in the live provider listing, the instance is gone — remove config.
|
|
4109
|
+
if target.provider and (host, port) not in live_endpoints:
|
|
4110
|
+
remove_target(name)
|
|
4111
|
+
return (name, "removed (dead pod)", ssh_target)
|
|
4112
|
+
|
|
4113
|
+
return (name, "unreachable", ssh_target)
|
|
4114
|
+
|
|
4115
|
+
if isinstance(target, ModalTarget):
|
|
4116
|
+
return (name, "serverless", "")
|
|
4117
|
+
|
|
4118
|
+
# Unknown target type
|
|
4119
|
+
return (name, "unknown", "")
|
|
4120
|
+
|
|
4121
|
+
async def _gather_statuses() -> list[tuple[str, str, str]]:
|
|
4122
|
+
live_endpoints = await _get_live_provider_endpoints()
|
|
4123
|
+
results: list[tuple[str, str, str]] = [("", "", "")] * len(targets)
|
|
4124
|
+
|
|
4125
|
+
async def _check(i: int, name: str) -> None:
|
|
4126
|
+
results[i] = await _get_target_status(name, live_endpoints)
|
|
4127
|
+
|
|
4128
|
+
async with trio.open_nursery() as nursery:
|
|
4129
|
+
for i, name in enumerate(targets):
|
|
4130
|
+
nursery.start_soon(_check, i, name)
|
|
4131
|
+
|
|
4132
|
+
return results
|
|
4133
|
+
|
|
4134
|
+
statuses = trio.run(_gather_statuses)
|
|
4135
|
+
|
|
3844
4136
|
typer.echo("Configured targets:")
|
|
3845
|
-
for name in
|
|
4137
|
+
for name, status, ssh_info in statuses:
|
|
3846
4138
|
marker = " (default)" if name == default else ""
|
|
3847
|
-
|
|
4139
|
+
label = f" {name}{marker}"
|
|
4140
|
+
detail = f" {ssh_info}" if ssh_info else ""
|
|
4141
|
+
typer.echo(f"{label:<40}{status}{detail}")
|
|
3848
4142
|
|
|
3849
4143
|
|
|
3850
4144
|
@targets_app.command("show")
|
|
@@ -4056,6 +4350,216 @@ def targets_cleanup(
|
|
|
4056
4350
|
raise typer.Exit(1) from None
|
|
4057
4351
|
|
|
4058
4352
|
|
|
4353
|
+
# Known libraries that can be installed on targets
|
|
4354
|
+
# TODO: Consider adding HipKittens to the default RunPod/DO Docker images
|
|
4355
|
+
# so this install step isn't needed. For now, this command handles it.
|
|
4356
|
+
# Architecture → branch mapping for libraries that ship per-arch branches.
|
|
4357
|
+
# "default" is used when the detected arch has no explicit entry.
|
|
4358
|
+
_ARCH_BRANCHES: dict[str, dict[str, str]] = {
|
|
4359
|
+
"hipkittens": {
|
|
4360
|
+
"gfx942": "cdna3", # MI300X, MI325X
|
|
4361
|
+
"default": "main", # MI350X, MI355X, and future CDNA4+
|
|
4362
|
+
},
|
|
4363
|
+
}
|
|
4364
|
+
|
|
4365
|
+
INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
|
|
4366
|
+
"hipkittens": {
|
|
4367
|
+
"description": "HipKittens - AMD port of ThunderKittens",
|
|
4368
|
+
"git_url": "https://github.com/HazyResearch/HipKittens.git",
|
|
4369
|
+
"install_path": "/opt/hipkittens",
|
|
4370
|
+
"requires_amd": True,
|
|
4371
|
+
},
|
|
4372
|
+
# CK is already installed with ROCm 7.0, no action needed
|
|
4373
|
+
"repair-headers": {
|
|
4374
|
+
"description": "Repair ROCm thrust headers (fixes hipify corruption)",
|
|
4375
|
+
"custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
|
|
4376
|
+
"requires_amd": True,
|
|
4377
|
+
},
|
|
4378
|
+
}
|
|
4379
|
+
|
|
4380
|
+
|
|
4381
|
+
def _resolve_gfx_arch(target: object, ssh_cmd: list[str]) -> str | None:
|
|
4382
|
+
"""Return the gfx architecture string for *target*.
|
|
4383
|
+
|
|
4384
|
+
1. If the target config already carries a compute_capability, map it.
|
|
4385
|
+
2. Otherwise SSH in and probe with ``rocminfo``.
|
|
4386
|
+
Returns None only if detection fails entirely.
|
|
4387
|
+
"""
|
|
4388
|
+
import subprocess
|
|
4389
|
+
|
|
4390
|
+
from .evaluate import AMD_CC_TO_ARCH
|
|
4391
|
+
|
|
4392
|
+
cc = getattr(target, "compute_capability", None)
|
|
4393
|
+
if cc and cc in AMD_CC_TO_ARCH:
|
|
4394
|
+
return AMD_CC_TO_ARCH[cc]
|
|
4395
|
+
|
|
4396
|
+
typer.echo(" Detecting GPU architecture via rocminfo...")
|
|
4397
|
+
probe_script = "rocminfo 2>/dev/null | grep -oP 'gfx\\d+' | head -1"
|
|
4398
|
+
result = subprocess.run(
|
|
4399
|
+
ssh_cmd + [probe_script],
|
|
4400
|
+
capture_output=True,
|
|
4401
|
+
text=True,
|
|
4402
|
+
timeout=30,
|
|
4403
|
+
)
|
|
4404
|
+
arch = result.stdout.strip()
|
|
4405
|
+
if result.returncode == 0 and arch.startswith("gfx"):
|
|
4406
|
+
typer.echo(f" Detected: {arch}")
|
|
4407
|
+
return arch
|
|
4408
|
+
|
|
4409
|
+
typer.echo(" Warning: could not detect GPU architecture", err=True)
|
|
4410
|
+
return None
|
|
4411
|
+
|
|
4412
|
+
|
|
4413
|
+
@targets_app.command("install")
|
|
4414
|
+
def targets_install(
|
|
4415
|
+
name: str = typer.Argument(..., help="Target name"),
|
|
4416
|
+
library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
|
|
4417
|
+
) -> None:
|
|
4418
|
+
"""Install a library or run maintenance on a target (idempotent).
|
|
4419
|
+
|
|
4420
|
+
Installs header-only libraries like HipKittens on remote targets.
|
|
4421
|
+
Safe to run multiple times - will skip if already installed.
|
|
4422
|
+
|
|
4423
|
+
For libraries with per-architecture branches (e.g. HipKittens), the
|
|
4424
|
+
correct branch is selected automatically based on the target's GPU.
|
|
4425
|
+
|
|
4426
|
+
Available libraries:
|
|
4427
|
+
hipkittens - HipKittens (AMD ThunderKittens port)
|
|
4428
|
+
repair-headers - Fix ROCm thrust headers (after hipify corruption)
|
|
4429
|
+
|
|
4430
|
+
Examples:
|
|
4431
|
+
wafer config targets install runpod-mi300x hipkittens
|
|
4432
|
+
wafer config targets install runpod-mi300x repair-headers
|
|
4433
|
+
wafer config targets install do-mi300x hipkittens
|
|
4434
|
+
"""
|
|
4435
|
+
import subprocess
|
|
4436
|
+
|
|
4437
|
+
from .targets import load_target
|
|
4438
|
+
from .targets_ops import get_target_ssh_info
|
|
4439
|
+
|
|
4440
|
+
if library not in INSTALLABLE_LIBRARIES:
|
|
4441
|
+
available = ", ".join(INSTALLABLE_LIBRARIES.keys())
|
|
4442
|
+
typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
|
|
4443
|
+
raise typer.Exit(1)
|
|
4444
|
+
|
|
4445
|
+
lib_info = INSTALLABLE_LIBRARIES[library]
|
|
4446
|
+
|
|
4447
|
+
try:
|
|
4448
|
+
target = load_target(name)
|
|
4449
|
+
except FileNotFoundError as e:
|
|
4450
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4451
|
+
raise typer.Exit(1) from None
|
|
4452
|
+
|
|
4453
|
+
# Check if target is AMD (for AMD-only libraries)
|
|
4454
|
+
if lib_info.get("requires_amd"):
|
|
4455
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
4456
|
+
DigitalOceanTarget,
|
|
4457
|
+
RunPodTarget,
|
|
4458
|
+
)
|
|
4459
|
+
|
|
4460
|
+
is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
|
|
4461
|
+
if not is_amd and hasattr(target, "compute_capability"):
|
|
4462
|
+
# Check compute capability for MI300X (gfx942 = 9.4)
|
|
4463
|
+
is_amd = target.compute_capability.startswith("9.")
|
|
4464
|
+
if not is_amd:
|
|
4465
|
+
typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
|
|
4466
|
+
raise typer.Exit(1)
|
|
4467
|
+
|
|
4468
|
+
typer.echo(f"Installing {library} on {name}...")
|
|
4469
|
+
typer.echo(f" {lib_info['description']}")
|
|
4470
|
+
|
|
4471
|
+
async def _install() -> bool:
|
|
4472
|
+
# get_target_ssh_info uses pure trio async (no asyncio bridging needed)
|
|
4473
|
+
# and we use subprocess for SSH, not AsyncSSHClient
|
|
4474
|
+
ssh_info = await get_target_ssh_info(target)
|
|
4475
|
+
|
|
4476
|
+
ssh_cmd = [
|
|
4477
|
+
"ssh",
|
|
4478
|
+
"-o",
|
|
4479
|
+
"StrictHostKeyChecking=no",
|
|
4480
|
+
"-o",
|
|
4481
|
+
"UserKnownHostsFile=/dev/null",
|
|
4482
|
+
"-o",
|
|
4483
|
+
"ConnectTimeout=30",
|
|
4484
|
+
"-i",
|
|
4485
|
+
str(ssh_info.key_path),
|
|
4486
|
+
"-p",
|
|
4487
|
+
str(ssh_info.port),
|
|
4488
|
+
f"{ssh_info.user}@{ssh_info.host}",
|
|
4489
|
+
]
|
|
4490
|
+
|
|
4491
|
+
# Handle custom scripts (like repair-headers) vs git installs
|
|
4492
|
+
if "custom_script" in lib_info:
|
|
4493
|
+
install_script = str(lib_info["custom_script"])
|
|
4494
|
+
success_marker = "REPAIRED"
|
|
4495
|
+
else:
|
|
4496
|
+
install_path = lib_info["install_path"]
|
|
4497
|
+
git_url = lib_info["git_url"]
|
|
4498
|
+
|
|
4499
|
+
# Resolve the branch for arch-aware libraries
|
|
4500
|
+
branch = "main"
|
|
4501
|
+
arch_map = _ARCH_BRANCHES.get(library)
|
|
4502
|
+
if arch_map:
|
|
4503
|
+
gfx = await trio.to_thread.run_sync(lambda: _resolve_gfx_arch(target, ssh_cmd))
|
|
4504
|
+
branch = arch_map.get(gfx, arch_map["default"]) if gfx else arch_map["default"]
|
|
4505
|
+
typer.echo(f" Branch: {branch} (arch={gfx or 'unknown'})")
|
|
4506
|
+
|
|
4507
|
+
# Idempotent: if already cloned, ensure correct branch & pull
|
|
4508
|
+
install_script = f"""
|
|
4509
|
+
if [ -d "{install_path}" ]; then
|
|
4510
|
+
echo "ALREADY_INSTALLED: {install_path} exists"
|
|
4511
|
+
cd {install_path} && git fetch --quiet origin && git checkout {branch} --quiet && git pull --quiet origin {branch}
|
|
4512
|
+
else
|
|
4513
|
+
echo "INSTALLING: cloning to {install_path}"
|
|
4514
|
+
git clone --quiet --branch {branch} {git_url} {install_path}
|
|
4515
|
+
fi
|
|
4516
|
+
echo "DONE"
|
|
4517
|
+
"""
|
|
4518
|
+
success_marker = "DONE"
|
|
4519
|
+
|
|
4520
|
+
def run_ssh() -> subprocess.CompletedProcess[str]:
|
|
4521
|
+
return subprocess.run(
|
|
4522
|
+
ssh_cmd + [install_script],
|
|
4523
|
+
capture_output=True,
|
|
4524
|
+
text=True,
|
|
4525
|
+
timeout=120,
|
|
4526
|
+
)
|
|
4527
|
+
|
|
4528
|
+
result = await trio.to_thread.run_sync(run_ssh)
|
|
4529
|
+
|
|
4530
|
+
if result.returncode != 0:
|
|
4531
|
+
typer.echo(f"Error: {result.stderr}", err=True)
|
|
4532
|
+
return False
|
|
4533
|
+
|
|
4534
|
+
output = result.stdout.strip()
|
|
4535
|
+
if "ALREADY_INSTALLED" in output:
|
|
4536
|
+
typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
|
|
4537
|
+
elif "INSTALLING" in output:
|
|
4538
|
+
typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
|
|
4539
|
+
elif "REPAIRED" in output:
|
|
4540
|
+
typer.echo(" ROCm headers repaired")
|
|
4541
|
+
|
|
4542
|
+
return success_marker in output
|
|
4543
|
+
|
|
4544
|
+
try:
|
|
4545
|
+
success = trio.run(_install)
|
|
4546
|
+
except Exception as e:
|
|
4547
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4548
|
+
raise typer.Exit(1) from None
|
|
4549
|
+
|
|
4550
|
+
if success:
|
|
4551
|
+
typer.echo(f"✓ {library} ready on {name}")
|
|
4552
|
+
|
|
4553
|
+
# Print usage hint
|
|
4554
|
+
if library == "hipkittens":
|
|
4555
|
+
typer.echo("")
|
|
4556
|
+
typer.echo("Usage in load_inline:")
|
|
4557
|
+
typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
|
|
4558
|
+
else:
|
|
4559
|
+
typer.echo(f"Failed to install {library}", err=True)
|
|
4560
|
+
raise typer.Exit(1)
|
|
4561
|
+
|
|
4562
|
+
|
|
4059
4563
|
@targets_app.command("pods")
|
|
4060
4564
|
def targets_pods() -> None:
|
|
4061
4565
|
"""List all running RunPod pods.
|
|
@@ -4185,8 +4689,8 @@ def billing_usage(
|
|
|
4185
4689
|
"""Show current billing usage and subscription info.
|
|
4186
4690
|
|
|
4187
4691
|
Example:
|
|
4188
|
-
wafer billing
|
|
4189
|
-
wafer billing --json
|
|
4692
|
+
wafer config billing
|
|
4693
|
+
wafer config billing --json
|
|
4190
4694
|
"""
|
|
4191
4695
|
# Only show usage if no subcommand was invoked
|
|
4192
4696
|
if ctx.invoked_subcommand is not None:
|
|
@@ -4214,9 +4718,9 @@ def billing_topup(
|
|
|
4214
4718
|
Opens a Stripe checkout page to add credits. Default amount is $25.
|
|
4215
4719
|
|
|
4216
4720
|
Example:
|
|
4217
|
-
wafer billing topup # Add $25
|
|
4218
|
-
wafer billing topup 100 # Add $100
|
|
4219
|
-
wafer billing topup --no-browser # Print URL instead
|
|
4721
|
+
wafer config billing topup # Add $25
|
|
4722
|
+
wafer config billing topup 100 # Add $100
|
|
4723
|
+
wafer config billing topup --no-browser # Print URL instead
|
|
4220
4724
|
"""
|
|
4221
4725
|
import webbrowser
|
|
4222
4726
|
|
|
@@ -4262,8 +4766,8 @@ def billing_portal(
|
|
|
4262
4766
|
Manage your subscription, update payment method, or view invoices.
|
|
4263
4767
|
|
|
4264
4768
|
Example:
|
|
4265
|
-
wafer billing portal
|
|
4266
|
-
wafer billing portal --no-browser
|
|
4769
|
+
wafer config billing portal
|
|
4770
|
+
wafer config billing portal --no-browser
|
|
4267
4771
|
"""
|
|
4268
4772
|
import webbrowser
|
|
4269
4773
|
|
|
@@ -4300,8 +4804,8 @@ def ssh_keys_list(
|
|
|
4300
4804
|
"""List all registered SSH public keys.
|
|
4301
4805
|
|
|
4302
4806
|
Example:
|
|
4303
|
-
wafer ssh-keys list
|
|
4304
|
-
wafer ssh-keys list --json
|
|
4807
|
+
wafer config ssh-keys list
|
|
4808
|
+
wafer config ssh-keys list --json
|
|
4305
4809
|
"""
|
|
4306
4810
|
from .ssh_keys import list_ssh_keys
|
|
4307
4811
|
|
|
@@ -4327,9 +4831,9 @@ def ssh_keys_add(
|
|
|
4327
4831
|
id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
|
|
4328
4832
|
|
|
4329
4833
|
Example:
|
|
4330
|
-
wafer ssh-keys add # Auto-detect
|
|
4331
|
-
wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
|
|
4332
|
-
wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
|
|
4834
|
+
wafer config ssh-keys add # Auto-detect
|
|
4835
|
+
wafer config ssh-keys add ~/.ssh/id_rsa.pub # Specific file
|
|
4836
|
+
wafer config ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
|
|
4333
4837
|
"""
|
|
4334
4838
|
from .ssh_keys import add_ssh_key
|
|
4335
4839
|
|
|
@@ -4348,10 +4852,10 @@ def ssh_keys_remove(
|
|
|
4348
4852
|
) -> None:
|
|
4349
4853
|
"""Remove an SSH public key.
|
|
4350
4854
|
|
|
4351
|
-
Get the key ID from 'wafer ssh-keys list'.
|
|
4855
|
+
Get the key ID from 'wafer config ssh-keys list'.
|
|
4352
4856
|
|
|
4353
4857
|
Example:
|
|
4354
|
-
wafer ssh-keys remove abc123-def456-...
|
|
4858
|
+
wafer config ssh-keys remove abc123-def456-...
|
|
4355
4859
|
"""
|
|
4356
4860
|
from .ssh_keys import remove_ssh_key
|
|
4357
4861
|
|
|
@@ -4391,9 +4895,13 @@ def workspaces_list(
|
|
|
4391
4895
|
@workspaces_app.command("create")
|
|
4392
4896
|
def workspaces_create(
|
|
4393
4897
|
name: str = typer.Argument(..., help="Workspace name"),
|
|
4394
|
-
gpu_type: str = typer.Option(
|
|
4898
|
+
gpu_type: str = typer.Option(
|
|
4899
|
+
"B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"
|
|
4900
|
+
),
|
|
4395
4901
|
image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
|
|
4396
|
-
wait: bool = typer.Option(
|
|
4902
|
+
wait: bool = typer.Option(
|
|
4903
|
+
False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"
|
|
4904
|
+
),
|
|
4397
4905
|
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
4398
4906
|
) -> None:
|
|
4399
4907
|
"""Create a new workspace.
|
|
@@ -4702,19 +5210,25 @@ def workspaces_ssh(
|
|
|
4702
5210
|
ssh_host = ws.get("ssh_host")
|
|
4703
5211
|
ssh_port = ws.get("ssh_port")
|
|
4704
5212
|
ssh_user = ws.get("ssh_user")
|
|
4705
|
-
|
|
5213
|
+
|
|
4706
5214
|
if not ssh_host or not ssh_port or not ssh_user:
|
|
4707
5215
|
typer.echo("Error: Workspace not ready. Wait a few seconds and retry.", err=True)
|
|
4708
5216
|
raise typer.Exit(1)
|
|
4709
5217
|
|
|
4710
5218
|
# Connect via SSH
|
|
4711
|
-
os.execvp(
|
|
5219
|
+
os.execvp(
|
|
4712
5220
|
"ssh",
|
|
4713
|
-
|
|
4714
|
-
|
|
4715
|
-
|
|
4716
|
-
|
|
4717
|
-
|
|
5221
|
+
[
|
|
5222
|
+
"ssh",
|
|
5223
|
+
"-p",
|
|
5224
|
+
str(ssh_port),
|
|
5225
|
+
"-o",
|
|
5226
|
+
"StrictHostKeyChecking=no",
|
|
5227
|
+
"-o",
|
|
5228
|
+
"UserKnownHostsFile=/dev/null",
|
|
5229
|
+
f"{ssh_user}@{ssh_host}",
|
|
5230
|
+
],
|
|
5231
|
+
)
|
|
4718
5232
|
|
|
4719
5233
|
|
|
4720
5234
|
@workspaces_app.command("sync")
|
|
@@ -4777,6 +5291,69 @@ def workspaces_sync(
|
|
|
4777
5291
|
raise typer.Exit(1) from None
|
|
4778
5292
|
|
|
4779
5293
|
|
|
5294
|
+
@workspaces_app.command("pull")
|
|
5295
|
+
def workspaces_pull(
|
|
5296
|
+
workspace: str = typer.Argument(..., help="Workspace name or ID"),
|
|
5297
|
+
remote_path: str = typer.Argument(
|
|
5298
|
+
..., help="Remote path in workspace (relative to /workspace or absolute)"
|
|
5299
|
+
),
|
|
5300
|
+
local_path: Path = typer.Argument(
|
|
5301
|
+
Path("."), help="Local destination path (default: current directory)"
|
|
5302
|
+
),
|
|
5303
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
5304
|
+
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
5305
|
+
) -> None:
|
|
5306
|
+
"""Pull files from workspace to local machine.
|
|
5307
|
+
|
|
5308
|
+
Uses rsync over SSH to download files from the workspace's /workspace directory.
|
|
5309
|
+
|
|
5310
|
+
Examples:
|
|
5311
|
+
wafer workspaces pull dev kernel.py ./ # Pull single file
|
|
5312
|
+
wafer workspaces pull dev kernel.py ./my_kernel.py # Pull and rename
|
|
5313
|
+
wafer workspaces pull dev /workspace/results ./ # Pull directory
|
|
5314
|
+
"""
|
|
5315
|
+
from .global_config import get_preferences
|
|
5316
|
+
from .workspaces import pull_files
|
|
5317
|
+
|
|
5318
|
+
# Determine verbosity based on mode
|
|
5319
|
+
prefs = get_preferences()
|
|
5320
|
+
if quiet:
|
|
5321
|
+
show_status = False
|
|
5322
|
+
elif verbose:
|
|
5323
|
+
show_status = True
|
|
5324
|
+
else:
|
|
5325
|
+
show_status = prefs.mode == "explicit"
|
|
5326
|
+
|
|
5327
|
+
if show_status:
|
|
5328
|
+
typer.echo(f"[wafer] Pulling {remote_path} from workspace {workspace}...", err=True)
|
|
5329
|
+
|
|
5330
|
+
def on_progress(msg: str) -> None:
|
|
5331
|
+
if show_status:
|
|
5332
|
+
typer.echo(f"[wafer] {msg}", err=True)
|
|
5333
|
+
|
|
5334
|
+
try:
|
|
5335
|
+
file_count = pull_files(
|
|
5336
|
+
workspace, remote_path, local_path.resolve(), on_progress=on_progress
|
|
5337
|
+
)
|
|
5338
|
+
if show_status:
|
|
5339
|
+
typer.echo(f"[wafer] Pulled {file_count} files to {local_path}", err=True)
|
|
5340
|
+
except RuntimeError as e:
|
|
5341
|
+
typer.echo(f"Error: {e}", err=True)
|
|
5342
|
+
raise typer.Exit(1) from None
|
|
5343
|
+
|
|
5344
|
+
|
|
5345
|
+
# =============================================================================
|
|
5346
|
+
# Live resource commands (list/terminate/reconcile/provision)
|
|
5347
|
+
# =============================================================================
|
|
5348
|
+
|
|
5349
|
+
targets_ops_app.command("list")(_targets_list_cmd)
|
|
5350
|
+
targets_ops_app.command("terminate")(_targets_terminate_cmd)
|
|
5351
|
+
targets_ops_app.command("reconcile")(_targets_reconcile_cmd)
|
|
5352
|
+
targets_ops_app.command("provision")(_targets_provision_cmd)
|
|
5353
|
+
targets_ops_app.command("pools")(_targets_pools_cmd)
|
|
5354
|
+
targets_ops_app.command("probe")(_targets_probe_cmd)
|
|
5355
|
+
|
|
5356
|
+
|
|
4780
5357
|
# =============================================================================
|
|
4781
5358
|
# Target operations commands (exec/ssh/sync)
|
|
4782
5359
|
# =============================================================================
|
|
@@ -5535,7 +6112,7 @@ def ncu_analyze(
|
|
|
5535
6112
|
compute/memory throughput, and optimization recommendations.
|
|
5536
6113
|
|
|
5537
6114
|
By default, uses local NCU if available, otherwise runs analysis
|
|
5538
|
-
remotely via wafer-api (requires authentication: wafer login).
|
|
6115
|
+
remotely via wafer-api (requires authentication: wafer auth login).
|
|
5539
6116
|
|
|
5540
6117
|
Use --target for direct SSH mode (like wafer remote-run --direct).
|
|
5541
6118
|
Use --include-source to fetch SASS assembly with register/instruction data.
|
|
@@ -5630,7 +6207,7 @@ def nsys_analyze(
|
|
|
5630
6207
|
Returns timeline events, kernel information, memory usage, and diagnostics.
|
|
5631
6208
|
|
|
5632
6209
|
By default, uses local nsys if available, otherwise runs analysis
|
|
5633
|
-
remotely via wafer-api (requires authentication: wafer login).
|
|
6210
|
+
remotely via wafer-api (requires authentication: wafer auth login).
|
|
5634
6211
|
|
|
5635
6212
|
Supports multiple execution modes:
|
|
5636
6213
|
- Local: Uses local nsys CLI (no GPU required for analysis)
|
|
@@ -6615,7 +7192,7 @@ def autotuner_results(
|
|
|
6615
7192
|
raise typer.Exit(1) from None
|
|
6616
7193
|
|
|
6617
7194
|
|
|
6618
|
-
@app.command("capture")
|
|
7195
|
+
@app.command("capture", rich_help_panel="Kernel Development")
|
|
6619
7196
|
def capture_command( # noqa: PLR0915
|
|
6620
7197
|
label: str = typer.Argument(
|
|
6621
7198
|
..., help="Label for this capture (e.g., 'baseline', 'optimized-v2')"
|
|
@@ -7280,6 +7857,203 @@ def isa_targets() -> None:
|
|
|
7280
7857
|
typer.echo(output)
|
|
7281
7858
|
|
|
7282
7859
|
|
|
7860
|
+
# =============================================================================
|
|
7861
|
+
# Trace comparison commands
|
|
7862
|
+
# =============================================================================
|
|
7863
|
+
|
|
7864
|
+
|
|
7865
|
+
@compare_app.command("analyze")
|
|
7866
|
+
def compare_analyze(
|
|
7867
|
+
trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
|
|
7868
|
+
trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
|
|
7869
|
+
format: str = typer.Option(
|
|
7870
|
+
"text",
|
|
7871
|
+
"--format",
|
|
7872
|
+
"-f",
|
|
7873
|
+
help="Output format: text, text-layers, csv, csv-layers, json",
|
|
7874
|
+
),
|
|
7875
|
+
output: Path | None = typer.Option(
|
|
7876
|
+
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7877
|
+
),
|
|
7878
|
+
phase: str = typer.Option(
|
|
7879
|
+
"all",
|
|
7880
|
+
"--phase",
|
|
7881
|
+
help="Filter by phase: all, prefill, decode",
|
|
7882
|
+
),
|
|
7883
|
+
layers: bool = typer.Option(False, "--layers", help="Show layer-wise performance breakdown"),
|
|
7884
|
+
all: bool = typer.Option(
|
|
7885
|
+
False, "--all", help="Show all items (no truncation for layers, operations, kernels)"
|
|
7886
|
+
),
|
|
7887
|
+
stack_traces: bool = typer.Option(
|
|
7888
|
+
False, "--stack-traces", help="Show Python stack traces for operations"
|
|
7889
|
+
),
|
|
7890
|
+
recommendations: bool = typer.Option(
|
|
7891
|
+
False, "--recommendations", help="Generate prioritized recommendations for kernel team"
|
|
7892
|
+
),
|
|
7893
|
+
json: bool = typer.Option(
|
|
7894
|
+
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7895
|
+
),
|
|
7896
|
+
) -> None:
|
|
7897
|
+
"""Compare GPU traces from AMD and NVIDIA platforms.
|
|
7898
|
+
|
|
7899
|
+
Analyzes performance differences between traces, identifying which operations
|
|
7900
|
+
are faster/slower on each platform and providing kernel-level details.
|
|
7901
|
+
|
|
7902
|
+
Examples:
|
|
7903
|
+
# Basic comparison (stdout)
|
|
7904
|
+
wafer compare analyze amd_trace.json nvidia_trace.json
|
|
7905
|
+
|
|
7906
|
+
# Show layer-wise breakdown
|
|
7907
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --layers
|
|
7908
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format text-layers
|
|
7909
|
+
|
|
7910
|
+
# Show all layers without truncation
|
|
7911
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --layers --all
|
|
7912
|
+
|
|
7913
|
+
# Show Python stack traces
|
|
7914
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces
|
|
7915
|
+
|
|
7916
|
+
# Show all stack traces without truncation
|
|
7917
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces --all
|
|
7918
|
+
|
|
7919
|
+
# Save to file
|
|
7920
|
+
wafer compare analyze amd_trace.json nvidia_trace.json -o report.txt
|
|
7921
|
+
|
|
7922
|
+
# CSV output (operations) to file
|
|
7923
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format csv -o operations.csv
|
|
7924
|
+
|
|
7925
|
+
# CSV output (layers) to file
|
|
7926
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format csv-layers -o layers.csv
|
|
7927
|
+
|
|
7928
|
+
# JSON output to file
|
|
7929
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format json -o report.json
|
|
7930
|
+
|
|
7931
|
+
# Analyze only prefill phase
|
|
7932
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --phase prefill
|
|
7933
|
+
"""
|
|
7934
|
+
from .trace_compare import compare_traces
|
|
7935
|
+
|
|
7936
|
+
compare_traces(
|
|
7937
|
+
trace1=trace1,
|
|
7938
|
+
trace2=trace2,
|
|
7939
|
+
output=output,
|
|
7940
|
+
output_format=format,
|
|
7941
|
+
phase=phase,
|
|
7942
|
+
show_layers=layers,
|
|
7943
|
+
show_all=all,
|
|
7944
|
+
show_stack_traces=stack_traces,
|
|
7945
|
+
recommendations=recommendations,
|
|
7946
|
+
)
|
|
7947
|
+
_mark_command_success()
|
|
7948
|
+
|
|
7949
|
+
|
|
7950
|
+
@compare_app.command("fusion")
|
|
7951
|
+
def compare_fusion_cmd(
|
|
7952
|
+
trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
|
|
7953
|
+
trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
|
|
7954
|
+
format: str = typer.Option(
|
|
7955
|
+
"text",
|
|
7956
|
+
"--format",
|
|
7957
|
+
"-f",
|
|
7958
|
+
help="Output format: text, csv, json",
|
|
7959
|
+
),
|
|
7960
|
+
output: Path | None = typer.Option(
|
|
7961
|
+
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7962
|
+
),
|
|
7963
|
+
min_group_size: int = typer.Option(
|
|
7964
|
+
50,
|
|
7965
|
+
"--min-group-size",
|
|
7966
|
+
help="Minimum correlation group size to analyze",
|
|
7967
|
+
),
|
|
7968
|
+
json: bool = typer.Option(
|
|
7969
|
+
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7970
|
+
),
|
|
7971
|
+
) -> None:
|
|
7972
|
+
"""Analyze kernel fusion differences between AMD and NVIDIA traces.
|
|
7973
|
+
|
|
7974
|
+
Detects which operations are fused differently on each platform by analyzing
|
|
7975
|
+
how many kernel launches each platform uses for the same logical operations.
|
|
7976
|
+
|
|
7977
|
+
Examples:
|
|
7978
|
+
# Basic fusion analysis (stdout)
|
|
7979
|
+
wafer compare fusion amd_trace.json nvidia_trace.json
|
|
7980
|
+
|
|
7981
|
+
# Save to file
|
|
7982
|
+
wafer compare fusion amd_trace.json nvidia_trace.json -o fusion_report.txt
|
|
7983
|
+
|
|
7984
|
+
# JSON output to file
|
|
7985
|
+
wafer compare fusion amd_trace.json nvidia_trace.json --format json -o fusion.json
|
|
7986
|
+
|
|
7987
|
+
# CSV output to file
|
|
7988
|
+
wafer compare fusion amd_trace.json nvidia_trace.json --format csv -o fusion.csv
|
|
7989
|
+
"""
|
|
7990
|
+
from .trace_compare import compare_align
|
|
7991
|
+
|
|
7992
|
+
compare_align(
|
|
7993
|
+
trace1=trace1,
|
|
7994
|
+
trace2=trace2,
|
|
7995
|
+
output=output,
|
|
7996
|
+
output_format=format,
|
|
7997
|
+
phase="all",
|
|
7998
|
+
)
|
|
7999
|
+
_mark_command_success()
|
|
8000
|
+
|
|
8001
|
+
|
|
8002
|
+
@compare_app.command("align")
|
|
8003
|
+
def compare_align_cmd(
|
|
8004
|
+
trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
|
|
8005
|
+
trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
|
|
8006
|
+
format: str = typer.Option(
|
|
8007
|
+
"json",
|
|
8008
|
+
"--format",
|
|
8009
|
+
"-f",
|
|
8010
|
+
help="Output format: json",
|
|
8011
|
+
),
|
|
8012
|
+
output: Path | None = typer.Option(
|
|
8013
|
+
None, "--output", "-o", help="Output file (default: stdout)"
|
|
8014
|
+
),
|
|
8015
|
+
phase: str = typer.Option(
|
|
8016
|
+
"all",
|
|
8017
|
+
"--phase",
|
|
8018
|
+
help="Filter by phase: all, prefill, decode",
|
|
8019
|
+
),
|
|
8020
|
+
layer: int | None = typer.Option(
|
|
8021
|
+
None,
|
|
8022
|
+
"--layer",
|
|
8023
|
+
help="Focus on specific layer number",
|
|
8024
|
+
),
|
|
8025
|
+
) -> None:
|
|
8026
|
+
"""Align kernels at layer level for exact kernel-to-kernel comparison.
|
|
8027
|
+
|
|
8028
|
+
Provides kernel-to-kernel mapping across AMD and NVIDIA platforms,
|
|
8029
|
+
showing which kernels correspond to each other at each layer position.
|
|
8030
|
+
|
|
8031
|
+
Examples:
|
|
8032
|
+
# Basic alignment (stdout JSON)
|
|
8033
|
+
wafer compare align amd_trace.json nvidia_trace.json
|
|
8034
|
+
|
|
8035
|
+
# Save to file
|
|
8036
|
+
wafer compare align amd_trace.json nvidia_trace.json -o alignment.json
|
|
8037
|
+
|
|
8038
|
+
# Focus on decode phase only
|
|
8039
|
+
wafer compare align amd_trace.json nvidia_trace.json --phase decode
|
|
8040
|
+
|
|
8041
|
+
# Focus on specific layer
|
|
8042
|
+
wafer compare align amd_trace.json nvidia_trace.json --layer 5
|
|
8043
|
+
"""
|
|
8044
|
+
from .trace_compare import compare_align
|
|
8045
|
+
|
|
8046
|
+
compare_align(
|
|
8047
|
+
trace1=trace1,
|
|
8048
|
+
trace2=trace2,
|
|
8049
|
+
output=output,
|
|
8050
|
+
output_format=format,
|
|
8051
|
+
phase=phase,
|
|
8052
|
+
layer=layer,
|
|
8053
|
+
)
|
|
8054
|
+
_mark_command_success()
|
|
8055
|
+
|
|
8056
|
+
|
|
7283
8057
|
def main() -> None:
|
|
7284
8058
|
"""Entry point for wafer CLI."""
|
|
7285
8059
|
app()
|