wafer-cli 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -18,9 +18,11 @@ Setup:
18
18
  config CLI configuration and local GPU targets
19
19
  """
20
20
 
21
+ import atexit
21
22
  import json
22
23
  import os
23
24
  import sys
25
+ import time
24
26
  from pathlib import Path
25
27
 
26
28
  import trio
@@ -34,6 +36,98 @@ app = typer.Typer(
34
36
  no_args_is_help=True,
35
37
  )
36
38
 
39
+ # =============================================================================
40
+ # Analytics tracking
41
+ # =============================================================================
42
+
43
+ # Track command start time for duration calculation
44
+ _command_start_time: float | None = None
45
+ # Track command outcome (defaults to failure, set to success on clean exit)
46
+ _command_outcome: str = "failure"
47
+
48
+
49
+ def _get_command_path(ctx: typer.Context) -> tuple[str, str | None]:
50
+ """Extract command and subcommand from Typer context.
51
+
52
+ Returns:
53
+ Tuple of (command, subcommand). subcommand may be None.
54
+ """
55
+ # Build command path from invoked subcommand chain
56
+ invoked = ctx.invoked_subcommand
57
+ info_name = ctx.info_name or ""
58
+
59
+ # Get parent command if exists
60
+ parent_cmd = None
61
+ if ctx.parent and ctx.parent.info_name and ctx.parent.info_name != "wafer":
62
+ parent_cmd = ctx.parent.info_name
63
+
64
+ if parent_cmd:
65
+ return parent_cmd, info_name
66
+ return info_name or "unknown", invoked
67
+
68
+
69
+ def _mark_command_success() -> None:
70
+ """Mark the current command as successful.
71
+
72
+ Call this at the end of successful command execution.
73
+ Commands that raise typer.Exit(1) or exceptions will remain marked as failures.
74
+ """
75
+ global _command_outcome
76
+ _command_outcome = "success"
77
+
78
+
79
+ @app.callback()
80
+ def main_callback(ctx: typer.Context) -> None:
81
+ """Initialize analytics and track command execution."""
82
+ global _command_start_time, _command_outcome
83
+ _command_start_time = time.time()
84
+ _command_outcome = "success" # Default to success, mark failure on exceptions
85
+
86
+ # Initialize analytics (lazy import to avoid slowing down --help)
87
+ from . import analytics
88
+
89
+ analytics.init_analytics()
90
+
91
+ # Install exception hook to catch SystemExit and mark failures
92
+ original_excepthook = sys.excepthook
93
+
94
+ def custom_excepthook(exc_type, exc_value, exc_traceback):
95
+ global _command_outcome
96
+ # Mark as failure if SystemExit with non-zero code, or any other exception
97
+ if exc_type is SystemExit:
98
+ exit_code = exc_value.code if hasattr(exc_value, 'code') else 1
99
+ if exit_code != 0 and exit_code is not None:
100
+ _command_outcome = "failure"
101
+ else:
102
+ _command_outcome = "failure"
103
+ # Call original excepthook
104
+ original_excepthook(exc_type, exc_value, exc_traceback)
105
+
106
+ sys.excepthook = custom_excepthook
107
+
108
+ # Register tracking at exit to capture command outcome
109
+ def track_on_exit() -> None:
110
+ command, subcommand = _get_command_path(ctx)
111
+
112
+ # Skip tracking for --help and --version
113
+ if ctx.resilient_parsing:
114
+ return
115
+
116
+ # Calculate duration
117
+ duration_ms = None
118
+ if _command_start_time is not None:
119
+ duration_ms = int((time.time() - _command_start_time) * 1000)
120
+
121
+ # Track the command execution with the recorded outcome
122
+ analytics.track_command(
123
+ command=command,
124
+ subcommand=subcommand,
125
+ outcome=_command_outcome,
126
+ duration_ms=duration_ms,
127
+ )
128
+
129
+ atexit.register(track_on_exit)
130
+
37
131
 
38
132
  # =============================================================================
39
133
  # Autocompletion helpers
@@ -57,13 +151,37 @@ config_app = typer.Typer(help="Manage CLI configuration and local GPU targets")
57
151
  app.add_typer(config_app, name="config")
58
152
 
59
153
  # Target management - nested under config
60
- targets_app = typer.Typer(help="Manage local GPU targets (TOML files)")
154
+ targets_app = typer.Typer(
155
+ help="""Manage GPU targets for remote evaluation.
156
+
157
+ Targets define how to access GPUs. Use 'wafer config targets init' to set up:
158
+
159
+ wafer config targets init ssh # Your own GPU via SSH
160
+ wafer config targets init runpod # RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
161
+ wafer config targets init digitalocean # DigitalOcean AMD GPUs
162
+
163
+ Then use with: wafer evaluate --target <name> ..."""
164
+ )
61
165
  config_app.add_typer(targets_app, name="targets")
62
166
 
63
167
  # Workspace management (remote API-backed)
64
- workspaces_app = typer.Typer(help="Manage cloud GPU workspaces")
168
+ workspaces_app = typer.Typer(
169
+ help="""Manage cloud GPU workspaces for remote development.
170
+
171
+ Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
172
+
173
+ wafer workspaces create dev --gpu H100 # Create workspace
174
+ wafer workspaces exec dev -- python x.py # Run commands
175
+ wafer workspaces ssh dev # Interactive SSH
176
+ wafer workspaces sync dev ./project # Sync files
177
+ wafer workspaces delete dev # Clean up"""
178
+ )
65
179
  app.add_typer(workspaces_app, name="workspaces")
