wafer-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """Wafer - Standalone tool for Docker execution on remote GPUs."""
2
+
3
+ __version__ = "0.1.0"
wafer/api_client.py ADDED
@@ -0,0 +1,201 @@
1
+ """Wafer API client for remote GPU operations.
2
+
3
+ Thin client that calls wafer-api endpoints instead of direct SSH.
4
+ """
5
+
6
+ import base64
7
+ import sys
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+
11
+ import httpx
12
+
13
+ # Default API URL (can be overridden via environment variable)
14
+ DEFAULT_API_URL = "http://localhost:8000"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class PushResult:
19
+ """Result of pushing files to GPU."""
20
+
21
+ workspace_id: str
22
+ workspace_path: str
23
+ files_uploaded: list[str]
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class ApiConfig:
28
+ """API client configuration."""
29
+
30
+ base_url: str = DEFAULT_API_URL
31
+ timeout: float = 60.0
32
+
33
+
34
+ def get_api_url() -> str:
35
+ """Get API URL from environment or default."""
36
+ import os
37
+
38
+ return os.environ.get("WAFER_API_URL", DEFAULT_API_URL)
39
+
40
+
41
+ def _get_auth_headers() -> dict[str, str]:
42
+ """Get auth headers from stored credentials (lazy import to avoid circular)."""
43
+ from .auth import get_auth_headers
44
+
45
+ return get_auth_headers()
46
+
47
+
48
+ def push_directory(local_path: Path, workspace_name: str | None = None) -> PushResult:
49
+ """Push local directory to GPU via wafer-api.
50
+
51
+ Args:
52
+ local_path: Local directory to upload
53
+ workspace_name: Optional workspace name (defaults to directory name)
54
+
55
+ Returns:
56
+ PushResult with workspace_id and uploaded files
57
+
58
+ Raises:
59
+ FileNotFoundError: If local_path doesn't exist
60
+ ValueError: If local_path is not a directory
61
+ httpx.HTTPError: If API request fails
62
+ """
63
+ if not local_path.exists():
64
+ raise FileNotFoundError(f"Path not found: {local_path}")
65
+ if not local_path.is_dir():
66
+ raise ValueError(f"Not a directory: {local_path}")
67
+
68
+ # Collect files and encode as base64
69
+ files = []
70
+ for file_path in local_path.rglob("*"):
71
+ if file_path.is_file():
72
+ relative_path = file_path.relative_to(local_path)
73
+ content = file_path.read_bytes()
74
+ files.append({
75
+ "path": str(relative_path),
76
+ "content": base64.b64encode(content).decode(),
77
+ })
78
+
79
+ # Build request
80
+ request_body = {
81
+ "files": files,
82
+ "workspace_name": workspace_name or local_path.name,
83
+ }
84
+
85
+ # Call API
86
+ api_url = get_api_url()
87
+ headers = _get_auth_headers()
88
+ with httpx.Client(timeout=60.0, headers=headers) as client:
89
+ response = client.post(f"{api_url}/v1/gpu/push", json=request_body)
90
+ response.raise_for_status()
91
+ data = response.json()
92
+
93
+ return PushResult(
94
+ workspace_id=data["workspace_id"],
95
+ workspace_path=data["workspace_path"],
96
+ files_uploaded=data["files_uploaded"],
97
+ )
98
+
99
+
100
+ def _collect_files(local_path: Path) -> list[dict]:
101
+ """Collect files from directory as base64-encoded dicts."""
102
+ files = []
103
+ for file_path in local_path.rglob("*"):
104
+ if file_path.is_file():
105
+ relative_path = file_path.relative_to(local_path)
106
+ content = file_path.read_bytes()
107
+ files.append({
108
+ "path": str(relative_path),
109
+ "content": base64.b64encode(content).decode(),
110
+ })
111
+ return files
112
+
113
+
114
+ def run_command_stream(
115
+ command: str,
116
+ upload_dir: Path | None = None,
117
+ workspace_id: str | None = None,
118
+ gpu_id: int | None = None,
119
+ docker_image: str | None = None,
120
+ docker_entrypoint: str | None = None,
121
+ pull_image: bool = False,
122
+ require_hardware_counters: bool = False,
123
+ target: str | None = None,
124
+ ) -> int:
125
+ """Run command on GPU via wafer-api, streaming output.
126
+
127
+ Two modes (mutually exclusive):
128
+ - upload_dir: Upload files and run (stateless, high-level)
129
+ - workspace_id: Use existing workspace (low-level)
130
+
131
+ Args:
132
+ command: Command to execute inside container
133
+ upload_dir: Directory to upload (stateless mode)
134
+ workspace_id: Workspace ID from push (low-level mode)
135
+ gpu_id: GPU ID to use (optional)
136
+ docker_image: Docker image override (optional)
137
+ docker_entrypoint: Docker entrypoint override (optional, e.g., "bash")
138
+ pull_image: Pull image if not available (optional, default False)
139
+ require_hardware_counters: Require baremetal for ncu profiling (optional)
140
+ target: Target name to use (optional, defaults to user's default)
141
+
142
+ Returns:
143
+ Exit code (0 = success, non-zero = failure)
144
+
145
+ Raises:
146
+ httpx.HTTPError: If API request fails
147
+ """
148
+ request_body: dict = {
149
+ "command": command,
150
+ }
151
+
152
+ # Add files or workspace_id (mutually exclusive)
153
+ if upload_dir is not None:
154
+ files = _collect_files(upload_dir)
155
+ request_body["files"] = files
156
+ request_body["workspace_name"] = upload_dir.name
157
+ elif workspace_id is not None:
158
+ request_body["workspace_id"] = workspace_id
159
+ # else: no files, no workspace (run command in temp workspace)
160
+
161
+ if gpu_id is not None:
162
+ request_body["gpu_id"] = gpu_id
163
+ if docker_image is not None:
164
+ request_body["docker_image"] = docker_image
165
+ if docker_entrypoint is not None:
166
+ request_body["docker_entrypoint"] = docker_entrypoint
167
+ if pull_image:
168
+ request_body["pull_image"] = True
169
+ if require_hardware_counters:
170
+ request_body["require_hardware_counters"] = True
171
+ if target is not None:
172
+ request_body["target"] = target
173
+
174
+ api_url = get_api_url()
175
+ headers = _get_auth_headers()
176
+ exit_code = 0
177
+
178
+ with httpx.Client(timeout=None, headers=headers) as client: # No timeout for streaming
179
+ with client.stream("POST", f"{api_url}/v1/gpu/jobs", json=request_body) as response:
180
+ response.raise_for_status()
181
+
182
+ for line in response.iter_lines():
183
+ if not line:
184
+ continue
185
+
186
+ # Parse SSE format: "data: <content>"
187
+ if line.startswith("data: "):
188
+ content = line[6:] # Strip "data: " prefix
189
+
190
+ if content == "[DONE]":
191
+ break
192
+ elif content.startswith("[ERROR]"):
193
+ print(content[8:], file=sys.stderr) # Strip "[ERROR] "
194
+ exit_code = 1
195
+ break
196
+ else:
197
+ print(content)
198
+
199
+ return exit_code
200
+
201
+
wafer/auth.py ADDED
@@ -0,0 +1,254 @@
1
+ """CLI authentication and credential management.
2
+
3
+ Handles storing/loading credentials from ~/.wafer/credentials.json
4
+ and verifying tokens against the wafer-api.
5
+ """
6
+
7
+ import json
8
+ import socket
9
+ import webbrowser
10
+ from dataclasses import dataclass
11
+ from http.server import BaseHTTPRequestHandler, HTTPServer
12
+ from pathlib import Path
13
+ from urllib.parse import parse_qs, urlparse
14
+
15
+ import httpx
16
+
17
+ from .api_client import get_api_url
18
+
19
+ # Default Supabase project URL (can be overridden)
20
+ DEFAULT_SUPABASE_URL = "https://auth.wafer.ai"
21
+
22
+ CREDENTIALS_DIR = Path.home() / ".wafer"
23
+ CREDENTIALS_FILE = CREDENTIALS_DIR / "credentials.json"
24
+
25
+
26
+ @dataclass
27
+ class Credentials:
28
+ """Stored credentials."""
29
+
30
+ access_token: str
31
+ email: str | None = None
32
+
33
+
34
+ @dataclass
35
+ class UserInfo:
36
+ """User info from token verification."""
37
+
38
+ user_id: str
39
+ email: str | None
40
+
41
+
42
+ def save_credentials(token: str, email: str | None = None) -> None:
43
+ """Save credentials to ~/.wafer/credentials.json."""
44
+ CREDENTIALS_DIR.mkdir(parents=True, exist_ok=True)
45
+ data = {"access_token": token}
46
+ if email:
47
+ data["email"] = email
48
+ CREDENTIALS_FILE.write_text(json.dumps(data, indent=2))
49
+ # Set restrictive permissions (owner read/write only)
50
+ CREDENTIALS_FILE.chmod(0o600)
51
+
52
+
53
+ def load_credentials() -> Credentials | None:
54
+ """Load credentials from ~/.wafer/credentials.json.
55
+
56
+ Returns None if file doesn't exist or is invalid.
57
+ """
58
+ if not CREDENTIALS_FILE.exists():
59
+ return None
60
+ try:
61
+ data = json.loads(CREDENTIALS_FILE.read_text())
62
+ return Credentials(
63
+ access_token=data["access_token"],
64
+ email=data.get("email"),
65
+ )
66
+ except (json.JSONDecodeError, KeyError):
67
+ return None
68
+
69
+
70
+ def clear_credentials() -> bool:
71
+ """Remove credentials file.
72
+
73
+ Returns True if file was removed, False if it didn't exist.
74
+ """
75
+ if CREDENTIALS_FILE.exists():
76
+ CREDENTIALS_FILE.unlink()
77
+ return True
78
+ return False
79
+
80
+
81
+ def get_auth_headers() -> dict[str, str]:
82
+ """Get Authorization headers if credentials exist.
83
+
84
+ Returns empty dict if not logged in.
85
+ """
86
+ creds = load_credentials()
87
+ if creds:
88
+ return {"Authorization": f"Bearer {creds.access_token}"}
89
+ return {}
90
+
91
+
92
+ def verify_token(token: str) -> UserInfo:
93
+ """Verify token with wafer-api and return user info.
94
+
95
+ Raises:
96
+ httpx.HTTPStatusError: If token is invalid (401) or other HTTP error
97
+ httpx.RequestError: If API is unreachable
98
+ """
99
+ api_url = get_api_url()
100
+ with httpx.Client(timeout=10.0) as client:
101
+ response = client.post(
102
+ f"{api_url}/v1/auth/verify",
103
+ json={"token": token},
104
+ )
105
+ response.raise_for_status()
106
+ data = response.json()
107
+ return UserInfo(
108
+ user_id=data["user_id"],
109
+ email=data.get("email"),
110
+ )
111
+
112
+
113
+ def _find_free_port() -> int:
114
+ """Find a free port for the callback server."""
115
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
116
+ s.bind(("", 0))
117
+ return s.getsockname()[1]
118
+
119
+
120
+ def get_supabase_url() -> str:
121
+ """Get Supabase URL from environment or default."""
122
+ import os
123
+
124
+ return os.environ.get("SUPABASE_URL", DEFAULT_SUPABASE_URL)
125
+
126
+
127
+ class OAuthCallbackHandler(BaseHTTPRequestHandler):
128
+ """HTTP handler that catches the OAuth callback with access token."""
129
+
130
+ access_token: str | None = None
131
+ error: str | None = None
132
+
133
+ def log_message(self, format: str, *args: object) -> None:
134
+ """Suppress default logging."""
135
+ pass
136
+
137
+ def do_GET(self) -> None:
138
+ """Handle GET request - catch the callback or serve the HTML page."""
139
+ parsed = urlparse(self.path)
140
+
141
+ if parsed.path == "/callback":
142
+ # This is the redirect from Supabase with hash fragment
143
+ # But hash fragments aren't sent to server, so serve a page that extracts it
144
+ html = """<!DOCTYPE html>
145
+ <html>
146
+ <head><title>Wafer CLI Login</title></head>
147
+ <body>
148
+ <h2>Completing login...</h2>
149
+ <script>
150
+ // Extract token from hash fragment
151
+ const hash = window.location.hash.substring(1);
152
+ const params = new URLSearchParams(hash);
153
+ const accessToken = params.get('access_token');
154
+ const error = params.get('error_description') || params.get('error');
155
+
156
+ if (accessToken) {
157
+ // Send token to our local server
158
+ fetch('/token?access_token=' + encodeURIComponent(accessToken))
159
+ .then(() => {
160
+ document.body.innerHTML = '<h2>✓ Login successful!</h2><p>You can close this window.</p>';
161
+ });
162
+ } else if (error) {
163
+ fetch('/token?error=' + encodeURIComponent(error));
164
+ document.body.innerHTML = '<h2>✗ Login failed</h2><p>' + error + '</p>';
165
+ } else {
166
+ document.body.innerHTML = '<h2>✗ No token received</h2>';
167
+ }
168
+ </script>
169
+ </body>
170
+ </html>"""
171
+ self.send_response(200)
172
+ self.send_header("Content-Type", "text/html")
173
+ self.end_headers()
174
+ self.wfile.write(html.encode())
175
+
176
+ elif parsed.path == "/token":
177
+ # JavaScript sends us the token
178
+ params = parse_qs(parsed.query)
179
+ if "access_token" in params:
180
+ OAuthCallbackHandler.access_token = params["access_token"][0]
181
+ elif "error" in params:
182
+ OAuthCallbackHandler.error = params["error"][0]
183
+
184
+ self.send_response(200)
185
+ self.send_header("Content-Type", "text/plain")
186
+ self.end_headers()
187
+ self.wfile.write(b"OK")
188
+
189
+ else:
190
+ self.send_response(404)
191
+ self.end_headers()
192
+
193
+
194
+ def browser_login(timeout: int = 120) -> str:
195
+ """Open browser for GitHub OAuth and return access token.
196
+
197
+ Starts a local HTTP server, opens browser to Supabase OAuth,
198
+ and waits for the callback with the access token.
199
+
200
+ Args:
201
+ timeout: Seconds to wait for callback (default 120)
202
+
203
+ Returns:
204
+ Access token string
205
+
206
+ Raises:
207
+ TimeoutError: If no callback received within timeout
208
+ RuntimeError: If OAuth flow failed
209
+ """
210
+ import time
211
+
212
+ port = _find_free_port()
213
+ redirect_uri = f"http://localhost:{port}/callback"
214
+ supabase_url = get_supabase_url()
215
+
216
+ # Build OAuth URL
217
+ auth_url = (
218
+ f"{supabase_url}/auth/v1/authorize"
219
+ f"?provider=github"
220
+ f"&redirect_to={redirect_uri}"
221
+ )
222
+
223
+ # Reset state
224
+ OAuthCallbackHandler.access_token = None
225
+ OAuthCallbackHandler.error = None
226
+
227
+ # Start local server
228
+ server = HTTPServer(("localhost", port), OAuthCallbackHandler)
229
+ server.timeout = 1 # Check for token every second
230
+
231
+ # Open browser
232
+ print("Opening browser for GitHub authentication...")
233
+ print(f"If browser doesn't open, visit: {auth_url}")
234
+ webbrowser.open(auth_url)
235
+
236
+ # Wait for callback
237
+ start = time.time()
238
+ print("Waiting for authentication...", end="", flush=True)
239
+
240
+ while time.time() - start < timeout:
241
+ server.handle_request()
242
+
243
+ if OAuthCallbackHandler.access_token:
244
+ print(" ✓")
245
+ server.server_close()
246
+ return OAuthCallbackHandler.access_token
247
+
248
+ if OAuthCallbackHandler.error:
249
+ print(" ✗")
250
+ server.server_close()
251
+ raise RuntimeError(f"OAuth failed: {OAuthCallbackHandler.error}")
252
+
253
+ server.server_close()
254
+ raise TimeoutError(f"No response within {timeout} seconds")