wafer-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/__init__.py +3 -0
- wafer/api_client.py +201 -0
- wafer/auth.py +254 -0
- wafer/cli.py +1536 -0
- wafer/compiler_analyze.py +63 -0
- wafer/config.py +105 -0
- wafer/evaluate.py +911 -0
- wafer/gpu_run.py +303 -0
- wafer/inference.py +148 -0
- wafer/ncu_analyze.py +571 -0
- wafer/targets.py +296 -0
- wafer/wevin_cli.py +897 -0
- wafer_cli-0.1.0.dist-info/METADATA +9 -0
- wafer_cli-0.1.0.dist-info/RECORD +17 -0
- wafer_cli-0.1.0.dist-info/WHEEL +5 -0
- wafer_cli-0.1.0.dist-info/entry_points.txt +2 -0
- wafer_cli-0.1.0.dist-info/top_level.txt +1 -0
wafer/__init__.py
ADDED
wafer/api_client.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Wafer API client for remote GPU operations.
|
|
2
|
+
|
|
3
|
+
Thin client that calls wafer-api endpoints instead of direct SSH.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import sys
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
# Default API URL (can be overridden via environment variable)
|
|
14
|
+
DEFAULT_API_URL = "http://localhost:8000"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class PushResult:
|
|
19
|
+
"""Result of pushing files to GPU."""
|
|
20
|
+
|
|
21
|
+
workspace_id: str
|
|
22
|
+
workspace_path: str
|
|
23
|
+
files_uploaded: list[str]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class ApiConfig:
|
|
28
|
+
"""API client configuration."""
|
|
29
|
+
|
|
30
|
+
base_url: str = DEFAULT_API_URL
|
|
31
|
+
timeout: float = 60.0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_api_url() -> str:
|
|
35
|
+
"""Get API URL from environment or default."""
|
|
36
|
+
import os
|
|
37
|
+
|
|
38
|
+
return os.environ.get("WAFER_API_URL", DEFAULT_API_URL)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_auth_headers() -> dict[str, str]:
|
|
42
|
+
"""Get auth headers from stored credentials (lazy import to avoid circular)."""
|
|
43
|
+
from .auth import get_auth_headers
|
|
44
|
+
|
|
45
|
+
return get_auth_headers()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def push_directory(local_path: Path, workspace_name: str | None = None) -> PushResult:
|
|
49
|
+
"""Push local directory to GPU via wafer-api.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
local_path: Local directory to upload
|
|
53
|
+
workspace_name: Optional workspace name (defaults to directory name)
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
PushResult with workspace_id and uploaded files
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
FileNotFoundError: If local_path doesn't exist
|
|
60
|
+
ValueError: If local_path is not a directory
|
|
61
|
+
httpx.HTTPError: If API request fails
|
|
62
|
+
"""
|
|
63
|
+
if not local_path.exists():
|
|
64
|
+
raise FileNotFoundError(f"Path not found: {local_path}")
|
|
65
|
+
if not local_path.is_dir():
|
|
66
|
+
raise ValueError(f"Not a directory: {local_path}")
|
|
67
|
+
|
|
68
|
+
# Collect files and encode as base64
|
|
69
|
+
files = []
|
|
70
|
+
for file_path in local_path.rglob("*"):
|
|
71
|
+
if file_path.is_file():
|
|
72
|
+
relative_path = file_path.relative_to(local_path)
|
|
73
|
+
content = file_path.read_bytes()
|
|
74
|
+
files.append({
|
|
75
|
+
"path": str(relative_path),
|
|
76
|
+
"content": base64.b64encode(content).decode(),
|
|
77
|
+
})
|
|
78
|
+
|
|
79
|
+
# Build request
|
|
80
|
+
request_body = {
|
|
81
|
+
"files": files,
|
|
82
|
+
"workspace_name": workspace_name or local_path.name,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# Call API
|
|
86
|
+
api_url = get_api_url()
|
|
87
|
+
headers = _get_auth_headers()
|
|
88
|
+
with httpx.Client(timeout=60.0, headers=headers) as client:
|
|
89
|
+
response = client.post(f"{api_url}/v1/gpu/push", json=request_body)
|
|
90
|
+
response.raise_for_status()
|
|
91
|
+
data = response.json()
|
|
92
|
+
|
|
93
|
+
return PushResult(
|
|
94
|
+
workspace_id=data["workspace_id"],
|
|
95
|
+
workspace_path=data["workspace_path"],
|
|
96
|
+
files_uploaded=data["files_uploaded"],
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _collect_files(local_path: Path) -> list[dict]:
|
|
101
|
+
"""Collect files from directory as base64-encoded dicts."""
|
|
102
|
+
files = []
|
|
103
|
+
for file_path in local_path.rglob("*"):
|
|
104
|
+
if file_path.is_file():
|
|
105
|
+
relative_path = file_path.relative_to(local_path)
|
|
106
|
+
content = file_path.read_bytes()
|
|
107
|
+
files.append({
|
|
108
|
+
"path": str(relative_path),
|
|
109
|
+
"content": base64.b64encode(content).decode(),
|
|
110
|
+
})
|
|
111
|
+
return files
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def run_command_stream(
|
|
115
|
+
command: str,
|
|
116
|
+
upload_dir: Path | None = None,
|
|
117
|
+
workspace_id: str | None = None,
|
|
118
|
+
gpu_id: int | None = None,
|
|
119
|
+
docker_image: str | None = None,
|
|
120
|
+
docker_entrypoint: str | None = None,
|
|
121
|
+
pull_image: bool = False,
|
|
122
|
+
require_hardware_counters: bool = False,
|
|
123
|
+
target: str | None = None,
|
|
124
|
+
) -> int:
|
|
125
|
+
"""Run command on GPU via wafer-api, streaming output.
|
|
126
|
+
|
|
127
|
+
Two modes (mutually exclusive):
|
|
128
|
+
- upload_dir: Upload files and run (stateless, high-level)
|
|
129
|
+
- workspace_id: Use existing workspace (low-level)
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
command: Command to execute inside container
|
|
133
|
+
upload_dir: Directory to upload (stateless mode)
|
|
134
|
+
workspace_id: Workspace ID from push (low-level mode)
|
|
135
|
+
gpu_id: GPU ID to use (optional)
|
|
136
|
+
docker_image: Docker image override (optional)
|
|
137
|
+
docker_entrypoint: Docker entrypoint override (optional, e.g., "bash")
|
|
138
|
+
pull_image: Pull image if not available (optional, default False)
|
|
139
|
+
require_hardware_counters: Require baremetal for ncu profiling (optional)
|
|
140
|
+
target: Target name to use (optional, defaults to user's default)
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Exit code (0 = success, non-zero = failure)
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
httpx.HTTPError: If API request fails
|
|
147
|
+
"""
|
|
148
|
+
request_body: dict = {
|
|
149
|
+
"command": command,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
# Add files or workspace_id (mutually exclusive)
|
|
153
|
+
if upload_dir is not None:
|
|
154
|
+
files = _collect_files(upload_dir)
|
|
155
|
+
request_body["files"] = files
|
|
156
|
+
request_body["workspace_name"] = upload_dir.name
|
|
157
|
+
elif workspace_id is not None:
|
|
158
|
+
request_body["workspace_id"] = workspace_id
|
|
159
|
+
# else: no files, no workspace (run command in temp workspace)
|
|
160
|
+
|
|
161
|
+
if gpu_id is not None:
|
|
162
|
+
request_body["gpu_id"] = gpu_id
|
|
163
|
+
if docker_image is not None:
|
|
164
|
+
request_body["docker_image"] = docker_image
|
|
165
|
+
if docker_entrypoint is not None:
|
|
166
|
+
request_body["docker_entrypoint"] = docker_entrypoint
|
|
167
|
+
if pull_image:
|
|
168
|
+
request_body["pull_image"] = True
|
|
169
|
+
if require_hardware_counters:
|
|
170
|
+
request_body["require_hardware_counters"] = True
|
|
171
|
+
if target is not None:
|
|
172
|
+
request_body["target"] = target
|
|
173
|
+
|
|
174
|
+
api_url = get_api_url()
|
|
175
|
+
headers = _get_auth_headers()
|
|
176
|
+
exit_code = 0
|
|
177
|
+
|
|
178
|
+
with httpx.Client(timeout=None, headers=headers) as client: # No timeout for streaming
|
|
179
|
+
with client.stream("POST", f"{api_url}/v1/gpu/jobs", json=request_body) as response:
|
|
180
|
+
response.raise_for_status()
|
|
181
|
+
|
|
182
|
+
for line in response.iter_lines():
|
|
183
|
+
if not line:
|
|
184
|
+
continue
|
|
185
|
+
|
|
186
|
+
# Parse SSE format: "data: <content>"
|
|
187
|
+
if line.startswith("data: "):
|
|
188
|
+
content = line[6:] # Strip "data: " prefix
|
|
189
|
+
|
|
190
|
+
if content == "[DONE]":
|
|
191
|
+
break
|
|
192
|
+
elif content.startswith("[ERROR]"):
|
|
193
|
+
print(content[8:], file=sys.stderr) # Strip "[ERROR] "
|
|
194
|
+
exit_code = 1
|
|
195
|
+
break
|
|
196
|
+
else:
|
|
197
|
+
print(content)
|
|
198
|
+
|
|
199
|
+
return exit_code
|
|
200
|
+
|
|
201
|
+
|
wafer/auth.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""CLI authentication and credential management.
|
|
2
|
+
|
|
3
|
+
Handles storing/loading credentials from ~/.wafer/credentials.json
|
|
4
|
+
and verifying tokens against the wafer-api.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import socket
|
|
9
|
+
import webbrowser
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from urllib.parse import parse_qs, urlparse
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
|
|
17
|
+
from .api_client import get_api_url
|
|
18
|
+
|
|
19
|
+
# Default Supabase project URL (can be overridden)
|
|
20
|
+
DEFAULT_SUPABASE_URL = "https://auth.wafer.ai"
|
|
21
|
+
|
|
22
|
+
CREDENTIALS_DIR = Path.home() / ".wafer"
|
|
23
|
+
CREDENTIALS_FILE = CREDENTIALS_DIR / "credentials.json"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class Credentials:
|
|
28
|
+
"""Stored credentials."""
|
|
29
|
+
|
|
30
|
+
access_token: str
|
|
31
|
+
email: str | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class UserInfo:
|
|
36
|
+
"""User info from token verification."""
|
|
37
|
+
|
|
38
|
+
user_id: str
|
|
39
|
+
email: str | None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def save_credentials(token: str, email: str | None = None) -> None:
|
|
43
|
+
"""Save credentials to ~/.wafer/credentials.json."""
|
|
44
|
+
CREDENTIALS_DIR.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
data = {"access_token": token}
|
|
46
|
+
if email:
|
|
47
|
+
data["email"] = email
|
|
48
|
+
CREDENTIALS_FILE.write_text(json.dumps(data, indent=2))
|
|
49
|
+
# Set restrictive permissions (owner read/write only)
|
|
50
|
+
CREDENTIALS_FILE.chmod(0o600)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_credentials() -> Credentials | None:
|
|
54
|
+
"""Load credentials from ~/.wafer/credentials.json.
|
|
55
|
+
|
|
56
|
+
Returns None if file doesn't exist or is invalid.
|
|
57
|
+
"""
|
|
58
|
+
if not CREDENTIALS_FILE.exists():
|
|
59
|
+
return None
|
|
60
|
+
try:
|
|
61
|
+
data = json.loads(CREDENTIALS_FILE.read_text())
|
|
62
|
+
return Credentials(
|
|
63
|
+
access_token=data["access_token"],
|
|
64
|
+
email=data.get("email"),
|
|
65
|
+
)
|
|
66
|
+
except (json.JSONDecodeError, KeyError):
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def clear_credentials() -> bool:
|
|
71
|
+
"""Remove credentials file.
|
|
72
|
+
|
|
73
|
+
Returns True if file was removed, False if it didn't exist.
|
|
74
|
+
"""
|
|
75
|
+
if CREDENTIALS_FILE.exists():
|
|
76
|
+
CREDENTIALS_FILE.unlink()
|
|
77
|
+
return True
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_auth_headers() -> dict[str, str]:
|
|
82
|
+
"""Get Authorization headers if credentials exist.
|
|
83
|
+
|
|
84
|
+
Returns empty dict if not logged in.
|
|
85
|
+
"""
|
|
86
|
+
creds = load_credentials()
|
|
87
|
+
if creds:
|
|
88
|
+
return {"Authorization": f"Bearer {creds.access_token}"}
|
|
89
|
+
return {}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def verify_token(token: str) -> UserInfo:
|
|
93
|
+
"""Verify token with wafer-api and return user info.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
httpx.HTTPStatusError: If token is invalid (401) or other HTTP error
|
|
97
|
+
httpx.RequestError: If API is unreachable
|
|
98
|
+
"""
|
|
99
|
+
api_url = get_api_url()
|
|
100
|
+
with httpx.Client(timeout=10.0) as client:
|
|
101
|
+
response = client.post(
|
|
102
|
+
f"{api_url}/v1/auth/verify",
|
|
103
|
+
json={"token": token},
|
|
104
|
+
)
|
|
105
|
+
response.raise_for_status()
|
|
106
|
+
data = response.json()
|
|
107
|
+
return UserInfo(
|
|
108
|
+
user_id=data["user_id"],
|
|
109
|
+
email=data.get("email"),
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _find_free_port() -> int:
|
|
114
|
+
"""Find a free port for the callback server."""
|
|
115
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
116
|
+
s.bind(("", 0))
|
|
117
|
+
return s.getsockname()[1]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_supabase_url() -> str:
|
|
121
|
+
"""Get Supabase URL from environment or default."""
|
|
122
|
+
import os
|
|
123
|
+
|
|
124
|
+
return os.environ.get("SUPABASE_URL", DEFAULT_SUPABASE_URL)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class OAuthCallbackHandler(BaseHTTPRequestHandler):
|
|
128
|
+
"""HTTP handler that catches the OAuth callback with access token."""
|
|
129
|
+
|
|
130
|
+
access_token: str | None = None
|
|
131
|
+
error: str | None = None
|
|
132
|
+
|
|
133
|
+
def log_message(self, format: str, *args: object) -> None:
|
|
134
|
+
"""Suppress default logging."""
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
def do_GET(self) -> None:
|
|
138
|
+
"""Handle GET request - catch the callback or serve the HTML page."""
|
|
139
|
+
parsed = urlparse(self.path)
|
|
140
|
+
|
|
141
|
+
if parsed.path == "/callback":
|
|
142
|
+
# This is the redirect from Supabase with hash fragment
|
|
143
|
+
# But hash fragments aren't sent to server, so serve a page that extracts it
|
|
144
|
+
html = """<!DOCTYPE html>
|
|
145
|
+
<html>
|
|
146
|
+
<head><title>Wafer CLI Login</title></head>
|
|
147
|
+
<body>
|
|
148
|
+
<h2>Completing login...</h2>
|
|
149
|
+
<script>
|
|
150
|
+
// Extract token from hash fragment
|
|
151
|
+
const hash = window.location.hash.substring(1);
|
|
152
|
+
const params = new URLSearchParams(hash);
|
|
153
|
+
const accessToken = params.get('access_token');
|
|
154
|
+
const error = params.get('error_description') || params.get('error');
|
|
155
|
+
|
|
156
|
+
if (accessToken) {
|
|
157
|
+
// Send token to our local server
|
|
158
|
+
fetch('/token?access_token=' + encodeURIComponent(accessToken))
|
|
159
|
+
.then(() => {
|
|
160
|
+
document.body.innerHTML = '<h2>✓ Login successful!</h2><p>You can close this window.</p>';
|
|
161
|
+
});
|
|
162
|
+
} else if (error) {
|
|
163
|
+
fetch('/token?error=' + encodeURIComponent(error));
|
|
164
|
+
document.body.innerHTML = '<h2>✗ Login failed</h2><p>' + error + '</p>';
|
|
165
|
+
} else {
|
|
166
|
+
document.body.innerHTML = '<h2>✗ No token received</h2>';
|
|
167
|
+
}
|
|
168
|
+
</script>
|
|
169
|
+
</body>
|
|
170
|
+
</html>"""
|
|
171
|
+
self.send_response(200)
|
|
172
|
+
self.send_header("Content-Type", "text/html")
|
|
173
|
+
self.end_headers()
|
|
174
|
+
self.wfile.write(html.encode())
|
|
175
|
+
|
|
176
|
+
elif parsed.path == "/token":
|
|
177
|
+
# JavaScript sends us the token
|
|
178
|
+
params = parse_qs(parsed.query)
|
|
179
|
+
if "access_token" in params:
|
|
180
|
+
OAuthCallbackHandler.access_token = params["access_token"][0]
|
|
181
|
+
elif "error" in params:
|
|
182
|
+
OAuthCallbackHandler.error = params["error"][0]
|
|
183
|
+
|
|
184
|
+
self.send_response(200)
|
|
185
|
+
self.send_header("Content-Type", "text/plain")
|
|
186
|
+
self.end_headers()
|
|
187
|
+
self.wfile.write(b"OK")
|
|
188
|
+
|
|
189
|
+
else:
|
|
190
|
+
self.send_response(404)
|
|
191
|
+
self.end_headers()
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def browser_login(timeout: int = 120) -> str:
|
|
195
|
+
"""Open browser for GitHub OAuth and return access token.
|
|
196
|
+
|
|
197
|
+
Starts a local HTTP server, opens browser to Supabase OAuth,
|
|
198
|
+
and waits for the callback with the access token.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
timeout: Seconds to wait for callback (default 120)
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Access token string
|
|
205
|
+
|
|
206
|
+
Raises:
|
|
207
|
+
TimeoutError: If no callback received within timeout
|
|
208
|
+
RuntimeError: If OAuth flow failed
|
|
209
|
+
"""
|
|
210
|
+
import time
|
|
211
|
+
|
|
212
|
+
port = _find_free_port()
|
|
213
|
+
redirect_uri = f"http://localhost:{port}/callback"
|
|
214
|
+
supabase_url = get_supabase_url()
|
|
215
|
+
|
|
216
|
+
# Build OAuth URL
|
|
217
|
+
auth_url = (
|
|
218
|
+
f"{supabase_url}/auth/v1/authorize"
|
|
219
|
+
f"?provider=github"
|
|
220
|
+
f"&redirect_to={redirect_uri}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Reset state
|
|
224
|
+
OAuthCallbackHandler.access_token = None
|
|
225
|
+
OAuthCallbackHandler.error = None
|
|
226
|
+
|
|
227
|
+
# Start local server
|
|
228
|
+
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
|
|
229
|
+
server.timeout = 1 # Check for token every second
|
|
230
|
+
|
|
231
|
+
# Open browser
|
|
232
|
+
print("Opening browser for GitHub authentication...")
|
|
233
|
+
print(f"If browser doesn't open, visit: {auth_url}")
|
|
234
|
+
webbrowser.open(auth_url)
|
|
235
|
+
|
|
236
|
+
# Wait for callback
|
|
237
|
+
start = time.time()
|
|
238
|
+
print("Waiting for authentication...", end="", flush=True)
|
|
239
|
+
|
|
240
|
+
while time.time() - start < timeout:
|
|
241
|
+
server.handle_request()
|
|
242
|
+
|
|
243
|
+
if OAuthCallbackHandler.access_token:
|
|
244
|
+
print(" ✓")
|
|
245
|
+
server.server_close()
|
|
246
|
+
return OAuthCallbackHandler.access_token
|
|
247
|
+
|
|
248
|
+
if OAuthCallbackHandler.error:
|
|
249
|
+
print(" ✗")
|
|
250
|
+
server.server_close()
|
|
251
|
+
raise RuntimeError(f"OAuth failed: {OAuthCallbackHandler.error}")
|
|
252
|
+
|
|
253
|
+
server.server_close()
|
|
254
|
+
raise TimeoutError(f"No response within {timeout} seconds")
|