66
180
 
181
+ # Billing management
182
+ billing_app = typer.Typer(help="Manage billing, credits, and subscription")
183
+ app.add_typer(billing_app, name="billing")
184
+
67
185
  # Corpus management
68
186
  corpus_app = typer.Typer(help="Download and manage GPU documentation")
69
187
  app.add_typer(corpus_app, name="corpus")
@@ -124,6 +242,159 @@ app.add_typer(amd_app, name="amd")
124
242
  isa_app = typer.Typer(help="ISA analysis for AMD GPU code objects (.co files)")
125
243
  amd_app.add_typer(isa_app, name="isa")
126
244
 
245
+ # =============================================================================
246
+ # Skill management (wafer skill ...)
247
+ # =============================================================================
248
+
249
+ skill_app = typer.Typer(help="Manage AI coding assistant skills (Claude Code, Codex)")
250
+ app.add_typer(skill_app, name="skill")
251
+
252
+
253
+ @skill_app.command("install")
254
+ def skill_install(
255
+ target: str = typer.Option(
256
+ "all",
257
+ "--target",
258
+ "-t",
259
+ help="Target tool: claude, codex, or all",
260
+ ),
261
+ force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing skill"),
262
+ ) -> None:
263
+ """Install the wafer-guide skill for AI coding assistants.
264
+
265
+ Installs the bundled skill to make wafer commands discoverable by
266
+ Claude Code and/or OpenAI Codex CLI.
267
+
268
+ Skills follow the open agent skills specification (agentskills.io).
269
+
270
+ Examples:
271
+ wafer skill install # Install for both Claude and Codex
272
+ wafer skill install -t claude # Install for Claude Code only
273
+ wafer skill install -t codex # Install for Codex CLI only
274
+ wafer skill install --force # Overwrite existing installation
275
+ """
276
+ # Locate bundled skill
277
+ skill_source = Path(__file__).parent / "skills" / "wafer-guide"
278
+ if not skill_source.exists():
279
+ typer.echo("Error: Bundled skill not found. Package may be corrupted.", err=True)
280
+ raise typer.Exit(1)
281
+
282
+ targets_to_install: list[tuple[str, Path]] = []
283
+
284
+ if target in ("all", "claude"):
285
+ targets_to_install.append((
286
+ "Claude Code",
287
+ Path.home() / ".claude" / "skills" / "wafer-guide",
288
+ ))
289
+ if target in ("all", "codex"):
290
+ targets_to_install.append(("Codex CLI", Path.home() / ".codex" / "skills" / "wafer-guide"))
291
+
292
+ if not targets_to_install:
293
+ typer.echo(f"Error: Unknown target '{target}'. Use: claude, codex, or all", err=True)
294
+ raise typer.Exit(1)
295
+
296
+ for tool_name, dest_path in targets_to_install:
297
+ # Check if already exists
298
+ if dest_path.exists():
299
+ if not force:
300
+ typer.echo(f" {tool_name}: Already installed at {dest_path}")
301
+ typer.echo(" Use --force to overwrite")
302
+ continue
303
+ # Remove existing
304
+ if dest_path.is_symlink():
305
+ dest_path.unlink()
306
+ else:
307
+ import shutil
308
+
309
+ shutil.rmtree(dest_path)
310
+
311
+ # Create parent directory
312
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
313
+
314
+ # Create symlink
315
+ dest_path.symlink_to(skill_source)
316
+ typer.echo(f" {tool_name}: Installed at {dest_path}")
317
+
318
+ typer.echo("")
319
+ typer.echo("Restart your AI assistant to load the new skill.")
320
+
321
+
322
+ @skill_app.command("uninstall")
323
+ def skill_uninstall(
324
+ target: str = typer.Option(
325
+ "all",
326
+ "--target",
327
+ "-t",
328
+ help="Target tool: claude, codex, or all",
329
+ ),
330
+ ) -> None:
331
+ """Uninstall the wafer-guide skill.
332
+
333
+ Examples:
334
+ wafer skill uninstall # Uninstall from both
335
+ wafer skill uninstall -t claude # Uninstall from Claude Code only
336
+ """
337
+ targets_to_uninstall: list[tuple[str, Path]] = []
338
+
339
+ if target in ("all", "claude"):
340
+ targets_to_uninstall.append((
341
+ "Claude Code",
342
+ Path.home() / ".claude" / "skills" / "wafer-guide",
343
+ ))
344
+ if target in ("all", "codex"):
345
+ targets_to_uninstall.append((
346
+ "Codex CLI",
347
+ Path.home() / ".codex" / "skills" / "wafer-guide",
348
+ ))
349
+
350
+ if not targets_to_uninstall:
351
+ typer.echo(f"Error: Unknown target '{target}'. Use: claude, codex, or all", err=True)
352
+ raise typer.Exit(1)
353
+
354
+ for tool_name, dest_path in targets_to_uninstall:
355
+ if not dest_path.exists():
356
+ typer.echo(f" {tool_name}: Not installed")
357
+ continue
358
+
359
+ if dest_path.is_symlink():
360
+ dest_path.unlink()
361
+ else:
362
+ import shutil
363
+
364
+ shutil.rmtree(dest_path)
365
+ typer.echo(f" {tool_name}: Uninstalled from {dest_path}")
366
+
367
+
368
+ @skill_app.command("status")
369
+ def skill_status() -> None:
370
+ """Show installation status of the wafer-guide skill.
371
+
372
+ Examples:
373
+ wafer skill status
374
+ """
375
+ skill_source = Path(__file__).parent / "skills" / "wafer-guide"
376
+
377
+ typer.echo("Wafer Skill Status")
378
+ typer.echo("=" * 40)
379
+ typer.echo(f"Bundled skill: {skill_source}")
380
+ typer.echo(f" Exists: {skill_source.exists()}")
381
+ typer.echo("")
382
+
383
+ installations = [
384
+ ("Claude Code", Path.home() / ".claude" / "skills" / "wafer-guide"),
385
+ ("Codex CLI", Path.home() / ".codex" / "skills" / "wafer-guide"),
386
+ ]
387
+
388
+ for tool_name, path in installations:
389
+ if path.exists():
390
+ if path.is_symlink():
391
+ target = path.resolve()
392
+ typer.echo(f"{tool_name}: Installed (symlink -> {target})")
393
+ else:
394
+ typer.echo(f"{tool_name}: Installed (copy at {path})")
395
+ else:
396
+ typer.echo(f"{tool_name}: Not installed")
397
+
127
398
 
