wafer-cli 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -18,9 +18,11 @@ Setup:
18
18
  config CLI configuration and local GPU targets
19
19
  """
20
20
 
21
+ import atexit
21
22
  import json
22
23
  import os
23
24
  import sys
25
+ import time
24
26
  from pathlib import Path
25
27
 
26
28
  import trio
@@ -28,12 +30,112 @@ import typer
28
30
 
29
31
  from .config import WaferConfig, WaferEnvironment
30
32
  from .inference import infer_upload_files, resolve_environment
33
+ from .problems import (
34
+ download_problems,
35
+ get_problem_path,
36
+ get_problems_path,
37
+ )
38
+ from .problems import (
39
+ list_problems as list_problems_fn,
40
+ )
31
41
 
32
42
  app = typer.Typer(
33
43
  help="GPU development toolkit for LLM coding agents",
34
44
  no_args_is_help=True,
35
45
  )
36
46
 
47
+ # =============================================================================
48
+ # Analytics tracking
49
+ # =============================================================================
50
+
51
+ # Track command start time for duration calculation
52
+ _command_start_time: float | None = None
53
+ # Track command outcome (defaults to failure, set to success on clean exit)
54
+ _command_outcome: str = "failure"
55
+
56
+
57
+ def _get_command_path(ctx: typer.Context) -> tuple[str, str | None]:
58
+ """Extract command and subcommand from Typer context.
59
+
60
+ Returns:
61
+ Tuple of (command, subcommand). subcommand may be None.
62
+ """
63
+ # Build command path from invoked subcommand chain
64
+ invoked = ctx.invoked_subcommand
65
+ info_name = ctx.info_name or ""
66
+
67
+ # Get parent command if exists
68
+ parent_cmd = None
69
+ if ctx.parent and ctx.parent.info_name and ctx.parent.info_name != "wafer":
70
+ parent_cmd = ctx.parent.info_name
71
+
72
+ if parent_cmd:
73
+ return parent_cmd, info_name
74
+ return info_name or "unknown", invoked
75
+
76
+
77
+ def _mark_command_success() -> None:
78
+ """Mark the current command as successful.
79
+
80
+ Call this at the end of successful command execution.
81
+ Commands that raise typer.Exit(1) or exceptions will remain marked as failures.
82
+ """
83
+ global _command_outcome
84
+ _command_outcome = "success"
85
+
86
+
87
+ @app.callback()
88
+ def main_callback(ctx: typer.Context) -> None:
89
+ """Initialize analytics and track command execution."""
90
+ global _command_start_time, _command_outcome
91
+ _command_start_time = time.time()
92
+ _command_outcome = "success" # Default to success, mark failure on exceptions
93
+
94
+ # Initialize analytics (lazy import to avoid slowing down --help)
95
+ from . import analytics
96
+
97
+ analytics.init_analytics()
98
+
99
+ # Install exception hook to catch SystemExit and mark failures
100
+ original_excepthook = sys.excepthook
101
+
102
+ def custom_excepthook(exc_type, exc_value, exc_traceback):
103
+ global _command_outcome
104
+ # Mark as failure if SystemExit with non-zero code, or any other exception
105
+ if exc_type is SystemExit:
106
+ exit_code = exc_value.code if hasattr(exc_value, "code") else 1
107
+ if exit_code != 0 and exit_code is not None:
108
+ _command_outcome = "failure"
109
+ else:
110
+ _command_outcome = "failure"
111
+ # Call original excepthook
112
+ original_excepthook(exc_type, exc_value, exc_traceback)
113
+
114
+ sys.excepthook = custom_excepthook
115
+
116
+ # Register tracking at exit to capture command outcome
117
+ def track_on_exit() -> None:
118
+ command, subcommand = _get_command_path(ctx)
119
+
120
+ # Skip tracking for --help and --version
121
+ if ctx.resilient_parsing:
122
+ return
123
+
124
+ # Calculate duration
125
+ duration_ms = None
126
+ if _command_start_time is not None:
127
+ duration_ms = int((time.time() - _command_start_time) * 1000)
128
+
129
+ # Track the command execution with the recorded outcome
130
+ analytics.track_command(
131
+ command=command,
132
+ subcommand=subcommand,
133
+ outcome=_command_outcome,
134
+ duration_ms=duration_ms,
135
+ )
136
+
137
+ atexit.register(track_on_exit)
138
+
37
139
 
38
140
  # =============================================================================
39
141
  # Autocompletion helpers
@@ -106,6 +208,13 @@ kernelbench_app = typer.Typer(
106
208
  )
107
209
  evaluate_app.add_typer(kernelbench_app, name="kernelbench")
108
210
 
211
+ # Nested subcommand for gpumode format
212
+ gpumode_app = typer.Typer(
213
+ help="Evaluate kernels in GPUMode format (custom_kernel/ref_kernel functions)",
214
+ invoke_without_command=True,
215
+ )
216
+ evaluate_app.add_typer(gpumode_app, name="gpumode")
217
+
109
218
  # =============================================================================
110
219
  # Dev commands (internal, used by web app proxy)
111
220
  # =============================================================================
@@ -302,6 +411,124 @@ def skill_status() -> None:
302
411
  typer.echo(f"{tool_name}: Not installed")
303
412
 
304
413
 
414
+ # =============================================================================
415
+ # Provider auth management (wafer auth ...)
416
+ # =============================================================================
417
+
418
+ provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
419
+ app.add_typer(provider_auth_app, name="auth")
420
+
421
+
422
+ @provider_auth_app.command("login")
423
+ def provider_auth_login(
424
+ provider: str = typer.Argument(
425
+ ...,
426
+ help="Provider name: runpod, digitalocean, or modal",
427
+ ),
428
+ api_key: str | None = typer.Option(
429
+ None,
430
+ "--api-key",
431
+ "-k",
432
+ help="API key (if not provided, reads from stdin)",
433
+ ),
434
+ ) -> None:
435
+ """Save API key for a cloud GPU provider.
436
+
437
+ Stores the key in ~/.wafer/auth.json. Environment variables
438
+ (e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
439
+
440
+ Examples:
441
+ wafer auth login runpod --api-key rp_xxx
442
+ wafer auth login digitalocean --api-key dop_v1_xxx
443
+ echo $API_KEY | wafer auth login runpod
444
+ """
445
+ import sys
446
+
447
+ from wafer_core.auth import PROVIDERS, save_api_key
448
+
449
+ # Validate provider
450
+ if provider not in PROVIDERS:
451
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
452
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
453
+ raise typer.Exit(1)
454
+
455
+ # Get API key from option or stdin
456
+ if api_key is None:
457
+ if sys.stdin.isatty():
458
+ typer.echo(f"Enter API key for {PROVIDERS[provider]['display_name']}:")
459
+ api_key = typer.prompt("API key", hide_input=True)
460
+ else:
461
+ api_key = sys.stdin.read().strip()
462
+
463
+ if not api_key:
464
+ typer.echo("Error: No API key provided", err=True)
465
+ raise typer.Exit(1)
466
+
467
+ # Save the key
468
+ save_api_key(provider, api_key)
469
+ typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
470
+ typer.echo(f"Stored in: ~/.wafer/auth.json")
471
+
472
+
473
+ @provider_auth_app.command("logout")
474
+ def provider_auth_logout(
475
+ provider: str = typer.Argument(
476
+ ...,
477
+ help="Provider name: runpod, digitalocean, or modal",
478
+ ),
479
+ ) -> None:
480
+ """Remove stored API key for a cloud GPU provider.
481
+
482
+ Examples:
483
+ wafer auth logout runpod
484
+ wafer auth logout digitalocean
485
+ """
486
+ from wafer_core.auth import PROVIDERS, remove_api_key
487
+
488
+ # Validate provider
489
+ if provider not in PROVIDERS:
490
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
491
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
492
+ raise typer.Exit(1)
493
+
494
+ if remove_api_key(provider):
495
+ typer.echo(f"API key removed for {PROVIDERS[provider]['display_name']}")
496
+ else:
497
+ typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
498
+
499
+
500
+ @provider_auth_app.command("status")
501
+ def provider_auth_status() -> None:
502
+ """Show authentication status for all cloud GPU providers.
503
+
504
+ Displays which providers have API keys configured and where
505
+ the keys are coming from (environment variable or auth.json).
506
+
507
+ Example:
508
+ wafer auth status
509
+ """
510
+ from wafer_core.auth import get_all_auth_status
511
+
512
+ statuses = get_all_auth_status()
513
+
514
+ typer.echo("Cloud GPU Provider Authentication Status")
515
+ typer.echo("=" * 45)
516
+
517
+ for status in statuses:
518
+ if status.is_authenticated:
519
+ source_str = f"({status.source})" if status.source else ""
520
+ typer.echo(
521
+ f" {status.display_name}: ✓ {status.key_preview} {source_str}"
522
+ )
523
+ else:
524
+ typer.echo(f" {status.display_name}: ✗ Not configured")
525
+ typer.echo(f" Run: wafer auth login {status.provider}")
526
+ typer.echo(f" Or set: {status.key_url}")
527
+
528
+ typer.echo("")
529
+ typer.echo("Note: Environment variables take precedence over stored keys.")
530
+
531
+
305
532
  @app.command(hidden=True)
306
533
  def run(
307
534
  command: str = typer.Argument(..., help="Command to run in Docker container"),
@@ -1195,13 +1422,25 @@ def evaluate( # noqa: PLR0913
1195
1422
  --benchmark --defensive
1196
1423
 
1197
1424
  Subcommands:
1198
- make-template Generate template files for this format
1425
+ gpumode Use GPUMode format (functional) - RECOMMENDED
1199
1426
  kernelbench Use KernelBench format (ModelNew class)
1427
+ make-template Generate template files for this format (deprecated)
1200
1428
  """
1201
1429
  # If a subcommand is being invoked, skip the main evaluation logic
1202
1430
  if ctx.invoked_subcommand is not None:
1203
1431
  return
1204
1432
 
1433
+ # Deprecation warning for bare evaluate
1434
+ typer.echo(
1435
+ "⚠️ Deprecation warning: 'wafer evaluate' will be removed in a future version.",
1436
+ err=True,
1437
+ )
1438
+ typer.echo(
1439
+ " Use 'wafer evaluate gpumode' instead for the functional format.",
1440
+ err=True,
1441
+ )
1442
+ typer.echo("", err=True)
1443
+
1205
1444
  # Validate required args when running evaluation (not subcommands)
1206
1445
  missing_args = []
1207
1446
  if implementation is None:
@@ -1216,12 +1455,12 @@ def evaluate( # noqa: PLR0913
1216
1455
  typer.echo(f" Required: {', '.join(missing_args)}", err=True)
1217
1456
  typer.echo("", err=True)
1218
1457
  typer.echo(
1219
- "Usage: wafer evaluate --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
1458
+ "Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
1220
1459
  err=True,
1221
1460
  )
1222
1461
  typer.echo("", err=True)
1223
- typer.echo("Run 'wafer evaluate --help' for full options.", err=True)
1224
- typer.echo("Run 'wafer evaluate make-template DIR' to generate starter files.", err=True)
1462
+ typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
1463
+ typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
1225
1464
  raise typer.Exit(1)
1226
1465
 
1227
1466
  from .evaluate import EvaluateArgs, run_evaluate
@@ -1409,8 +1648,59 @@ def evaluate_make_template(
1409
1648
  # KernelBench format evaluation
1410
1649
  # =============================================================================
1411
1650
 
1412
- # Path to KernelBench problems (relative to wafer root)
1413
- KERNELBENCH_ROOT = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench"
1651
+
1652
+ def _get_kernelbench_root() -> Path | None:
1653
+ """Get KernelBench problems root, preferring downloaded location."""
1654
+ # First check downloaded location
1655
+ downloaded = get_problems_path("kernelbench")
1656
+ if downloaded is not None:
1657
+ kb_root = downloaded / "KernelBench"
1658
+ if kb_root.exists():
1659
+ return kb_root
1660
+ return downloaded
1661
+
1662
+ # Fall back to legacy location (for development)
1663
+ legacy = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench" / "KernelBench"
1664
+ if legacy.exists():
1665
+ return legacy
1666
+
1667
+ return None
1668
+
1669
+
1670
+ @kernelbench_app.command("download")
1671
+ def kernelbench_download(
1672
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
1673
+ ) -> None:
1674
+ """Download KernelBench problems from GitHub.
1675
+
1676
+ Downloads the problem set to ~/.cache/wafer/problems/kernelbench/
1677
+
1678
+ Examples:
1679
+ wafer evaluate kernelbench download
1680
+ wafer evaluate kernelbench download --force # Re-download
1681
+ """
1682
+ try:
1683
+ path = download_problems("kernelbench", force=force, verbose=True)
1684
+ typer.echo("")
1685
+ typer.echo(f"Problems available at: {path}")
1686
+ typer.echo("Run 'wafer evaluate kernelbench list-problems' to see available problems.")
1687
+ except Exception as e:
1688
+ typer.echo(f"Error downloading problems: {e}", err=True)
1689
+ raise typer.Exit(1) from None
1690
+
1691
+
1692
+ @kernelbench_app.command("list-problems")
1693
+ def kernelbench_list_problems() -> None:
1694
+ """List available KernelBench problems.
1695
+
1696
+ Examples:
1697
+ wafer evaluate kernelbench list-problems
1698
+ """
1699
+ try:
1700
+ list_problems_fn("kernelbench", verbose=True)
1701
+ except ValueError as e:
1702
+ typer.echo(str(e), err=True)
1703
+ raise typer.Exit(1) from None
1414
1704
 
1415
1705
 
1416
1706
  @kernelbench_app.callback(invoke_without_command=True)
@@ -1436,6 +1726,10 @@ def kernelbench_evaluate( # noqa: PLR0913
1436
1726
  ),
1437
1727
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
1438
1728
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
1729
+ inputs: Path | None = typer.Option(
1730
+ None, "--inputs", help="Custom inputs file to override get_inputs()"
1731
+ ),
1732
+ seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
1439
1733
  defensive: bool = typer.Option(
1440
1734
  False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
1441
1735
  ),
@@ -1500,6 +1794,8 @@ def kernelbench_evaluate( # noqa: PLR0913
1500
1794
  target_name=target or "",
1501
1795
  benchmark=benchmark,
1502
1796
  profile=profile,
1797
+ inputs=inputs,
1798
+ seed=seed,
1503
1799
  defensive=defensive,
1504
1800
  sync_artifacts=sync_artifacts,
1505
1801
  gpu_id=gpu_id,
@@ -1561,6 +1857,13 @@ def kernelbench_make_template(
1561
1857
  # Overwrite existing
1562
1858
  wafer evaluate kernelbench make-template level1/1 --force
1563
1859
  """
1860
+ # Get problems root (downloaded or legacy)
1861
+ kb_root = _get_kernelbench_root()
1862
+ if kb_root is None:
1863
+ typer.echo("Error: KernelBench problems not found.", err=True)
1864
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1865
+ raise typer.Exit(1)
1866
+
1564
1867
  # Parse problem ID
1565
1868
  parts = problem.split("/")
1566
1869
  if len(parts) != 2:
@@ -1572,10 +1875,10 @@ def kernelbench_make_template(
1572
1875
  level_str = f"level{level_str}"
1573
1876
 
1574
1877
  # Find the problem file
1575
- problem_dir = KERNELBENCH_ROOT / "KernelBench" / level_str
1878
+ problem_dir = kb_root / level_str
1576
1879
  if not problem_dir.exists():
1577
1880
  typer.echo(f"Error: KernelBench level directory not found: {problem_dir}", err=True)
1578
- typer.echo(f"Make sure KernelBench is at: {KERNELBENCH_ROOT}", err=True)
1881
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1579
1882
  raise typer.Exit(1)
1580
1883
 
1581
1884
  # Find matching problem file
@@ -1642,6 +1945,252 @@ def kernelbench_make_template(
1642
1945
  typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
1643
1946
 
1644
1947
 
1948
+ # =============================================================================
1949
+ # GPUMode format evaluation
1950
+ # =============================================================================
1951
+
1952
+
1953
+ @gpumode_app.command("download")
1954
+ def gpumode_download(
1955
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
1956
+ ) -> None:
1957
+ """Download GPUMode reference kernels from GitHub.
1958
+
1959
+ Downloads the problem set to ~/.cache/wafer/problems/gpumode/
1960
+
1961
+ Examples:
1962
+ wafer evaluate gpumode download
1963
+ wafer evaluate gpumode download --force # Re-download
1964
+ """
1965
+ try:
1966
+ path = download_problems("gpumode", force=force, verbose=True)
1967
+ typer.echo("")
1968
+ typer.echo(f"Problems available at: {path}")
1969
+ typer.echo("Run 'wafer evaluate gpumode list-problems' to see available problems.")
1970
+ except Exception as e:
1971
+ typer.echo(f"Error downloading problems: {e}", err=True)
1972
+ raise typer.Exit(1) from None
1973
+
1974
+
1975
+ @gpumode_app.command("list-problems")
1976
+ def gpumode_list_problems() -> None:
1977
+ """List available GPUMode problems.
1978
+
1979
+ Examples:
1980
+ wafer evaluate gpumode list-problems
1981
+ """
1982
+ try:
1983
+ list_problems_fn("gpumode", verbose=True)
1984
+ except ValueError as e:
1985
+ typer.echo(str(e), err=True)
1986
+ raise typer.Exit(1) from None
1987
+
1988
+
1989
+ @gpumode_app.command("make-template")
1990
+ def gpumode_make_template(
1991
+ problem: str = typer.Option(
1992
+ ...,
1993
+ "--problem",
1994
+ "-p",
1995
+ help="Problem ID (e.g., 'pmpp/vectoradd_py' or 'amd/fp8-mm')",
1996
+ ),
1997
+ output: Path = typer.Option(
1998
+ None, "--output", "-o", help="Output directory (default: ./<problem_name>/)"
1999
+ ),
2000
+ force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
2001
+ ) -> None:
2002
+ """Extract a GPUMode problem as template files.
2003
+
2004
+ Creates a directory with reference.py, task.yml, and other problem files.
2005
+ You then create kernel.py with your custom_kernel implementation.
2006
+
2007
+ Examples:
2008
+ # Extract pmpp vectoradd problem
2009
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py
2010
+
2011
+ # Extract to specific directory
2012
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py --output ./my-kernel/
2013
+ """
2014
+ import shutil
2015
+
2016
+ # Get problem path
2017
+ problem_path = get_problem_path("gpumode", problem)
2018
+ if problem_path is None:
2019
+ # Check if problems are downloaded
2020
+ if get_problems_path("gpumode") is None:
2021
+ typer.echo("Error: GPUMode problems not downloaded.", err=True)
2022
+ typer.echo("Run 'wafer evaluate gpumode download' first.", err=True)
2023
+ else:
2024
+ typer.echo(f"Error: Problem '{problem}' not found.", err=True)
2025
+ typer.echo(
2026
+ "Run 'wafer evaluate gpumode list-problems' to see available problems.", err=True
2027
+ )
2028
+ raise typer.Exit(1)
2029
+
2030
+ # Determine output path
2031
+ if output is None:
2032
+ output = Path.cwd() / problem.replace("/", "_")
2033
+
2034
+ output = output.resolve()
2035
+
2036
+ # Check if exists
2037
+ if output.exists() and not force:
2038
+ typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
2039
+ raise typer.Exit(1)
2040
+
2041
+ # Copy the problem directory
2042
+ if output.exists():
2043
+ shutil.rmtree(output)
2044
+ shutil.copytree(problem_path, output)
2045
+
2046
+ typer.echo(f"Created {output}/")
2047
+ typer.echo("")
2048
+ typer.echo("Contents:")
2049
+ for f in sorted(output.iterdir()):
2050
+ if not f.name.startswith("."):
2051
+ typer.echo(f" {f.name}")
2052
+ typer.echo("")
2053
+ typer.echo("Next steps:")
2054
+ typer.echo(" 1. Read reference.py to understand the kernel interface")
2055
+ typer.echo(" 2. Create kernel.py with your custom_kernel implementation:")
2056
+ typer.echo("")
2057
+ typer.echo(" def custom_kernel(data):")
2058
+ typer.echo(" # Your optimized implementation")
2059
+ typer.echo(" ...")
2060
+ typer.echo("")
2061
+ typer.echo(" 3. Run evaluation:")
2062
+ typer.echo(
2063
+ f" wafer evaluate gpumode --impl {output}/kernel.py --reference {output}/reference.py \\"
2064
+ )
2065
+ typer.echo(f" --test-cases {output}/test_cases.json --target <target>")
2066
+
2067
+
2068
+ @gpumode_app.callback(invoke_without_command=True)
2069
+ def gpumode_evaluate( # noqa: PLR0913
2070
+ ctx: typer.Context,
2071
+ implementation: Path | None = typer.Option(
2072
+ None, "--impl", "-i", help="Path to implementation kernel file"
2073
+ ),
2074
+ reference: Path | None = typer.Option(
2075
+ None, "--reference", help="Path to reference kernel file"
2076
+ ),
2077
+ test_cases: Path | None = typer.Option(
2078
+ None, "--test-cases", help="Path to test cases JSON file"
2079
+ ),
2080
+ target: str | None = typer.Option(
2081
+ None,
2082
+ "--target",
2083
+ "-t",
2084
+ help="GPU target name. See 'wafer config targets list' for available targets.",
2085
+ autocompletion=complete_target_name,
2086
+ ),
2087
+ benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
2088
+ profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
2089
+ defensive: bool = typer.Option(
2090
+ False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
2091
+ ),
2092
+ sync_artifacts: bool = typer.Option(
2093
+ True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
2094
+ ),
2095
+ gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
2096
+ ) -> None:
2097
+ """Run kernel evaluation in GPUMode format (functional).
2098
+
2099
+ This format expects:
2100
+ - Implementation: Python file with `custom_kernel(inputs)` function
2101
+ - Reference: Python file with `ref_kernel(inputs)` and `generate_input(**kwargs)` functions
2102
+ - Test cases: JSON file with test parameters
2103
+
2104
+ Examples:
2105
+ # Basic correctness check
2106
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
2107
+
2108
+ # With benchmarking
2109
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
2110
+ --target vultr-b200 --benchmark
2111
+
2112
+ Subcommands:
2113
+ download Download GPUMode problems from GitHub
2114
+ list-problems List available problems
2115
+ make-template Extract a problem as template files
2116
+ """
2117
+ # If a subcommand is being invoked, skip the main evaluation logic
2118
+ if ctx.invoked_subcommand is not None:
2119
+ return
2120
+
2121
+ # Validate required args when running evaluation (not subcommands)
2122
+ missing_args = []
2123
+ if implementation is None:
2124
+ missing_args.append("--impl/-i")
2125
+ if reference is None:
2126
+ missing_args.append("--reference")
2127
+ if test_cases is None:
2128
+ missing_args.append("--test-cases")
2129
+
2130
+ if missing_args:
2131
+ typer.echo("Error: Missing required arguments", err=True)
2132
+ typer.echo(f" Required: {', '.join(missing_args)}", err=True)
2133
+ typer.echo("", err=True)
2134
+ typer.echo(
2135
+ "Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
2136
+ err=True,
2137
+ )
2138
+ typer.echo("", err=True)
2139
+ typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
2140
+ typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
2141
+ raise typer.Exit(1)
2142
+
2143
+ # Reuse the existing evaluate logic (same format)
2144
+ from .evaluate import EvaluateArgs, run_evaluate
2145
+
2146
+ args = EvaluateArgs(
2147
+ implementation=implementation,
2148
+ reference=reference,
2149
+ test_cases=test_cases,
2150
+ target_name=target or "",
2151
+ benchmark=benchmark,
2152
+ profile=profile,
2153
+ defensive=defensive,
2154
+ sync_artifacts=sync_artifacts,
2155
+ gpu_id=gpu_id,
2156
+ )
2157
+
2158
+ try:
2159
+ import trio_asyncio
2160
+
2161
+ result = trio_asyncio.run(run_evaluate, args)
2162
+ except KeyboardInterrupt:
2163
+ typer.echo("\nInterrupted by user", err=True)
2164
+ raise typer.Exit(130) from None
2165
+ except Exception as e:
2166
+ if hasattr(e, "exceptions") and e.exceptions:
2167
+ for exc in e.exceptions:
2168
+ typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
2169
+ else:
2170
+ typer.echo(f"Error: {e}", err=True)
2171
+ raise typer.Exit(1) from None
2172
+
2173
+ # Print results
2174
+ if result.success:
2175
+ typer.echo("")
2176
+ typer.echo("=" * 60)
2177
+ status = "PASS" if result.all_correct else "FAIL"
2178
+ typer.echo(f"Result: {status}")
2179
+ score_pct = f"{result.correctness_score:.1%}"
2180
+ typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
2181
+ if result.geomean_speedup > 0:
2182
+ typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
2183
+ if result.artifact_path:
2184
+ typer.echo(f"Artifacts: {result.artifact_path}")
2185
+ typer.echo("=" * 60)
2186
+
2187
+ if not result.all_correct:
2188
+ raise typer.Exit(1)
2189
+ else:
2190
+ typer.echo(f"Error: {result.error_message}", err=True)
2191
+ raise typer.Exit(1)
2192
+
2193
+
1645
2194
  # =============================================================================
1646
2195
  # Push and Remote-Run commands
1647
2196
  # =============================================================================
@@ -1773,7 +2322,7 @@ def _run_direct_mode(
1773
2322
  typer.echo(f"Uploading {upload_dir.name}...")
1774
2323
  try:
1775
2324
  push_result = push_direct(upload_dir, target)
1776
- workspace_name = push_result.workspace_path
2325
+ workspace_name = push_result.workspace_name
1777
2326
  typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
1778
2327
  except Exception as e:
1779
2328
  typer.echo(f"Error uploading: {e}", err=True)
@@ -1945,17 +2494,34 @@ def login(
1945
2494
  token: str | None = typer.Option(
1946
2495
  None, "--token", "-t", help="Access token (skip browser OAuth)"
1947
2496
  ),
2497
+ port: int | None = typer.Option(
2498
+ None,
2499
+ "--port",
2500
+ "-p",
2501
+ help="Port for OAuth callback server (default: 8765 for SSH, random for local)",
2502
+ ),
1948
2503
  ) -> None:
1949
2504
  """Authenticate CLI with wafer-api via GitHub OAuth.
1950
2505
 
1951
2506
  Opens browser for GitHub authentication. Use --token to skip browser.
1952
2507
  Uses the API environment from config (see 'wafer config show').
1953
2508
 
2509
+ SSH Users:
2510
+ - Automatically uses port 8765 (just set up port forwarding once)
2511
+ - On local machine: ssh -L 8765:localhost:8765 user@host
2512
+ - On remote machine: wafer login
2513
+ - Browser opens locally, redirect works through tunnel
2514
+
2515
+ Manual token option:
2516
+ - Visit auth.wafer.ai, authenticate, copy token from URL
2517
+ - Run: wafer login --token <paste-token>
2518
+
1954
2519
  Examples:
1955
- wafer login # opens browser for GitHub OAuth
1956
- wafer login --token xyz # use existing token
2520
+ wafer login # auto-detects SSH, uses appropriate port
2521
+ wafer login --port 9000 # override port
2522
+ wafer login --token xyz # manual token (no browser)
1957
2523
 
1958
- # To login to a different environment:
2524
+ # Change environment:
1959
2525
  wafer config set api.environment staging
1960
2526
  wafer login
1961
2527
  """
@@ -1971,11 +2537,21 @@ def login(
1971
2537
  typer.echo(f"Auth: {get_supabase_url()}")
1972
2538
  typer.echo("")
1973
2539
 
2540
+ # Auto-detect SSH and use fixed port
2541
+ if port is None:
2542
+ is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
2543
+ if is_ssh:
2544
+ port = 8765
2545
+ typer.echo("🔒 SSH session detected - using port 8765 for OAuth callback")
2546
+ typer.echo(" Make sure you have port forwarding set up:")
2547
+ typer.echo(" ssh -L 8765:localhost:8765 user@host")
2548
+ typer.echo("")
2549
+
1974
2550
  # Browser OAuth if no token provided
1975
2551
  refresh_token = None
1976
2552
  if token is None:
1977
2553
  try:
1978
- token, refresh_token = browser_login()
2554
+ token, refresh_token = browser_login(port=port)
1979
2555
  except TimeoutError as e:
1980
2556
  typer.echo(f"Error: {e}", err=True)
1981
2557
  raise typer.Exit(1) from None
@@ -2009,6 +2585,11 @@ def login(
2009
2585
  # Save credentials (with refresh token if available)
2010
2586
  save_credentials(token, refresh_token, user_info.email)
2011
2587
 
2588
+ # Track login event with analytics
2589
+ from . import analytics
2590
+
2591
+ analytics.track_login(user_info.user_id, user_info.email)
2592
+
2012
2593
  if user_info.email:
2013
2594
  typer.echo(f"Logged in as {user_info.email}")
2014
2595
  else:
@@ -2019,8 +2600,14 @@ def login(
2019
2600
  @app.command("logout")
2020
2601
  def logout() -> None:
2021
2602
  """Remove stored credentials."""
2603
+ from . import analytics
2022
2604
  from .auth import clear_credentials
2023
2605
 
2606
+ # Track logout event first (while credentials still exist for user identification)
2607
+ # Note: track_logout() handles the case where user is not logged in
2608
+ analytics.track_logout()
2609
+
2610
+ # Clear credentials and report result
2024
2611
  if clear_credentials():
2025
2612
  typer.echo("Logged out. Credentials removed.")
2026
2613
  else:
@@ -2985,7 +3572,9 @@ def billing_usage(
2985
3572
  @billing_app.command("topup")
2986
3573
  def billing_topup(
2987
3574
  amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
2988
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
3575
+ no_browser: bool = typer.Option(
3576
+ False, "--no-browser", help="Print URL instead of opening browser"
3577
+ ),
2989
3578
  ) -> None:
2990
3579
  """Add credits to your account.
2991
3580
 
@@ -3031,7 +3620,9 @@ def billing_topup(
3031
3620
 
3032
3621
  @billing_app.command("portal")
3033
3622
  def billing_portal(
3034
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
3623
+ no_browser: bool = typer.Option(
3624
+ False, "--no-browser", help="Print URL instead of opening browser"
3625
+ ),
3035
3626
  ) -> None:
3036
3627
  """Open Stripe billing portal.
3037
3628
 
@@ -4298,6 +4889,29 @@ autotuner_app = typer.Typer(help="Hyperparameter sweep for performance engineeri
4298
4889
  app.add_typer(autotuner_app, name="autotuner", hidden=True)
4299
4890
 
4300
4891
 
4892
+ def _setup_wafer_core_env() -> None:
4893
+ """Set environment variables for wafer-core to use.
4894
+
4895
+ Call this before using any wafer-core functions that need API access.
4896
+
4897
+ Respects explicit environment variable overrides:
4898
+ - WAFER_API_URL: If already set, uses that instead of config
4899
+ - WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
4900
+ """
4901
+ from .auth import get_valid_token
4902
+ from .global_config import get_api_url
4903
+
4904
+ # Set API URL (get_api_url already respects WAFER_API_URL env var)
4905
+ os.environ["WAFER_API_URL"] = get_api_url()
4906
+
4907
+ # Only set auth token if not explicitly provided in environment
4908
+ # This allows CI/service accounts to override with their own tokens
4909
+ if "WAFER_AUTH_TOKEN" not in os.environ:
4910
+ token = get_valid_token()
4911
+ if token:
4912
+ os.environ["WAFER_AUTH_TOKEN"] = token
4913
+
4914
+
4301
4915
  @autotuner_app.command("list")
4302
4916
  def autotuner_list(
4303
4917
  show_all: bool = typer.Option(
@@ -4313,6 +4927,7 @@ def autotuner_list(
4313
4927
  wafer autotuner list
4314
4928
  wafer autotuner list --all
4315
4929
  """
4930
+ _setup_wafer_core_env()
4316
4931
  from .autotuner import list_command
4317
4932
 
4318
4933
  try:
@@ -4353,6 +4968,7 @@ def autotuner_delete(
4353
4968
  wafer autotuner delete --all --status pending
4354
4969
  wafer autotuner delete --all --status failed --yes
4355
4970
  """
4971
+ _setup_wafer_core_env()
4356
4972
  from .autotuner import delete_all_command, delete_command
4357
4973
 
4358
4974
  # Validate arguments
@@ -4419,6 +5035,7 @@ def autotuner_run(
4419
5035
  wafer autotuner run --resume <sweep-id>
4420
5036
  wafer autotuner run --resume <sweep-id> --parallel 8
4421
5037
  """
5038
+ _setup_wafer_core_env()
4422
5039
  from .autotuner import run_sweep_command
4423
5040
 
4424
5041
  # Validate arguments
@@ -4583,8 +5200,23 @@ def capture_command( # noqa: PLR0915
4583
5200
  wafer capture grid-search "python train.py --lr {LR} --bs {BS}" --sweep "LR=0.001,0.01,0.1" --sweep "BS=16,32"
4584
5201
  """
4585
5202
  import itertools
5203
+ import os
4586
5204
  import tomllib
4587
5205
 
5206
+ from .auth import get_valid_token
5207
+ from .global_config import get_api_url
5208
+
5209
+ # Set environment variables for wafer-core BEFORE importing it
5210
+ # wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
5211
+ os.environ["WAFER_API_URL"] = get_api_url()
5212
+
5213
+ # Only set auth token if not explicitly provided in environment
5214
+ # This allows CI/service accounts to override with their own tokens
5215
+ if "WAFER_AUTH_TOKEN" not in os.environ:
5216
+ token = get_valid_token()
5217
+ if token:
5218
+ os.environ["WAFER_AUTH_TOKEN"] = token
5219
+
4588
5220
  import trio
4589
5221
  from wafer_core.tools.capture_tool import ( # pragma: no cover
4590
5222
  CaptureConfig,
@@ -4774,6 +5406,20 @@ def capture_list_command(
4774
5406
  # Pagination
4775
5407
  wafer capture-list --limit 20 --offset 20
4776
5408
  """
5409
+ import os
5410
+
5411
+ from .auth import get_valid_token
5412
+ from .global_config import get_api_url
5413
+
5414
+ # Set environment variables for wafer-core BEFORE importing it
5415
+ os.environ["WAFER_API_URL"] = get_api_url()
5416
+
5417
+ # Only set auth token if not explicitly provided in environment
5418
+ # This allows CI/service accounts to override with their own tokens
5419
+ if "WAFER_AUTH_TOKEN" not in os.environ:
5420
+ token = get_valid_token()
5421
+ if token:
5422
+ os.environ["WAFER_AUTH_TOKEN"] = token
4777
5423
 
4778
5424
  import trio
4779
5425
  from wafer_core.utils.backend import list_captures # pragma: no cover