wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/GUIDE.md ADDED
@@ -0,0 +1,118 @@
1
+ # Wafer CLI Guide
2
+
3
+ GPU development primitives for LLM agents.
4
+
5
+ ## Quick Start: Cloud GPU (No Setup)
6
+
7
+ Run code on cloud GPUs instantly with workspaces:
8
+
9
+ ```bash
10
+ wafer login # One-time auth
11
+ wafer workspaces create dev --gpu B200 # Create workspace (NVIDIA B200)
12
+ wafer workspaces exec dev -- python -c "import torch; print(torch.cuda.get_device_name(0))"
13
+ wafer workspaces sync dev ./my-project # Sync files
14
+ wafer workspaces exec dev -- python train.py
15
+ ```
16
+
17
+ **Available GPUs:**
18
+
19
+ - `MI300X` - AMD Instinct MI300X (192GB HBM3, ROCm)
20
+ - `B200` - NVIDIA Blackwell B200 (180GB HBM3e, CUDA) - default
21
+
22
+ ## Documentation Lookup
23
+
24
+ Answer GPU programming questions from indexed documentation.
25
+
26
+ ```bash
27
+ # Download corpus (one-time)
28
+ wafer corpus download cuda
29
+ wafer corpus download cutlass
30
+ wafer corpus download hip
31
+
32
+ # Query documentation
33
+ wafer agent -t ask-docs --corpus cuda "What is warp divergence?"
34
+ wafer agent -t ask-docs --corpus cutlass "What is a TiledMma?"
35
+ ```
36
+
37
+ ## Trace Analysis
38
+
39
+ Analyze performance traces from NCU, NSYS, or PyTorch profiler.
40
+
41
+ ```bash
42
+ # AI-assisted analysis
43
+ wafer agent -t trace-analyze --args trace=./profile.ncu-rep "Why is this kernel slow?"
44
+ wafer agent -t trace-analyze --args trace=./trace.json "What's the bottleneck?"
45
+
46
+ # Direct trace queries (PyTorch/Perfetto JSON)
47
+ wafer nvidia perfetto tables trace.json
48
+ wafer nvidia perfetto query trace.json \
49
+ "SELECT name, dur/1e6 as ms FROM slice WHERE cat='kernel' ORDER BY dur DESC LIMIT 10"
50
+
51
+ # NCU/NSYS analysis
52
+ wafer nvidia ncu analyze profile.ncu-rep
53
+ wafer nvidia nsys analyze profile.nsys-rep
54
+ ```
55
+
56
+ ## Kernel Evaluation
57
+
58
+ Test kernel correctness and measure speedup against a reference.
59
+
60
+ ```bash
61
+ # Using workspaces (no target setup required):
62
+ wafer workspaces create dev --gpu B200
63
+ wafer workspaces exec --sync ./my-kernel dev -- python test_kernel.py
64
+
65
+ # Or using configured targets (for your own hardware):
66
+ wafer evaluate make-template ./my-kernel
67
+ wafer evaluate \
68
+ --impl ./my-kernel/kernel.py \
69
+ --reference ./my-kernel/reference.py \
70
+ --test-cases ./my-kernel/test_cases.json \
71
+ --target <target-name>
72
+ ```
73
+
74
+ For target setup, see `wafer config targets --help`.
75
+
76
+ ## Kernel Optimization (AI-assisted)
77
+
78
+ Iteratively optimize a kernel with evaluation feedback.
79
+
80
+ ```bash
81
+ wafer agent -t optimize-kernel \
82
+ --args kernel=./my_kernel.cu \
83
+ --args target=H100 \
84
+ "Optimize this GEMM for memory bandwidth"
85
+ ```
86
+
87
+ ## Workspaces
88
+
89
+ Cloud GPU environments with no setup required.
90
+
91
+ **Available GPUs:**
92
+
93
+ - `MI300X` - AMD Instinct MI300X (192GB HBM3, ROCm)
94
+ - `B200` - NVIDIA Blackwell B200 (180GB HBM3e, CUDA) - default
95
+
96
+ ```bash
97
+ wafer workspaces create dev --gpu B200 --wait # NVIDIA B200
98
+ wafer workspaces create amd-dev --gpu MI300X # AMD MI300X
99
+ wafer workspaces list # List all
100
+ wafer workspaces sync dev ./project # Sync files
101
+ wafer workspaces exec dev -- ./run.sh # Run commands
102
+ wafer workspaces ssh dev # Interactive SSH
103
+ wafer workspaces delete dev # Cleanup
104
+ ```
105
+
106
+ See `wafer workspaces --help` for details.
107
+
108
+ ## Command Reference
109
+
110
+ ```bash
111
+ wafer corpus list|download|path # Manage documentation corpora
112
+ wafer workspaces # Cloud GPU environments (no setup)
113
+ wafer evaluate # Test kernel correctness/performance
114
+ wafer nvidia ncu|nsys|perfetto # NVIDIA profiling tools
115
+ wafer amd isa|rocprof-compute # AMD profiling tools
116
+ wafer agent -t <template> # AI-assisted workflows
117
+ wafer config targets # Configure your own GPU targets
118
+ ```
wafer/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """Wafer - Standalone tool for Docker execution on remote GPUs."""
2
+
3
+ __version__ = "0.1.0"
wafer/analytics.py ADDED
@@ -0,0 +1,306 @@
1
+ """PostHog analytics for Wafer CLI.
2
+
3
+ Tracks CLI command usage and user activity for product analytics.
4
+ Mirrors the analytics implementation in apps/wevin-extension/src/services/analytics.ts.
5
+
6
+ Usage:
7
+ from .analytics import track_command, identify_user, shutdown_analytics
8
+
9
+ # Track a command execution
10
+ track_command("evaluate", {"subcommand": "kernelbench", "outcome": "success"})
11
+
12
+ # Identify user after login
13
+ identify_user("user-id", "user@example.com")
14
+ """
15
+
16
+ import atexit
17
+ import platform
18
+ import uuid
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ # PostHog configuration - same as wevin-extension
23
+ POSTHOG_API_KEY = "phc_9eDjkY72ud9o4l1mA1Gr1dnRT1yx71rP3XY9z66teFh"
24
+ POSTHOG_HOST = "https://us.i.posthog.com"
25
+
26
+ # Anonymous ID storage
27
+ ANONYMOUS_ID_FILE = Path.home() / ".wafer" / ".analytics_id"
28
+
29
+ # Global state
30
+ _posthog_client: Any = None
31
+ _distinct_id: str | None = None
32
+ _initialized: bool = False
33
+
34
+
35
+ def _get_anonymous_id() -> str:
36
+ """Get or create anonymous ID for users who aren't logged in."""
37
+ if ANONYMOUS_ID_FILE.exists():
38
+ return ANONYMOUS_ID_FILE.read_text().strip()
39
+
40
+ # Generate new anonymous ID
41
+ anonymous_id = f"anon_{uuid.uuid4().hex}"
42
+ ANONYMOUS_ID_FILE.parent.mkdir(parents=True, exist_ok=True)
43
+ ANONYMOUS_ID_FILE.write_text(anonymous_id)
44
+ return anonymous_id
45
+
46
+
47
+ def _get_user_id_from_credentials() -> tuple[str | None, str | None]:
48
+ """Get user ID and email from stored credentials.
49
+
50
+ Returns:
51
+ Tuple of (user_id, email), both may be None if not logged in.
52
+ """
53
+ # Import here to avoid circular imports
54
+ from .auth import load_credentials, verify_token
55
+
56
+ creds = load_credentials()
57
+ if not creds:
58
+ return None, None
59
+
60
+ # Try to get user info from token
61
+ try:
62
+ user_info = verify_token(creds.access_token)
63
+ return user_info.user_id, user_info.email or creds.email
64
+ except Exception:
65
+ # Token verification failed, use email from credentials if available
66
+ return None, creds.email
67
+
68
+
69
+ def _is_analytics_enabled() -> bool:
70
+ """Check if analytics is enabled via preferences.
71
+
72
+ Returns True by default, respects user preference in config.
73
+ """
74
+ from .global_config import get_preferences
75
+
76
+ try:
77
+ prefs = get_preferences()
78
+ return getattr(prefs, "analytics_enabled", True)
79
+ except Exception:
80
+ # Default to enabled if we can't read preferences
81
+ return True
82
+
83
+
84
+ def init_analytics() -> bool:
85
+ """Initialize PostHog client.
86
+
87
+ Returns:
88
+ True if initialization succeeded, False otherwise.
89
+ """
90
+ global _posthog_client, _distinct_id, _initialized
91
+
92
+ if _initialized:
93
+ return _posthog_client is not None
94
+
95
+ _initialized = True
96
+
97
+ # Check if analytics is enabled
98
+ if not _is_analytics_enabled():
99
+ return False
100
+
101
+ try:
102
+ from posthog import Posthog
103
+
104
+ _posthog_client = Posthog(
105
+ api_key=POSTHOG_API_KEY,
106
+ host=POSTHOG_HOST,
107
+ # Flush immediately for CLI - commands are short-lived
108
+ flush_at=1,
109
+ flush_interval=1,
110
+ # Disable debug logging
111
+ debug=False,
112
+ )
113
+
114
+ # Set up distinct ID - prefer authenticated user, fall back to anonymous
115
+ user_id, email = _get_user_id_from_credentials()
116
+ if user_id:
117
+ _distinct_id = user_id
118
+ # Identify the user with their email
119
+ if email:
120
+ _posthog_client.identify(
121
+ distinct_id=user_id,
122
+ properties={
123
+ "email": email,
124
+ "auth_provider": "github",
125
+ },
126
+ )
127
+ else:
128
+ _distinct_id = _get_anonymous_id()
129
+
130
+ # Register shutdown handler to flush events
131
+ atexit.register(shutdown_analytics)
132
+
133
+ return True
134
+
135
+ except ImportError:
136
+ # PostHog not installed - analytics disabled
137
+ return False
138
+ except Exception:
139
+ # Any other error - fail silently, don't break CLI
140
+ return False
141
+
142
+
143
+ def shutdown_analytics() -> None:
144
+ """Shutdown PostHog client and flush pending events."""
145
+ global _posthog_client
146
+
147
+ if _posthog_client is not None:
148
+ try:
149
+ _posthog_client.flush()
150
+ _posthog_client.shutdown()
151
+ except Exception:
152
+ pass # Fail silently on shutdown
153
+ _posthog_client = None
154
+
155
+
156
+ def identify_user(user_id: str, email: str | None = None) -> None:
157
+ """Identify a user after login.
158
+
159
+ Args:
160
+ user_id: Supabase user ID
161
+ email: User's email address
162
+ """
163
+ global _distinct_id
164
+
165
+ if not init_analytics():
166
+ return
167
+
168
+ if _posthog_client is None:
169
+ return
170
+
171
+ _distinct_id = user_id
172
+
173
+ try:
174
+ properties: dict[str, Any] = {"auth_provider": "github"}
175
+ if email:
176
+ properties["email"] = email
177
+
178
+ _posthog_client.identify(
179
+ distinct_id=user_id,
180
+ properties=properties,
181
+ )
182
+ _posthog_client.flush()
183
+ except Exception:
184
+ pass # Fail silently
185
+
186
+
187
+ def reset_user_identity() -> None:
188
+ """Reset user identity after logout."""
189
+ global _distinct_id
190
+
191
+ _distinct_id = _get_anonymous_id()
192
+
193
+
194
+ def get_distinct_id() -> str:
195
+ """Get current distinct ID for tracking."""
196
+ global _distinct_id
197
+
198
+ if _distinct_id is None:
199
+ user_id, _ = _get_user_id_from_credentials()
200
+ _distinct_id = user_id or _get_anonymous_id()
201
+
202
+ return _distinct_id
203
+
204
+
205
+ def _get_cli_version() -> str:
206
+ """Get CLI version from package metadata."""
207
+ try:
208
+ from importlib.metadata import version
209
+
210
+ return version("wafer-cli")
211
+ except Exception:
212
+ return "unknown"
213
+
214
+
215
+ def _get_base_properties() -> dict[str, Any]:
216
+ """Get base properties included with all events."""
217
+ return {
218
+ "platform": "cli",
219
+ "tool_id": "cli",
220
+ "cli_version": _get_cli_version(),
221
+ "os": platform.system().lower(),
222
+ "os_version": platform.release(),
223
+ "python_version": platform.python_version(),
224
+ }
225
+
226
+
227
+ def track_event(event_name: str, properties: dict[str, Any] | None = None) -> None:
228
+ """Track a generic event.
229
+
230
+ Args:
231
+ event_name: Name of the event to track
232
+ properties: Additional properties to include
233
+ """
234
+ if not init_analytics():
235
+ return
236
+
237
+ if _posthog_client is None:
238
+ return
239
+
240
+ try:
241
+ event_properties = _get_base_properties()
242
+ if properties:
243
+ event_properties.update(properties)
244
+
245
+ _posthog_client.capture(
246
+ distinct_id=get_distinct_id(),
247
+ event=event_name,
248
+ properties=event_properties,
249
+ )
250
+ except Exception:
251
+ pass # Fail silently
252
+
253
+
254
+ def track_command(
255
+ command: str,
256
+ subcommand: str | None = None,
257
+ outcome: str = "success",
258
+ duration_ms: int | None = None,
259
+ properties: dict[str, Any] | None = None,
260
+ ) -> None:
261
+ """Track a CLI command execution.
262
+
263
+ This event counts towards DAU in the internal dashboard.
264
+
265
+ Args:
266
+ command: The main command name (e.g., "evaluate", "agent")
267
+ subcommand: Optional subcommand (e.g., "kernelbench")
268
+ outcome: "success" or "error"
269
+ duration_ms: Command execution time in milliseconds
270
+ properties: Additional properties to include
271
+ """
272
+ event_properties: dict[str, Any] = {
273
+ "command": command,
274
+ "outcome": outcome,
275
+ }
276
+
277
+ if subcommand:
278
+ event_properties["subcommand"] = subcommand
279
+
280
+ if duration_ms is not None:
281
+ event_properties["duration_ms"] = duration_ms
282
+
283
+ if properties:
284
+ event_properties.update(properties)
285
+
286
+ track_event("cli_command_executed", event_properties)
287
+
288
+
289
+ def track_login(user_id: str, email: str | None = None) -> None:
290
+ """Track user login event.
291
+
292
+ Args:
293
+ user_id: Supabase user ID
294
+ email: User's email address
295
+ """
296
+ # First identify the user
297
+ identify_user(user_id, email)
298
+
299
+ # Then track the login event
300
+ track_event("cli_user_signed_in", {"user_id": user_id})
301
+
302
+
303
+ def track_logout() -> None:
304
+ """Track user logout event."""
305
+ track_event("cli_user_signed_out")
306
+ reset_user_identity()
wafer/api_client.py ADDED
@@ -0,0 +1,195 @@
1
+ """Wafer API client for remote GPU operations.
2
+
3
+ Thin client that calls wafer-api endpoints instead of direct SSH.
4
+ """
5
+
6
+ import base64
7
+ import sys
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+
11
+ import httpx
12
+
13
+ from .global_config import get_api_url # noqa: F401 - re-exported for backwards compat
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class PushResult:
18
+ """Result of pushing files to GPU."""
19
+
20
+ workspace_id: str
21
+ workspace_path: str
22
+ files_uploaded: list[str]
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class ApiConfig:
27
+ """API client configuration."""
28
+
29
+ base_url: str = "http://localhost:8000" # Only used if ApiConfig is instantiated directly
30
+ timeout: float = 60.0
31
+
32
+
33
+ def _get_auth_headers() -> dict[str, str]:
34
+ """Get auth headers from stored credentials (lazy import to avoid circular)."""
35
+ from .auth import get_auth_headers
36
+
37
+ return get_auth_headers()
38
+
39
+
40
+ def push_directory(local_path: Path, workspace_name: str | None = None) -> PushResult:
41
+ """Push local directory to GPU via wafer-api.
42
+
43
+ Args:
44
+ local_path: Local directory to upload
45
+ workspace_name: Optional workspace name (defaults to directory name)
46
+
47
+ Returns:
48
+ PushResult with workspace_id and uploaded files
49
+
50
+ Raises:
51
+ FileNotFoundError: If local_path doesn't exist
52
+ ValueError: If local_path is not a directory
53
+ httpx.HTTPError: If API request fails
54
+ """
55
+ if not local_path.exists():
56
+ raise FileNotFoundError(f"Path not found: {local_path}")
57
+ if not local_path.is_dir():
58
+ raise ValueError(f"Not a directory: {local_path}")
59
+
60
+ # Collect files and encode as base64
61
+ files = []
62
+ for file_path in local_path.rglob("*"):
63
+ if file_path.is_file():
64
+ relative_path = file_path.relative_to(local_path)
65
+ content = file_path.read_bytes()
66
+ files.append({
67
+ "path": str(relative_path),
68
+ "content": base64.b64encode(content).decode(),
69
+ })
70
+
71
+ # Build request
72
+ request_body = {
73
+ "files": files,
74
+ "workspace_name": workspace_name or local_path.name,
75
+ }
76
+
77
+ # Call API
78
+ api_url = get_api_url()
79
+ headers = _get_auth_headers()
80
+ with httpx.Client(timeout=60.0, headers=headers) as client:
81
+ response = client.post(f"{api_url}/v1/gpu/push", json=request_body)
82
+ response.raise_for_status()
83
+ data = response.json()
84
+
85
+ return PushResult(
86
+ workspace_id=data["workspace_id"],
87
+ workspace_path=data["workspace_path"],
88
+ files_uploaded=data["files_uploaded"],
89
+ )
90
+
91
+
92
+ def _collect_files(local_path: Path) -> list[dict]:
93
+ """Collect files from directory as base64-encoded dicts."""
94
+ files = []
95
+ for file_path in local_path.rglob("*"):
96
+ if file_path.is_file():
97
+ relative_path = file_path.relative_to(local_path)
98
+ content = file_path.read_bytes()
99
+ files.append({
100
+ "path": str(relative_path),
101
+ "content": base64.b64encode(content).decode(),
102
+ })
103
+ return files
104
+
105
+
106
+ def run_command_stream(
107
+ command: str,
108
+ upload_dir: Path | None = None,
109
+ workspace_id: str | None = None,
110
+ gpu_id: int | None = None,
111
+ gpu_count: int = 1,
112
+ docker_image: str | None = None,
113
+ docker_entrypoint: str | None = None,
114
+ pull_image: bool = False,
115
+ require_hardware_counters: bool = False,
116
+ target: str | None = None,
117
+ ) -> int:
118
+ """Run command on GPU via wafer-api, streaming output.
119
+
120
+ Two modes (mutually exclusive):
121
+ - upload_dir: Upload files and run (stateless, high-level)
122
+ - workspace_id: Use existing workspace (low-level)
123
+
124
+ Args:
125
+ command: Command to execute inside container
126
+ upload_dir: Directory to upload (stateless mode)
127
+ workspace_id: Workspace ID from push (low-level mode)
128
+ gpu_id: GPU ID to use (optional)
129
+ gpu_count: Number of GPUs needed (1-8, default 1)
130
+ docker_image: Docker image override (optional)
131
+ docker_entrypoint: Docker entrypoint override (optional, e.g., "bash")
132
+ pull_image: Pull image if not available (optional, default False)
133
+ require_hardware_counters: Require baremetal for ncu profiling (optional)
134
+ target: Target name to use (optional, defaults to user's default)
135
+
136
+ Returns:
137
+ Exit code (0 = success, non-zero = failure)
138
+
139
+ Raises:
140
+ httpx.HTTPError: If API request fails
141
+ """
142
+ request_body: dict = {
143
+ "command": command,
144
+ }
145
+
146
+ # Add files or workspace_id (mutually exclusive)
147
+ if upload_dir is not None:
148
+ files = _collect_files(upload_dir)
149
+ request_body["files"] = files
150
+ request_body["workspace_name"] = upload_dir.name
151
+ elif workspace_id is not None:
152
+ request_body["workspace_id"] = workspace_id
153
+ # else: no files, no workspace (run command in temp workspace)
154
+
155
+ if gpu_id is not None:
156
+ request_body["gpu_id"] = gpu_id
157
+ if gpu_count > 1:
158
+ request_body["gpu_count"] = gpu_count
159
+ if docker_image is not None:
160
+ request_body["docker_image"] = docker_image
161
+ if docker_entrypoint is not None:
162
+ request_body["docker_entrypoint"] = docker_entrypoint
163
+ if pull_image:
164
+ request_body["pull_image"] = True
165
+ if require_hardware_counters:
166
+ request_body["require_hardware_counters"] = True
167
+ if target is not None:
168
+ request_body["target"] = target
169
+
170
+ api_url = get_api_url()
171
+ headers = _get_auth_headers()
172
+ exit_code = 0
173
+
174
+ with httpx.Client(timeout=None, headers=headers) as client: # No timeout for streaming
175
+ with client.stream("POST", f"{api_url}/v1/gpu/jobs", json=request_body) as response:
176
+ response.raise_for_status()
177
+
178
+ for line in response.iter_lines():
179
+ if not line:
180
+ continue
181
+
182
+ # Parse SSE format: "data: <content>"
183
+ if line.startswith("data: "):
184
+ content = line[6:] # Strip "data: " prefix
185
+
186
+ if content == "[DONE]":
187
+ break
188
+ elif content.startswith("[ERROR]"):
189
+ print(content[8:], file=sys.stderr) # Strip "[ERROR] "
190
+ exit_code = 1
191
+ break
192
+ else:
193
+ print(content)
194
+
195
+ return exit_code