128
399
  @app.command(hidden=True)
129
400
  def run(
@@ -1768,17 +2039,31 @@ def login(
1768
2039
  token: str | None = typer.Option(
1769
2040
  None, "--token", "-t", help="Access token (skip browser OAuth)"
1770
2041
  ),
2042
+ port: int | None = typer.Option(
2043
+ None, "--port", "-p", help="Port for OAuth callback server (default: 8765 for SSH, random for local)"
2044
+ ),
1771
2045
  ) -> None:
1772
2046
  """Authenticate CLI with wafer-api via GitHub OAuth.
1773
2047
 
1774
2048
  Opens browser for GitHub authentication. Use --token to skip browser.
1775
2049
  Uses the API environment from config (see 'wafer config show').
1776
2050
 
2051
+ SSH Users:
2052
+ - Automatically uses port 8765 (just set up port forwarding once)
2053
+ - On local machine: ssh -L 8765:localhost:8765 user@host
2054
+ - On remote machine: wafer login
2055
+ - Browser opens locally, redirect works through tunnel
2056
+
2057
+ Manual token option:
2058
+ - Visit auth.wafer.ai, authenticate, copy token from URL
2059
+ - Run: wafer login --token <paste-token>
2060
+
1777
2061
  Examples:
1778
- wafer login # opens browser for GitHub OAuth
1779
- wafer login --token xyz # use existing token
2062
+ wafer login # auto-detects SSH, uses appropriate port
2063
+ wafer login --port 9000 # override port
2064
+ wafer login --token xyz # manual token (no browser)
1780
2065
 
1781
- # To login to a different environment:
2066
+ # Change environment:
1782
2067
  wafer config set api.environment staging
1783
2068
  wafer login
1784
2069
  """
@@ -1794,11 +2079,21 @@ def login(
1794
2079
  typer.echo(f"Auth: {get_supabase_url()}")
1795
2080
  typer.echo("")
1796
2081
 
2082
+ # Auto-detect SSH and use fixed port
2083
+ if port is None:
2084
+ is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
2085
+ if is_ssh:
2086
+ port = 8765
2087
+ typer.echo("🔒 SSH session detected - using port 8765 for OAuth callback")
2088
+ typer.echo(" Make sure you have port forwarding set up:")
2089
+ typer.echo(" ssh -L 8765:localhost:8765 user@host")
2090
+ typer.echo("")
2091
+
1797
2092
  # Browser OAuth if no token provided
1798
2093
  refresh_token = None
1799
2094
  if token is None:
1800
2095
  try:
1801
- token, refresh_token = browser_login()
2096
+ token, refresh_token = browser_login(port=port)
1802
2097
  except TimeoutError as e:
1803
2098
  typer.echo(f"Error: {e}", err=True)
1804
2099
  raise typer.Exit(1) from None
@@ -1832,6 +2127,11 @@ def login(
1832
2127
  # Save credentials (with refresh token if available)
1833
2128
  save_credentials(token, refresh_token, user_info.email)
1834
2129
 
2130
+ # Track login event with analytics
2131
+ from . import analytics
2132
+
2133
+ analytics.track_login(user_info.user_id, user_info.email)
2134
+
1835
2135
  if user_info.email:
1836
2136
  typer.echo(f"Logged in as {user_info.email}")
1837
2137
  else:
@@ -1844,6 +2144,13 @@ def logout() -> None:
1844
2144
  """Remove stored credentials."""
1845
2145
  from .auth import clear_credentials
1846
2146
 
2147
+ from . import analytics
2148
+
2149
+ # Track logout event first (while credentials still exist for user identification)
2150
+ # Note: track_logout() handles the case where user is not logged in
2151
+ analytics.track_logout()
2152
+
2153
+ # Clear credentials and report result
1847
2154
  if clear_credentials():
1848
2155
  typer.echo("Logged out. Credentials removed.")
1849
2156
  else:
@@ -1920,7 +2227,13 @@ def guide() -> None:
1920
2227
  # =============================================================================
1921
2228
 
1922
2229
  # Demo subcommand group
1923
- demo_app = typer.Typer(help="Demo commands and sample data")
2230
+ demo_app = typer.Typer(
2231
+ help="""Interactive demos for Wafer workflows.
2232
+
2233
+ wafer demo docs Query GPU documentation (downloads ~5MB)
2234
+ wafer demo trace Analyze a sample performance trace
2235
+ wafer demo eval Run kernel evaluation on cloud GPU (requires login)"""
2236
+ )
1924
2237
  app.add_typer(demo_app, name="demo")
1925
2238
 
1926
2239
  DEMO_TRACES_URL = "https://github.com/wafer-ai/wafer/raw/main/apps/wafer-cli/wafer/demo_data"
@@ -2009,73 +2322,309 @@ def demo_traces() -> None:
2009
2322
  )
2010
2323
 
2011
2324
 
2012
- @demo_app.command("examples")
2013
- def demo_examples() -> None:
2014
- """Show example commands for common workflows.
2325
+ @demo_app.command("docs")
2326
+ def demo_docs(
2327
+ yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
2328
+ ) -> None:
2329
+ """Demo: Ask GPU documentation questions.
2330
+
2331
+ Downloads CUDA corpus (~5MB) and asks a sample question using AI.
2015
2332
 
2016
- Prints copy-paste ready examples for:
2017
- - Analyzing traces
2018
- - Asking documentation questions
2019
- - Evaluating kernels
2333
+ Example:
2334
+ wafer demo docs
2335
+ wafer demo docs -y # skip confirmation
2020
2336
  """
2021
- typer.echo("""# Wafer CLI Examples
2337
+ import subprocess
2022
2338
 
2023
- ## 1. Set Up GPU Access (one-time)
2339
+ from .corpus import download_corpus, get_corpus_path
2024
2340
 
2025
- # Option A: Your own GPU via SSH
2026
- wafer config targets init ssh --name my-gpu --host user@hostname:22 --gpu-type H100
2341
+ # Check if already downloaded
2342
+ corpus_path = get_corpus_path("cuda")
2343
+ needs_download = corpus_path is None
2027
2344
 
2028
- # Option B: RunPod (on-demand cloud GPUs)
2029
- export WAFER_RUNPOD_API_KEY=your_key # from runpod.io/console/user/settings
2030
- wafer config targets init runpod --gpu MI300X
2345
+ if needs_download and not yes:
2346
+ typer.echo("This demo will:")
2347
+ typer.echo(" 1. Download CUDA documentation corpus (~5MB)")
2348
+ typer.echo(" 2. Ask a sample question using AI")
2349
+ typer.echo("")
2350
+ if not typer.confirm("Continue?"):
2351
+ raise typer.Exit(0)
2031
2352
 
2032
- # Option C: DigitalOcean (AMD MI300X)
2033
- export WAFER_AMD_DIGITALOCEAN_API_KEY=your_key
2034
- wafer config targets init digitalocean
2353
+ # Step 1: Download corpus if needed
2354
+ if needs_download:
2355
+ typer.echo("\n[1/2] Downloading CUDA corpus...")
2356
+ download_corpus("cuda")
2357
+ else:
2358
+ typer.echo("\n[1/2] CUDA corpus already downloaded")
2035
2359
 
2036
- # Verify your target
2037
- wafer config targets list
2360
+ # Step 2: Ask a question
2361
+ typer.echo("\n[2/2] Asking: 'What is warp divergence?'\n")
2362
+ typer.echo("-" * 60)
2363
+ result = subprocess.run(
2364
+ [
2365
+ "wafer",
2366
+ "wevin",
2367
+ "-s",
2368
+ "-t",
2369
+ "ask-docs",
2370
+ "--corpus",
2371
+ "cuda",
2372
+ "What is warp divergence? Answer in 2-3 sentences.",
2373
+ ],
2374
+ check=False,
2375
+ )
2376
+ typer.echo("-" * 60)
2038
2377
 
2039
- ## 2. Evaluate a Kernel
2378
+ if result.returncode == 0:
2379
+ typer.echo("\n✓ Demo complete! Try your own questions:")
2380
+ typer.echo(' wafer agent -t ask-docs --corpus cuda "your question here"')
2381
+ else:
2382
+ typer.echo("\n✗ Demo failed. Check your configuration.")
2383
+ raise typer.Exit(1)
2040
2384
 
2041
- # Generate template files
2042
- wafer evaluate make-template ./my-kernel
2043
2385
 
2044
- # Run evaluation
2045
- wafer evaluate \\
2046
- --impl ./my-kernel/kernel.py \\
2047
- --reference ./my-kernel/reference.py \\
2048
- --test-cases ./my-kernel/test_cases.json \\
2049
- --target my-gpu # or runpod-mi300x, do-mi300x
2386
+ @demo_app.command("trace")
2387
+ def demo_trace(
2388
+ yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
2389
+ ) -> None:
2390
+ """Demo: Analyze a performance trace.
2050
2391
 
2051
- ## 3. Ask GPU Programming Questions (no GPU needed)
2392
+ Creates a sample PyTorch trace and runs SQL queries on it.
2052
2393
 
2053
- # Download CUDA documentation (one-time)
2054
- wafer corpus download cuda
2394
+ Example:
2395
+ wafer demo trace
2396
+ wafer demo trace -y # skip confirmation
2397
+ """
2398
+ import subprocess
2055
2399
 
2056
- # Ask a question
2057
- wafer wevin -t ask-docs --corpus cuda -s "What is warp divergence?"
2400
+ if not yes:
2401
+ typer.echo("This demo will:")
2402
+ typer.echo(" 1. Create a sample PyTorch-style trace")
2403
+ typer.echo(" 2. Run SQL queries to find slowest kernels")
2404
+ typer.echo("")
2405
+ if not typer.confirm("Continue?"):
2406
+ raise typer.Exit(0)
2058
2407
 
2059
- ## 4. Analyze a PyTorch Trace (no GPU needed)
2408
+ # Step 1: Setup demo data
2409
+ typer.echo("\n[1/2] Creating sample trace...")
2410
+ DEMO_DIR.mkdir(parents=True, exist_ok=True)
2411
+ sample_trace = DEMO_DIR / "sample_trace.json"
2412
+ sample_trace.write_text("""{
2413
+ "traceEvents": [
2414
+ {"name": "matmul_kernel", "cat": "kernel", "ph": "X", "ts": 0, "dur": 1500000, "pid": 1, "tid": 1},
2415
+ {"name": "relu_kernel", "cat": "kernel", "ph": "X", "ts": 1600000, "dur": 50000, "pid": 1, "tid": 1},
2416
+ {"name": "softmax_kernel", "cat": "kernel", "ph": "X", "ts": 1700000, "dur": 200000, "pid": 1, "tid": 1},
2417
+ {"name": "attention_kernel", "cat": "kernel", "ph": "X", "ts": 2000000, "dur": 3000000, "pid": 1, "tid": 1},
2418
+ {"name": "layernorm_kernel", "cat": "kernel", "ph": "X", "ts": 5100000, "dur": 100000, "pid": 1, "tid": 1}
2419
+ ]
2420
+ }""")
2421
+ typer.echo(f" Created: {sample_trace}")
2422
+
2423
+ # Step 2: Query the trace
2424
+ typer.echo("\n[2/2] Finding slowest kernels...\n")
2425
+ typer.echo("-" * 60)
2426
+ result = subprocess.run(
2427
+ [
2428
+ "wafer",
2429
+ "nvidia",
2430
+ "perfetto",
2431
+ "query",
2432
+ str(sample_trace),
2433
+ "SELECT name, dur/1e6 as duration_ms FROM slice ORDER BY dur DESC",
2434
+ ],
2435
+ check=False,
2436
+ )
2437
+ typer.echo("-" * 60)
2438
+
2439
+ if result.returncode == 0:
2440
+ typer.echo("\n✓ Demo complete! Try your own traces:")
2441
+ typer.echo(' wafer nvidia perfetto query <your_trace.json> "SELECT name, dur FROM slice"')
2442
+ typer.echo("")
2443
+ typer.echo(" Or use AI-assisted analysis:")
2444
+ typer.echo(' wafer agent -t trace-analyze --args trace=<your_trace.json> "What\'s slow?"')
2445
+ else:
2446
+ typer.echo("\n✗ Demo failed.")
2447
+ raise typer.Exit(1)
2448
+
2449
+
2450
+ @demo_app.command("eval")
2451
+ def demo_eval(
2452
+ yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
2453
+ ) -> None:
2454
+ """Demo: Evaluate a kernel on a cloud GPU.
2455
+
2456
+ Creates a workspace, runs a sample Triton kernel evaluation, and cleans up.
2457
+ Requires authentication (wafer login).
2060
2458
 
2061
- # Setup demo data
2062
- wafer demo setup
2459
+ Example:
2460
+ wafer demo eval
2461
+ wafer demo eval -y # skip confirmation
2462
+ """
2463
+ import subprocess
2464
+ import tempfile
2465
+ import time
2466
+
2467
+ from .auth import load_credentials
2468
+
2469
+ # Check auth first
2470
+ creds = load_credentials()
2471
+ if not creds:
2472
+ typer.echo("Error: Not authenticated. Run: wafer login")
2473
+ raise typer.Exit(1)
2474
+
2475
+ if not yes:
2476
+ typer.echo("This demo will:")
2477
+ typer.echo(" 1. Create a cloud GPU workspace (B200)")
2478
+ typer.echo(" 2. Generate and upload a sample Triton kernel")
2479
+ typer.echo(" 3. Run correctness + performance evaluation")
2480
+ typer.echo(" 4. Delete the workspace")
2481
+ typer.echo("")
2482
+ typer.echo(" Note: Workspace usage is billed. Demo takes ~2-3 minutes.")
2483
+ typer.echo("")
2484
+ if not typer.confirm("Continue?"):
2485
+ raise typer.Exit(0)
2486
+
2487
+ workspace_name = f"wafer-demo-{int(time.time()) % 100000}"
2488
+
2489
+ try:
2490
+ # Step 1: Create workspace
2491
+ typer.echo(f"\n[1/4] Creating workspace '{workspace_name}'...")
2492
+ result = subprocess.run(
2493
+ ["wafer", "workspaces", "create", workspace_name, "--gpu", "B200", "--json"],
2494
+ capture_output=True,
2495
+ text=True,
2496
+ check=True,
2497
+ )
2498
+ import json
2499
+
2500
+ ws_info = json.loads(result.stdout)
2501
+ workspace_id = ws_info.get("id", workspace_name)
2502
+ typer.echo(f" Created: {workspace_id}")
2503
+
2504
+ # Step 2: Generate kernel template
2505
+ typer.echo("\n[2/4] Generating sample kernel...")
2506
+ with tempfile.TemporaryDirectory() as tmpdir:
2507
+ kernel_dir = Path(tmpdir) / "demo-kernel"
2508
+ subprocess.run(
2509
+ ["wafer", "evaluate", "make-template", str(kernel_dir)],
2510
+ capture_output=True,
2511
+ check=True,
2512
+ )
2513
+ typer.echo(" Generated Triton vector-add kernel")
2063
2514
 
2064
- # Find slowest kernels
2065
- wafer nvidia perfetto query ~/.cache/wafer/demo/sample_trace.json \\
2066
- "SELECT name, dur/1e6 as ms FROM slice WHERE cat='kernel' ORDER BY dur DESC"
2515
+ # Step 3: Run evaluation
2516
+ typer.echo("\n[3/4] Running evaluation on cloud GPU...\n")
2517
+ typer.echo("-" * 60)
2067
2518
 
2068
- ---
2069
- For more details, run: wafer guide
2519
+ # Write a simple test script to avoid escaping hell
2520
+ test_script = kernel_dir / "run_test.py"
2521
+ test_script.write_text("""
2522
+ import torch
2523
+ import kernel
2524
+ import reference
2525
+
2526
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
2527
+
2528
+ # Test correctness
2529
+ inputs = reference.generate_input(n=1048576, seed=42)
2530
+ out = kernel.custom_kernel(inputs)
2531
+ ref = reference.ref_kernel(inputs)
2532
+ correct = torch.allclose(out, ref)
2533
+ print(f"Correctness: {correct}")
2534
+
2535
+ # Benchmark
2536
+ import time
2537
+ for _ in range(10):
2538
+ kernel.custom_kernel(inputs)
2539
+ torch.cuda.synchronize()
2540
+
2541
+ t0 = time.perf_counter()
2542
+ for _ in range(100):
2543
+ kernel.custom_kernel(inputs)
2544
+ torch.cuda.synchronize()
2545
+ t1 = time.perf_counter()
2546
+
2547
+ print(f"Performance: {(t1-t0)/100*1e6:.1f} us/iter")
2070
2548
  """)
2071
2549
 
2550
+ eval_result = subprocess.run(
2551
+ [
2552
+ "wafer",
2553
+ "workspaces",
2554
+ "exec",
2555
+ "--sync",
2556
+ str(kernel_dir),
2557
+ workspace_name,
2558
+ "--",
2559
+ "bash",
2560
+ "-c",
2561
+ "cd /workspace && uv pip install -q --system triton && python run_test.py",
2562
+ ],
2563
+ check=False,
2564
+ )
2565
+ typer.echo("-" * 60)
2566
+
2567
+ # Step 4: Cleanup
2568
+ typer.echo(f"\n[4/4] Deleting workspace '{workspace_name}'...")
2569
+ subprocess.run(
2570
+ ["wafer", "workspaces", "delete", workspace_id],
2571
+ capture_output=True,
2572
+ check=False,
2573
+ )
2574
+ typer.echo(" Deleted")
2575
+
2576
+ if eval_result.returncode == 0:
2577
+ typer.echo("\n✓ Demo complete! To evaluate your own kernels:")
2578
+ typer.echo("")
2579
+ typer.echo(" # Using workspaces (no setup required):")
2580
+ typer.echo(" wafer workspaces create dev --gpu B200")
2581
+ typer.echo(" wafer workspaces exec --sync ./my-kernel dev -- python my_test.py")
2582
+ typer.echo("")
2583
+ typer.echo(" # Or using wafer evaluate with a configured target:")
2584
+ typer.echo(" wafer evaluate make-template ./my-kernel")
2585
+ typer.echo(" wafer evaluate --impl ./my-kernel/kernel.py \\")
2586
+ typer.echo(" --reference ./my-kernel/reference.py \\")
2587
+ typer.echo(" --test-cases ./my-kernel/test_cases.json \\")
2588
+ typer.echo(" --target <your-target>")
2589
+ else:
2590
+ typer.echo("\n✗ Evaluation failed, but workspace was cleaned up.")
2591
+ raise typer.Exit(1)
2592
+
2593
+ except subprocess.CalledProcessError as e:
2594
+ error_msg = e.stderr.strip() if e.stderr else str(e)
2595
+ typer.echo(f"\n✗ Error: {error_msg}")
2596
+ # Try to cleanup on failure
2597
+ typer.echo(f"Attempting to cleanup workspace '{workspace_name}'...")
2598
+ subprocess.run(
2599
+ ["wafer", "workspaces", "delete", workspace_name],
2600
+ capture_output=True,
2601
+ check=False,
2602
+ )
2603
+ raise typer.Exit(1) from None
2604
+ except KeyboardInterrupt:
2605
+ typer.echo(f"\n\nInterrupted. Cleaning up workspace '{workspace_name}'...")
2606
+ subprocess.run(
2607
+ ["wafer", "workspaces", "delete", workspace_name],
2608
+ capture_output=True,
2609
+ check=False,
2610
+ )
2611
+ raise typer.Exit(1) from None
2612
+
2072
2613
 
2073
2614
  # =============================================================================
2074
2615
  # Targets subcommands
2075
2616
  # =============================================================================
2076
2617
 
2077
2618
  # Init subcommand group for interactive target setup
2078
- init_app = typer.Typer(help="Initialize a new target interactively")
2619
+ init_app = typer.Typer(
2620
+ help="""Initialize a new GPU target.
2621
+
2622
+ Choose based on your GPU access:
2623
+
2624
+ ssh Your own hardware via SSH
2625
+ runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
2626
+ digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
2627
+ )
2079
2628
  targets_app.add_typer(init_app, name="init")
2080
2629
 
2081
2630
 
@@ -2533,6 +3082,118 @@ def targets_pods() -> None:
2533
3082
  typer.echo()
2534
3083
 
2535
3084
 
3085
+ # =============================================================================
3086
+ # Billing commands
3087
+ # =============================================================================
3088
+
3089
+
3090
+ @billing_app.callback(invoke_without_command=True)
3091
+ def billing_usage(
3092
+ ctx: typer.Context,
3093
+ json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
3094
+ ) -> None:
3095
+ """Show current billing usage and subscription info.
3096
+
3097
+ Example:
3098
+ wafer billing
3099
+ wafer billing --json
3100
+ """
3101
+ # Only show usage if no subcommand was invoked
3102
+ if ctx.invoked_subcommand is not None:
3103
+ return
3104
+
3105
+ from .billing import get_usage
3106
+
3107
+ try:
3108
+ result = get_usage(json_output=json_output)
3109
+ typer.echo(result)
3110
+ except RuntimeError as e:
3111
+ typer.echo(f"Error: {e}", err=True)
3112
+ raise typer.Exit(1) from None
3113
+
3114
+
3115
+ @billing_app.command("topup")
3116
+ def billing_topup(
3117
+ amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
3118
+ no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
3119
+ ) -> None:
3120
+ """Add credits to your account.
3121
+
3122
+ Opens a Stripe checkout page to add credits. Default amount is $25.
3123
+
3124
+ Example:
3125
+ wafer billing topup # Add $25
3126
+ wafer billing topup 100 # Add $100
3127
+ wafer billing topup --no-browser # Print URL instead
3128
+ """
3129
+ import webbrowser
3130
+
3131
+ from .billing import create_topup, validate_topup_amount
3132
+
3133
+ # Convert dollars to cents
3134
+ amount_cents = amount * 100
3135
+
3136
+ # Validate amount client-side before API call
3137
+ try:
3138
+ validate_topup_amount(amount_cents)
3139
+ except ValueError as e:
3140
+ typer.echo(f"Error: {e}", err=True)
3141
+ raise typer.Exit(1) from None
3142
+
3143
+ try:
3144
+ result = create_topup(amount_cents)
3145
+ checkout_url = result.get("checkout_url")
3146
+
3147
+ if not checkout_url:
3148
+ typer.echo("Error: No checkout URL received from API", err=True)
3149
+ raise typer.Exit(1) from None
3150
+
3151
+ if no_browser:
3152
+ typer.echo(f"Complete your purchase at:\n{checkout_url}")
3153
+ else:
3154
+ typer.echo(f"Opening checkout for ${amount}...")
3155
+ webbrowser.open(checkout_url)
3156
+ typer.echo("Browser opened. Complete your purchase there.")
3157
+ except RuntimeError as e:
3158
+ typer.echo(f"Error: {e}", err=True)
3159
+ raise typer.Exit(1) from None
3160
+
3161
+
3162
+ @billing_app.command("portal")
3163
+ def billing_portal(
3164
+ no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
3165
+ ) -> None:
3166
+ """Open Stripe billing portal.
3167
+
3168
+ Manage your subscription, update payment method, or view invoices.
3169
+
3170
+ Example:
3171
+ wafer billing portal
3172
+ wafer billing portal --no-browser
3173
+ """
3174
+ import webbrowser
3175
+
3176
+ from .billing import get_portal_url
3177
+
3178
+ try:
3179
+ result = get_portal_url()
3180
+ portal_url = result.get("portal_url")
3181
+
3182
+ if not portal_url:
3183
+ typer.echo("Error: No portal URL received from API", err=True)
3184
+ raise typer.Exit(1) from None
3185
+
3186
+ if no_browser:
3187
+ typer.echo(f"Billing portal:\n{portal_url}")
3188
+ else:
3189
+ typer.echo("Opening billing portal...")
3190
+ webbrowser.open(portal_url)
3191
+ typer.echo("Browser opened.")
3192
+ except RuntimeError as e:
3193
+ typer.echo(f"Error: {e}", err=True)
3194
+ raise typer.Exit(1) from None
3195
+
3196
+
2536
3197
  # =============================================================================
2537
3198
  # Workspaces commands
2538
3199
  # =============================================================================
@@ -3767,6 +4428,29 @@ autotuner_app = typer.Typer(help="Hyperparameter sweep for performance engineeri
3767
4428
  app.add_typer(autotuner_app, name="autotuner", hidden=True)
3768
4429
 
3769
4430
 
4431
+ def _setup_wafer_core_env() -> None:
4432
+ """Set environment variables for wafer-core to use.
4433
+
4434
+ Call this before using any wafer-core functions that need API access.
4435
+
4436
+ Respects explicit environment variable overrides:
4437
+ - WAFER_API_URL: If already set, uses that instead of config
4438
+ - WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
4439
+ """
4440
+ from .global_config import get_api_url
4441
+ from .auth import get_valid_token
4442
+
4443
+ # Set API URL (get_api_url already respects WAFER_API_URL env var)
4444
+ os.environ["WAFER_API_URL"] = get_api_url()
4445
+
4446
+ # Only set auth token if not explicitly provided in environment
4447
+ # This allows CI/service accounts to override with their own tokens
4448
+ if "WAFER_AUTH_TOKEN" not in os.environ:
4449
+ token = get_valid_token()
4450
+ if token:
4451
+ os.environ["WAFER_AUTH_TOKEN"] = token
4452
+
4453
+
3770
4454
  @autotuner_app.command("list")
3771
4455
  def autotuner_list(
3772
4456
  show_all: bool = typer.Option(
@@ -3782,6 +4466,7 @@ def autotuner_list(
3782
4466
  wafer autotuner list
3783
4467
  wafer autotuner list --all
3784
4468
  """
4469
+ _setup_wafer_core_env()
3785
4470
  from .autotuner import list_command
3786
4471
 
3787
4472
  try:
@@ -3822,6 +4507,7 @@ def autotuner_delete(
3822
4507
  wafer autotuner delete --all --status pending
3823
4508
  wafer autotuner delete --all --status failed --yes
3824
4509
  """
4510
+ _setup_wafer_core_env()
3825
4511
  from .autotuner import delete_all_command, delete_command
3826
4512
 
3827
4513
  # Validate arguments
@@ -3888,6 +4574,7 @@ def autotuner_run(
3888
4574
  wafer autotuner run --resume <sweep-id>
3889
4575
  wafer autotuner run --resume <sweep-id> --parallel 8
3890
4576
  """
4577
+ _setup_wafer_core_env()
3891
4578
  from .autotuner import run_sweep_command
3892
4579
 
3893
4580
  # Validate arguments
@@ -4052,8 +4739,23 @@ def capture_command( # noqa: PLR0915
4052
4739
  wafer capture grid-search "python train.py --lr {LR} --bs {BS}" --sweep "LR=0.001,0.01,0.1" --sweep "BS=16,32"
4053
4740
  """
4054
4741
  import itertools
4742
+ import os
4055
4743
  import tomllib
4056
4744
 
4745
+ from .global_config import get_api_url
4746
+ from .auth import get_valid_token
4747
+
4748
+ # Set environment variables for wafer-core BEFORE importing it
4749
+ # wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
4750
+ os.environ["WAFER_API_URL"] = get_api_url()
4751
+
4752
+ # Only set auth token if not explicitly provided in environment
4753
+ # This allows CI/service accounts to override with their own tokens
4754
+ if "WAFER_AUTH_TOKEN" not in os.environ:
4755
+ token = get_valid_token()
4756
+ if token:
4757
+ os.environ["WAFER_AUTH_TOKEN"] = token
4758
+
4057
4759
  import trio
4058
4760
  from wafer_core.tools.capture_tool import ( # pragma: no cover
4059
4761
  CaptureConfig,
@@ -4243,6 +4945,20 @@ def capture_list_command(
4243
4945
  # Pagination
4244
4946
  wafer capture-list --limit 20 --offset 20
4245
4947
  """
4948
+ import os
4949
+
4950
+ from .global_config import get_api_url
4951
+ from .auth import get_valid_token
4952
+
4953
+ # Set environment variables for wafer-core BEFORE importing it
4954
+ os.environ["WAFER_API_URL"] = get_api_url()
4955
+
4956
+ # Only set auth token if not explicitly provided in environment
4957
+ # This allows CI/service accounts to override with their own tokens
4958
+ if "WAFER_AUTH_TOKEN" not in os.environ:
4959
+ token = get_valid_token()
4960
+ if token:
4961
+ os.environ["WAFER_AUTH_TOKEN"] = token
4246
4962
 
4247
4963
  import trio
4248
4964
  from wafer_core.utils.backend import list_captures # pragma: no cover