wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/auth.py ADDED
@@ -0,0 +1,432 @@
1
+ """CLI authentication and credential management.
2
+
3
+ Handles storing/loading credentials from ~/.wafer/credentials.json
4
+ and verifying tokens against the wafer-api.
5
+
6
+ Supports automatic token refresh using Supabase refresh tokens.
7
+ """
8
+
9
+ import json
10
+ import socket
11
+ import sys
12
+ import time
13
+ import webbrowser
14
+ from dataclasses import dataclass
15
+ from http.server import BaseHTTPRequestHandler, HTTPServer
16
+ from pathlib import Path
17
+ from urllib.parse import parse_qs, urlparse
18
+
19
+ import httpx
20
+
21
+ from .global_config import get_api_url, get_supabase_url
22
+
23
+
24
+ def _safe_symbol(unicode_sym: str, ascii_fallback: str) -> str:
25
+ """Return unicode symbol if terminal supports it, otherwise ASCII fallback."""
26
+ # Check if stdout can handle UTF-8
27
+ if not sys.stdout.isatty():
28
+ return ascii_fallback
29
+ try:
30
+ encoding = sys.stdout.encoding or "ascii"
31
+ unicode_sym.encode(encoding)
32
+ return unicode_sym
33
+ except (UnicodeEncodeError, LookupError):
34
+ return ascii_fallback
35
+
36
+
37
+ # Safe symbols for terminal output
38
+ CHECK = _safe_symbol("✓", "[OK]")
39
+ CROSS = _safe_symbol("✗", "[FAIL]")
40
+
41
+ CREDENTIALS_DIR = Path.home() / ".wafer"
42
+ CREDENTIALS_FILE = CREDENTIALS_DIR / "credentials.json"
43
+
44
+
45
+ @dataclass
46
+ class Credentials:
47
+ """Stored credentials."""
48
+
49
+ access_token: str
50
+ refresh_token: str | None = None
51
+ email: str | None = None
52
+
53
+
54
+ @dataclass
55
+ class UserInfo:
56
+ """User info from token verification."""
57
+
58
+ user_id: str
59
+ email: str | None
60
+
61
+
62
+ def save_credentials(
63
+ access_token: str,
64
+ refresh_token: str | None = None,
65
+ email: str | None = None,
66
+ ) -> None:
67
+ """Save credentials to ~/.wafer/credentials.json."""
68
+ CREDENTIALS_DIR.mkdir(parents=True, exist_ok=True)
69
+ data = {"access_token": access_token}
70
+ if refresh_token:
71
+ data["refresh_token"] = refresh_token
72
+ if email:
73
+ data["email"] = email
74
+ CREDENTIALS_FILE.write_text(json.dumps(data, indent=2))
75
+ # Set restrictive permissions (owner read/write only)
76
+ CREDENTIALS_FILE.chmod(0o600)
77
+
78
+
79
+ def load_credentials() -> Credentials | None:
80
+ """Load credentials from ~/.wafer/credentials.json.
81
+
82
+ Returns None if file doesn't exist or is invalid.
83
+ """
84
+ if not CREDENTIALS_FILE.exists():
85
+ return None
86
+ try:
87
+ data = json.loads(CREDENTIALS_FILE.read_text())
88
+ return Credentials(
89
+ access_token=data["access_token"],
90
+ refresh_token=data.get("refresh_token"),
91
+ email=data.get("email"),
92
+ )
93
+ except (json.JSONDecodeError, KeyError):
94
+ return None
95
+
96
+
97
+ def clear_credentials() -> bool:
98
+ """Remove credentials file.
99
+
100
+ Returns True if file was removed, False if it didn't exist.
101
+ """
102
+ if CREDENTIALS_FILE.exists():
103
+ CREDENTIALS_FILE.unlink()
104
+ return True
105
+ return False
106
+
107
+
108
+ def get_auth_headers() -> dict[str, str]:
109
+ """Get Authorization headers with a valid token.
110
+
111
+ Automatically refreshes expired tokens if a refresh token is available.
112
+
113
+ Returns empty dict if not logged in or refresh fails.
114
+ """
115
+ token = get_valid_token()
116
+ if token:
117
+ return {"Authorization": f"Bearer {token}"}
118
+ return {}
119
+
120
+
121
+ def verify_token(token: str) -> UserInfo:
122
+ """Verify token with wafer-api and return user info.
123
+
124
+ Raises:
125
+ httpx.HTTPStatusError: If token is invalid (401) or other HTTP error
126
+ httpx.RequestError: If API is unreachable
127
+ """
128
+ api_url = get_api_url()
129
+ with httpx.Client(timeout=10.0) as client:
130
+ response = client.post(
131
+ f"{api_url}/v1/auth/verify",
132
+ json={"token": token},
133
+ )
134
+ response.raise_for_status()
135
+ data = response.json()
136
+ return UserInfo(
137
+ user_id=data["user_id"],
138
+ email=data.get("email"),
139
+ )
140
+
141
+
142
+ def refresh_access_token(refresh_token: str) -> tuple[str, str]:
143
+ """Use refresh token to get a new access token from Supabase.
144
+
145
+ Args:
146
+ refresh_token: The refresh token from previous auth
147
+
148
+ Returns:
149
+ Tuple of (new_access_token, new_refresh_token)
150
+
151
+ Raises:
152
+ httpx.HTTPStatusError: If refresh fails (e.g., refresh token expired)
153
+ httpx.RequestError: If Supabase is unreachable
154
+ """
155
+ from .global_config import get_supabase_anon_key
156
+
157
+ supabase_url = get_supabase_url()
158
+ anon_key = get_supabase_anon_key()
159
+ with httpx.Client(timeout=10.0) as client:
160
+ response = client.post(
161
+ f"{supabase_url}/auth/v1/token?grant_type=refresh_token",
162
+ json={"refresh_token": refresh_token},
163
+ headers={
164
+ "Content-Type": "application/json",
165
+ "apikey": anon_key,
166
+ },
167
+ )
168
+ response.raise_for_status()
169
+ data = response.json()
170
+ return data["access_token"], data["refresh_token"]
171
+
172
+
173
+ def get_valid_token() -> str | None:
174
+ """Get a valid access token, refreshing if necessary.
175
+
176
+ Attempts to verify the current token. If it's expired and we have a
177
+ refresh token, automatically refreshes and saves the new tokens.
178
+
179
+ Returns:
180
+ Valid access token, or None if not logged in or refresh failed
181
+ """
182
+ creds = load_credentials()
183
+ if not creds:
184
+ return None
185
+
186
+ # Try current token
187
+ try:
188
+ verify_token(creds.access_token)
189
+ return creds.access_token
190
+ except httpx.HTTPStatusError as e:
191
+ if e.response.status_code != 401:
192
+ # Not an auth error, re-raise
193
+ raise
194
+
195
+ # Token expired, try refresh
196
+ if not creds.refresh_token:
197
+ return None
198
+
199
+ try:
200
+ new_access, new_refresh = refresh_access_token(creds.refresh_token)
201
+ save_credentials(new_access, new_refresh, creds.email)
202
+ return new_access
203
+ except httpx.HTTPStatusError:
204
+ # Refresh failed, need to re-login
205
+ return None
206
+
207
+
208
+ def _find_free_port() -> int:
209
+ """Find a free port for the callback server."""
210
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
211
+ s.bind(("", 0))
212
+ return s.getsockname()[1]
213
+
214
+
215
+ class OAuthCallbackHandler(BaseHTTPRequestHandler):
216
+ """HTTP handler that catches the OAuth callback with access token."""
217
+
218
+ access_token: str | None = None
219
+ refresh_token: str | None = None
220
+ error: str | None = None
221
+
222
+ def log_message(self, format: str, *args: object) -> None: # noqa: A002
223
+ """Suppress default logging."""
224
+ pass
225
+
226
+ def do_GET(self) -> None:
227
+ """Handle GET request - catch the callback or serve the HTML page."""
228
+ parsed = urlparse(self.path)
229
+
230
+ if parsed.path == "/callback":
231
+ # This is the redirect from Supabase with hash fragment
232
+ # But hash fragments aren't sent to server, so serve a page that extracts it
233
+ html = """<!DOCTYPE html>
234
+ <html>
235
+ <head><meta charset="UTF-8"><title>Wafer CLI Login</title></head>
236
+ <body>
237
+ <h2>Completing login...</h2>
238
+ <script>
239
+ // Extract tokens from hash fragment
240
+ const hash = window.location.hash.substring(1);
241
+ const params = new URLSearchParams(hash);
242
+ const accessToken = params.get('access_token');
243
+ const refreshToken = params.get('refresh_token');
244
+ const error = params.get('error_description') || params.get('error');
245
+
246
+ if (accessToken) {
247
+ // Send both tokens to our local server
248
+ let url = '/token?access_token=' + encodeURIComponent(accessToken);
249
+ if (refreshToken) {
250
+ url += '&refresh_token=' + encodeURIComponent(refreshToken);
251
+ }
252
+ fetch(url)
253
+ .then(() => {
254
+ document.body.innerHTML = '<h2>✓ Login successful!</h2><p>You can close this window.</p>';
255
+ });
256
+ } else if (error) {
257
+ fetch('/token?error=' + encodeURIComponent(error));
258
+ document.body.innerHTML = '<h2>✗ Login failed</h2><p>' + error + '</p>';
259
+ } else {
260
+ document.body.innerHTML = '<h2>✗ No token received</h2>';
261
+ }
262
+ </script>
263
+ </body>
264
+ </html>"""
265
+ self.send_response(200)
266
+ self.send_header("Content-Type", "text/html; charset=utf-8")
267
+ self.end_headers()
268
+ self.wfile.write(html.encode())
269
+
270
+ elif parsed.path == "/token":
271
+ # JavaScript sends us the tokens
272
+ params = parse_qs(parsed.query)
273
+ if "access_token" in params:
274
+ OAuthCallbackHandler.access_token = params["access_token"][0]
275
+ if "refresh_token" in params:
276
+ OAuthCallbackHandler.refresh_token = params["refresh_token"][0]
277
+ elif "error" in params:
278
+ OAuthCallbackHandler.error = params["error"][0]
279
+
280
+ self.send_response(200)
281
+ self.send_header("Content-Type", "text/plain")
282
+ self.end_headers()
283
+ self.wfile.write(b"OK")
284
+
285
+ else:
286
+ self.send_response(404)
287
+ self.end_headers()
288
+
289
+
290
+ def browser_login(timeout: int = 120, port: int | None = None) -> tuple[str, str | None]:
291
+ """Open browser for GitHub OAuth and return tokens.
292
+
293
+ Starts a local HTTP server, opens browser to Supabase OAuth,
294
+ and waits for the callback with the tokens.
295
+
296
+ Args:
297
+ timeout: Seconds to wait for callback (default 120)
298
+ port: Port for callback server. If None, finds a free port (default None)
299
+
300
+ Returns:
301
+ Tuple of (access_token, refresh_token). refresh_token may be None.
302
+
303
+ Raises:
304
+ TimeoutError: If no callback received within timeout
305
+ RuntimeError: If OAuth flow failed
306
+ """
307
+ if port is None:
308
+ port = _find_free_port()
309
+ redirect_uri = f"http://localhost:{port}/callback"
310
+ supabase_url = get_supabase_url()
311
+
312
+ # Build OAuth URL
313
+ auth_url = f"{supabase_url}/auth/v1/authorize?provider=github&redirect_to={redirect_uri}"
314
+
315
+ # Reset state
316
+ OAuthCallbackHandler.access_token = None
317
+ OAuthCallbackHandler.refresh_token = None
318
+ OAuthCallbackHandler.error = None
319
+
320
+ # Start local server
321
+ server = HTTPServer(("localhost", port), OAuthCallbackHandler)
322
+ server.timeout = 1 # Check for token every second
323
+
324
+ # Open browser
325
+ print("Opening browser for GitHub authentication...")
326
+ print(f"If browser doesn't open, visit: {auth_url}")
327
+ webbrowser.open(auth_url)
328
+
329
+ # Wait for callback
330
+ start = time.time()
331
+ print("Waiting for authentication...", end="", flush=True)
332
+
333
+ while time.time() - start < timeout:
334
+ server.handle_request()
335
+
336
+ if OAuthCallbackHandler.access_token:
337
+ print(f" {CHECK}")
338
+ server.server_close()
339
+ return OAuthCallbackHandler.access_token, OAuthCallbackHandler.refresh_token
340
+
341
+ if OAuthCallbackHandler.error:
342
+ print(f" {CROSS}")
343
+ server.server_close()
344
+ raise RuntimeError(f"OAuth failed: {OAuthCallbackHandler.error}")
345
+
346
+ server.server_close()
347
+ raise TimeoutError(f"No response within {timeout} seconds")
348
+
349
+
350
+ def device_code_login(timeout: int = 600) -> tuple[str, str | None]:
351
+ """Authenticate using state-based flow (no browser/port forwarding needed).
352
+
353
+ This is the SSH-friendly auth flow similar to GitHub CLI:
354
+ 1. Request a state token from the API
355
+ 2. Display the auth URL with state parameter
356
+ 3. User visits URL on any device and signs in normally
357
+ 4. Poll API until user completes authentication
358
+
359
+ Args:
360
+ timeout: Seconds to wait for authentication (default 600 = 10 minutes)
361
+
362
+ Returns:
363
+ Tuple of (access_token, refresh_token). refresh_token may be None.
364
+
365
+ Raises:
366
+ TimeoutError: If user doesn't authenticate within timeout
367
+ RuntimeError: If auth flow failed
368
+ """
369
+ api_url = get_api_url()
370
+
371
+ # Request state and auth URL
372
+ with httpx.Client(timeout=10.0) as client:
373
+ response = client.post(f"{api_url}/v1/auth/cli-auth/start", json={})
374
+ response.raise_for_status()
375
+ data = response.json()
376
+
377
+ state = data["state"]
378
+ auth_url = data["auth_url"]
379
+ expires_in = data["expires_in"]
380
+
381
+ # Display instructions to user
382
+ print("\n" + "=" * 60)
383
+ print(" WAFER CLI - Authentication")
384
+ print("=" * 60)
385
+ print(f"\n Visit: {auth_url}")
386
+ print("\n Sign in with GitHub to complete authentication")
387
+ print("\n" + "=" * 60 + "\n")
388
+
389
+ # Poll for authentication
390
+ start = time.time()
391
+ poll_interval = 5 # Poll every 5 seconds
392
+ last_poll = 0.0
393
+
394
+ print("Waiting for authentication", end="", flush=True)
395
+
396
+ while time.time() - start < min(timeout, expires_in):
397
+ # Show progress dots
398
+ if time.time() - last_poll >= poll_interval:
399
+ print(".", end="", flush=True)
400
+
401
+ # Poll the API
402
+ with httpx.Client(timeout=10.0) as client:
403
+ try:
404
+ response = client.post(f"{api_url}/v1/auth/cli-auth/token", json={"state": state})
405
+
406
+ if response.status_code == 200:
407
+ # Success!
408
+ data = response.json()
409
+ print(f" {CHECK}\n")
410
+ return data["access_token"], data.get("refresh_token")
411
+
412
+ if response.status_code == 428:
413
+ # Still waiting
414
+ last_poll = time.time()
415
+ time.sleep(1)
416
+ continue
417
+
418
+ # Some other error
419
+ print(f" {CROSS}\n")
420
+ raise RuntimeError(f"CLI auth flow failed: {response.status_code} {response.text}")
421
+
422
+ except httpx.RequestError:
423
+ # Network error, retry
424
+ print("!", end="", flush=True)
425
+ last_poll = time.time()
426
+ time.sleep(1)
427
+ continue
428
+
429
+ time.sleep(0.5) # Small sleep to avoid busy loop
430
+
431
+ print(f" {CROSS}\n")
432
+ raise TimeoutError(f"Authentication not completed within {expires_in} seconds")