wafer-cli 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +118 -0
- wafer/__init__.py +3 -0
- wafer/analytics.py +306 -0
- wafer/api_client.py +195 -0
- wafer/auth.py +432 -0
- wafer/autotuner.py +1080 -0
- wafer/billing.py +233 -0
- wafer/cli.py +7289 -0
- wafer/config.py +105 -0
- wafer/corpus.py +366 -0
- wafer/evaluate.py +4593 -0
- wafer/global_config.py +350 -0
- wafer/gpu_run.py +307 -0
- wafer/inference.py +148 -0
- wafer/kernel_scope.py +552 -0
- wafer/ncu_analyze.py +651 -0
- wafer/nsys_analyze.py +1042 -0
- wafer/nsys_profile.py +510 -0
- wafer/output.py +248 -0
- wafer/problems.py +357 -0
- wafer/rocprof_compute.py +490 -0
- wafer/rocprof_sdk.py +274 -0
- wafer/rocprof_systems.py +520 -0
- wafer/skills/wafer-guide/SKILL.md +129 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +842 -0
- wafer/targets_ops.py +717 -0
- wafer/templates/__init__.py +0 -0
- wafer/templates/ask_docs.py +61 -0
- wafer/templates/optimize_kernel.py +71 -0
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/templates/trace_analyze.py +74 -0
- wafer/tracelens.py +218 -0
- wafer/wevin_cli.py +577 -0
- wafer/workspaces.py +852 -0
- wafer_cli-0.2.14.dist-info/METADATA +16 -0
- wafer_cli-0.2.14.dist-info/RECORD +41 -0
- wafer_cli-0.2.14.dist-info/WHEEL +5 -0
- wafer_cli-0.2.14.dist-info/entry_points.txt +2 -0
- wafer_cli-0.2.14.dist-info/top_level.txt +1 -0
wafer/auth.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
"""CLI authentication and credential management.
|
|
2
|
+
|
|
3
|
+
Handles storing/loading credentials from ~/.wafer/credentials.json
|
|
4
|
+
and verifying tokens against the wafer-api.
|
|
5
|
+
|
|
6
|
+
Supports automatic token refresh using Supabase refresh tokens.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import socket
|
|
11
|
+
import sys
|
|
12
|
+
import time
|
|
13
|
+
import webbrowser
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from urllib.parse import parse_qs, urlparse
|
|
18
|
+
|
|
19
|
+
import httpx
|
|
20
|
+
|
|
21
|
+
from .global_config import get_api_url, get_supabase_url
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _safe_symbol(unicode_sym: str, ascii_fallback: str) -> str:
|
|
25
|
+
"""Return unicode symbol if terminal supports it, otherwise ASCII fallback."""
|
|
26
|
+
# Check if stdout can handle UTF-8
|
|
27
|
+
if not sys.stdout.isatty():
|
|
28
|
+
return ascii_fallback
|
|
29
|
+
try:
|
|
30
|
+
encoding = sys.stdout.encoding or "ascii"
|
|
31
|
+
unicode_sym.encode(encoding)
|
|
32
|
+
return unicode_sym
|
|
33
|
+
except (UnicodeEncodeError, LookupError):
|
|
34
|
+
return ascii_fallback
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Safe symbols for terminal output
|
|
38
|
+
CHECK = _safe_symbol("✓", "[OK]")
|
|
39
|
+
CROSS = _safe_symbol("✗", "[FAIL]")
|
|
40
|
+
|
|
41
|
+
CREDENTIALS_DIR = Path.home() / ".wafer"
|
|
42
|
+
CREDENTIALS_FILE = CREDENTIALS_DIR / "credentials.json"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class Credentials:
|
|
47
|
+
"""Stored credentials."""
|
|
48
|
+
|
|
49
|
+
access_token: str
|
|
50
|
+
refresh_token: str | None = None
|
|
51
|
+
email: str | None = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class UserInfo:
|
|
56
|
+
"""User info from token verification."""
|
|
57
|
+
|
|
58
|
+
user_id: str
|
|
59
|
+
email: str | None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def save_credentials(
|
|
63
|
+
access_token: str,
|
|
64
|
+
refresh_token: str | None = None,
|
|
65
|
+
email: str | None = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Save credentials to ~/.wafer/credentials.json."""
|
|
68
|
+
CREDENTIALS_DIR.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
data = {"access_token": access_token}
|
|
70
|
+
if refresh_token:
|
|
71
|
+
data["refresh_token"] = refresh_token
|
|
72
|
+
if email:
|
|
73
|
+
data["email"] = email
|
|
74
|
+
CREDENTIALS_FILE.write_text(json.dumps(data, indent=2))
|
|
75
|
+
# Set restrictive permissions (owner read/write only)
|
|
76
|
+
CREDENTIALS_FILE.chmod(0o600)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def load_credentials() -> Credentials | None:
|
|
80
|
+
"""Load credentials from ~/.wafer/credentials.json.
|
|
81
|
+
|
|
82
|
+
Returns None if file doesn't exist or is invalid.
|
|
83
|
+
"""
|
|
84
|
+
if not CREDENTIALS_FILE.exists():
|
|
85
|
+
return None
|
|
86
|
+
try:
|
|
87
|
+
data = json.loads(CREDENTIALS_FILE.read_text())
|
|
88
|
+
return Credentials(
|
|
89
|
+
access_token=data["access_token"],
|
|
90
|
+
refresh_token=data.get("refresh_token"),
|
|
91
|
+
email=data.get("email"),
|
|
92
|
+
)
|
|
93
|
+
except (json.JSONDecodeError, KeyError):
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def clear_credentials() -> bool:
|
|
98
|
+
"""Remove credentials file.
|
|
99
|
+
|
|
100
|
+
Returns True if file was removed, False if it didn't exist.
|
|
101
|
+
"""
|
|
102
|
+
if CREDENTIALS_FILE.exists():
|
|
103
|
+
CREDENTIALS_FILE.unlink()
|
|
104
|
+
return True
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_auth_headers() -> dict[str, str]:
|
|
109
|
+
"""Get Authorization headers with a valid token.
|
|
110
|
+
|
|
111
|
+
Automatically refreshes expired tokens if a refresh token is available.
|
|
112
|
+
|
|
113
|
+
Returns empty dict if not logged in or refresh fails.
|
|
114
|
+
"""
|
|
115
|
+
token = get_valid_token()
|
|
116
|
+
if token:
|
|
117
|
+
return {"Authorization": f"Bearer {token}"}
|
|
118
|
+
return {}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def verify_token(token: str) -> UserInfo:
|
|
122
|
+
"""Verify token with wafer-api and return user info.
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
httpx.HTTPStatusError: If token is invalid (401) or other HTTP error
|
|
126
|
+
httpx.RequestError: If API is unreachable
|
|
127
|
+
"""
|
|
128
|
+
api_url = get_api_url()
|
|
129
|
+
with httpx.Client(timeout=10.0) as client:
|
|
130
|
+
response = client.post(
|
|
131
|
+
f"{api_url}/v1/auth/verify",
|
|
132
|
+
json={"token": token},
|
|
133
|
+
)
|
|
134
|
+
response.raise_for_status()
|
|
135
|
+
data = response.json()
|
|
136
|
+
return UserInfo(
|
|
137
|
+
user_id=data["user_id"],
|
|
138
|
+
email=data.get("email"),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def refresh_access_token(refresh_token: str) -> tuple[str, str]:
|
|
143
|
+
"""Use refresh token to get a new access token from Supabase.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
refresh_token: The refresh token from previous auth
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Tuple of (new_access_token, new_refresh_token)
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
httpx.HTTPStatusError: If refresh fails (e.g., refresh token expired)
|
|
153
|
+
httpx.RequestError: If Supabase is unreachable
|
|
154
|
+
"""
|
|
155
|
+
from .global_config import get_supabase_anon_key
|
|
156
|
+
|
|
157
|
+
supabase_url = get_supabase_url()
|
|
158
|
+
anon_key = get_supabase_anon_key()
|
|
159
|
+
with httpx.Client(timeout=10.0) as client:
|
|
160
|
+
response = client.post(
|
|
161
|
+
f"{supabase_url}/auth/v1/token?grant_type=refresh_token",
|
|
162
|
+
json={"refresh_token": refresh_token},
|
|
163
|
+
headers={
|
|
164
|
+
"Content-Type": "application/json",
|
|
165
|
+
"apikey": anon_key,
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
response.raise_for_status()
|
|
169
|
+
data = response.json()
|
|
170
|
+
return data["access_token"], data["refresh_token"]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def get_valid_token() -> str | None:
|
|
174
|
+
"""Get a valid access token, refreshing if necessary.
|
|
175
|
+
|
|
176
|
+
Attempts to verify the current token. If it's expired and we have a
|
|
177
|
+
refresh token, automatically refreshes and saves the new tokens.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Valid access token, or None if not logged in or refresh failed
|
|
181
|
+
"""
|
|
182
|
+
creds = load_credentials()
|
|
183
|
+
if not creds:
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
# Try current token
|
|
187
|
+
try:
|
|
188
|
+
verify_token(creds.access_token)
|
|
189
|
+
return creds.access_token
|
|
190
|
+
except httpx.HTTPStatusError as e:
|
|
191
|
+
if e.response.status_code != 401:
|
|
192
|
+
# Not an auth error, re-raise
|
|
193
|
+
raise
|
|
194
|
+
|
|
195
|
+
# Token expired, try refresh
|
|
196
|
+
if not creds.refresh_token:
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
new_access, new_refresh = refresh_access_token(creds.refresh_token)
|
|
201
|
+
save_credentials(new_access, new_refresh, creds.email)
|
|
202
|
+
return new_access
|
|
203
|
+
except httpx.HTTPStatusError:
|
|
204
|
+
# Refresh failed, need to re-login
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _find_free_port() -> int:
|
|
209
|
+
"""Find a free port for the callback server."""
|
|
210
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
211
|
+
s.bind(("", 0))
|
|
212
|
+
return s.getsockname()[1]
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class OAuthCallbackHandler(BaseHTTPRequestHandler):
|
|
216
|
+
"""HTTP handler that catches the OAuth callback with access token."""
|
|
217
|
+
|
|
218
|
+
access_token: str | None = None
|
|
219
|
+
refresh_token: str | None = None
|
|
220
|
+
error: str | None = None
|
|
221
|
+
|
|
222
|
+
def log_message(self, format: str, *args: object) -> None: # noqa: A002
|
|
223
|
+
"""Suppress default logging."""
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
def do_GET(self) -> None:
|
|
227
|
+
"""Handle GET request - catch the callback or serve the HTML page."""
|
|
228
|
+
parsed = urlparse(self.path)
|
|
229
|
+
|
|
230
|
+
if parsed.path == "/callback":
|
|
231
|
+
# This is the redirect from Supabase with hash fragment
|
|
232
|
+
# But hash fragments aren't sent to server, so serve a page that extracts it
|
|
233
|
+
html = """<!DOCTYPE html>
|
|
234
|
+
<html>
|
|
235
|
+
<head><meta charset="UTF-8"><title>Wafer CLI Login</title></head>
|
|
236
|
+
<body>
|
|
237
|
+
<h2>Completing login...</h2>
|
|
238
|
+
<script>
|
|
239
|
+
// Extract tokens from hash fragment
|
|
240
|
+
const hash = window.location.hash.substring(1);
|
|
241
|
+
const params = new URLSearchParams(hash);
|
|
242
|
+
const accessToken = params.get('access_token');
|
|
243
|
+
const refreshToken = params.get('refresh_token');
|
|
244
|
+
const error = params.get('error_description') || params.get('error');
|
|
245
|
+
|
|
246
|
+
if (accessToken) {
|
|
247
|
+
// Send both tokens to our local server
|
|
248
|
+
let url = '/token?access_token=' + encodeURIComponent(accessToken);
|
|
249
|
+
if (refreshToken) {
|
|
250
|
+
url += '&refresh_token=' + encodeURIComponent(refreshToken);
|
|
251
|
+
}
|
|
252
|
+
fetch(url)
|
|
253
|
+
.then(() => {
|
|
254
|
+
document.body.innerHTML = '<h2>✓ Login successful!</h2><p>You can close this window.</p>';
|
|
255
|
+
});
|
|
256
|
+
} else if (error) {
|
|
257
|
+
fetch('/token?error=' + encodeURIComponent(error));
|
|
258
|
+
document.body.innerHTML = '<h2>✗ Login failed</h2><p>' + error + '</p>';
|
|
259
|
+
} else {
|
|
260
|
+
document.body.innerHTML = '<h2>✗ No token received</h2>';
|
|
261
|
+
}
|
|
262
|
+
</script>
|
|
263
|
+
</body>
|
|
264
|
+
</html>"""
|
|
265
|
+
self.send_response(200)
|
|
266
|
+
self.send_header("Content-Type", "text/html; charset=utf-8")
|
|
267
|
+
self.end_headers()
|
|
268
|
+
self.wfile.write(html.encode())
|
|
269
|
+
|
|
270
|
+
elif parsed.path == "/token":
|
|
271
|
+
# JavaScript sends us the tokens
|
|
272
|
+
params = parse_qs(parsed.query)
|
|
273
|
+
if "access_token" in params:
|
|
274
|
+
OAuthCallbackHandler.access_token = params["access_token"][0]
|
|
275
|
+
if "refresh_token" in params:
|
|
276
|
+
OAuthCallbackHandler.refresh_token = params["refresh_token"][0]
|
|
277
|
+
elif "error" in params:
|
|
278
|
+
OAuthCallbackHandler.error = params["error"][0]
|
|
279
|
+
|
|
280
|
+
self.send_response(200)
|
|
281
|
+
self.send_header("Content-Type", "text/plain")
|
|
282
|
+
self.end_headers()
|
|
283
|
+
self.wfile.write(b"OK")
|
|
284
|
+
|
|
285
|
+
else:
|
|
286
|
+
self.send_response(404)
|
|
287
|
+
self.end_headers()
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def browser_login(timeout: int = 120, port: int | None = None) -> tuple[str, str | None]:
|
|
291
|
+
"""Open browser for GitHub OAuth and return tokens.
|
|
292
|
+
|
|
293
|
+
Starts a local HTTP server, opens browser to Supabase OAuth,
|
|
294
|
+
and waits for the callback with the tokens.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
timeout: Seconds to wait for callback (default 120)
|
|
298
|
+
port: Port for callback server. If None, finds a free port (default None)
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Tuple of (access_token, refresh_token). refresh_token may be None.
|
|
302
|
+
|
|
303
|
+
Raises:
|
|
304
|
+
TimeoutError: If no callback received within timeout
|
|
305
|
+
RuntimeError: If OAuth flow failed
|
|
306
|
+
"""
|
|
307
|
+
if port is None:
|
|
308
|
+
port = _find_free_port()
|
|
309
|
+
redirect_uri = f"http://localhost:{port}/callback"
|
|
310
|
+
supabase_url = get_supabase_url()
|
|
311
|
+
|
|
312
|
+
# Build OAuth URL
|
|
313
|
+
auth_url = f"{supabase_url}/auth/v1/authorize?provider=github&redirect_to={redirect_uri}"
|
|
314
|
+
|
|
315
|
+
# Reset state
|
|
316
|
+
OAuthCallbackHandler.access_token = None
|
|
317
|
+
OAuthCallbackHandler.refresh_token = None
|
|
318
|
+
OAuthCallbackHandler.error = None
|
|
319
|
+
|
|
320
|
+
# Start local server
|
|
321
|
+
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
|
|
322
|
+
server.timeout = 1 # Check for token every second
|
|
323
|
+
|
|
324
|
+
# Open browser
|
|
325
|
+
print("Opening browser for GitHub authentication...")
|
|
326
|
+
print(f"If browser doesn't open, visit: {auth_url}")
|
|
327
|
+
webbrowser.open(auth_url)
|
|
328
|
+
|
|
329
|
+
# Wait for callback
|
|
330
|
+
start = time.time()
|
|
331
|
+
print("Waiting for authentication...", end="", flush=True)
|
|
332
|
+
|
|
333
|
+
while time.time() - start < timeout:
|
|
334
|
+
server.handle_request()
|
|
335
|
+
|
|
336
|
+
if OAuthCallbackHandler.access_token:
|
|
337
|
+
print(f" {CHECK}")
|
|
338
|
+
server.server_close()
|
|
339
|
+
return OAuthCallbackHandler.access_token, OAuthCallbackHandler.refresh_token
|
|
340
|
+
|
|
341
|
+
if OAuthCallbackHandler.error:
|
|
342
|
+
print(f" {CROSS}")
|
|
343
|
+
server.server_close()
|
|
344
|
+
raise RuntimeError(f"OAuth failed: {OAuthCallbackHandler.error}")
|
|
345
|
+
|
|
346
|
+
server.server_close()
|
|
347
|
+
raise TimeoutError(f"No response within {timeout} seconds")
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def device_code_login(timeout: int = 600) -> tuple[str, str | None]:
|
|
351
|
+
"""Authenticate using state-based flow (no browser/port forwarding needed).
|
|
352
|
+
|
|
353
|
+
This is the SSH-friendly auth flow similar to GitHub CLI:
|
|
354
|
+
1. Request a state token from the API
|
|
355
|
+
2. Display the auth URL with state parameter
|
|
356
|
+
3. User visits URL on any device and signs in normally
|
|
357
|
+
4. Poll API until user completes authentication
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
timeout: Seconds to wait for authentication (default 600 = 10 minutes)
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Tuple of (access_token, refresh_token). refresh_token may be None.
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
TimeoutError: If user doesn't authenticate within timeout
|
|
367
|
+
RuntimeError: If auth flow failed
|
|
368
|
+
"""
|
|
369
|
+
api_url = get_api_url()
|
|
370
|
+
|
|
371
|
+
# Request state and auth URL
|
|
372
|
+
with httpx.Client(timeout=10.0) as client:
|
|
373
|
+
response = client.post(f"{api_url}/v1/auth/cli-auth/start", json={})
|
|
374
|
+
response.raise_for_status()
|
|
375
|
+
data = response.json()
|
|
376
|
+
|
|
377
|
+
state = data["state"]
|
|
378
|
+
auth_url = data["auth_url"]
|
|
379
|
+
expires_in = data["expires_in"]
|
|
380
|
+
|
|
381
|
+
# Display instructions to user
|
|
382
|
+
print("\n" + "=" * 60)
|
|
383
|
+
print(" WAFER CLI - Authentication")
|
|
384
|
+
print("=" * 60)
|
|
385
|
+
print(f"\n Visit: {auth_url}")
|
|
386
|
+
print("\n Sign in with GitHub to complete authentication")
|
|
387
|
+
print("\n" + "=" * 60 + "\n")
|
|
388
|
+
|
|
389
|
+
# Poll for authentication
|
|
390
|
+
start = time.time()
|
|
391
|
+
poll_interval = 5 # Poll every 5 seconds
|
|
392
|
+
last_poll = 0.0
|
|
393
|
+
|
|
394
|
+
print("Waiting for authentication", end="", flush=True)
|
|
395
|
+
|
|
396
|
+
while time.time() - start < min(timeout, expires_in):
|
|
397
|
+
# Show progress dots
|
|
398
|
+
if time.time() - last_poll >= poll_interval:
|
|
399
|
+
print(".", end="", flush=True)
|
|
400
|
+
|
|
401
|
+
# Poll the API
|
|
402
|
+
with httpx.Client(timeout=10.0) as client:
|
|
403
|
+
try:
|
|
404
|
+
response = client.post(f"{api_url}/v1/auth/cli-auth/token", json={"state": state})
|
|
405
|
+
|
|
406
|
+
if response.status_code == 200:
|
|
407
|
+
# Success!
|
|
408
|
+
data = response.json()
|
|
409
|
+
print(f" {CHECK}\n")
|
|
410
|
+
return data["access_token"], data.get("refresh_token")
|
|
411
|
+
|
|
412
|
+
if response.status_code == 428:
|
|
413
|
+
# Still waiting
|
|
414
|
+
last_poll = time.time()
|
|
415
|
+
time.sleep(1)
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
# Some other error
|
|
419
|
+
print(f" {CROSS}\n")
|
|
420
|
+
raise RuntimeError(f"CLI auth flow failed: {response.status_code} {response.text}")
|
|
421
|
+
|
|
422
|
+
except httpx.RequestError:
|
|
423
|
+
# Network error, retry
|
|
424
|
+
print("!", end="", flush=True)
|
|
425
|
+
last_poll = time.time()
|
|
426
|
+
time.sleep(1)
|
|
427
|
+
continue
|
|
428
|
+
|
|
429
|
+
time.sleep(0.5) # Small sleep to avoid busy loop
|
|
430
|
+
|
|
431
|
+
print(f" {CROSS}\n")
|
|
432
|
+
raise TimeoutError(f"Authentication not completed within {expires_in} seconds")
|