wafer-cli 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +118 -0
- wafer/__init__.py +3 -0
- wafer/analytics.py +306 -0
- wafer/api_client.py +195 -0
- wafer/auth.py +432 -0
- wafer/autotuner.py +1080 -0
- wafer/billing.py +233 -0
- wafer/cli.py +7289 -0
- wafer/config.py +105 -0
- wafer/corpus.py +366 -0
- wafer/evaluate.py +4593 -0
- wafer/global_config.py +350 -0
- wafer/gpu_run.py +307 -0
- wafer/inference.py +148 -0
- wafer/kernel_scope.py +552 -0
- wafer/ncu_analyze.py +651 -0
- wafer/nsys_analyze.py +1042 -0
- wafer/nsys_profile.py +510 -0
- wafer/output.py +248 -0
- wafer/problems.py +357 -0
- wafer/rocprof_compute.py +490 -0
- wafer/rocprof_sdk.py +274 -0
- wafer/rocprof_systems.py +520 -0
- wafer/skills/wafer-guide/SKILL.md +129 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +842 -0
- wafer/targets_ops.py +717 -0
- wafer/templates/__init__.py +0 -0
- wafer/templates/ask_docs.py +61 -0
- wafer/templates/optimize_kernel.py +71 -0
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/templates/trace_analyze.py +74 -0
- wafer/tracelens.py +218 -0
- wafer/wevin_cli.py +577 -0
- wafer/workspaces.py +852 -0
- wafer_cli-0.2.14.dist-info/METADATA +16 -0
- wafer_cli-0.2.14.dist-info/RECORD +41 -0
- wafer_cli-0.2.14.dist-info/WHEEL +5 -0
- wafer_cli-0.2.14.dist-info/entry_points.txt +2 -0
- wafer_cli-0.2.14.dist-info/top_level.txt +1 -0
wafer/workspaces.py
ADDED
|
@@ -0,0 +1,852 @@
|
|
|
1
|
+
"""Workspaces CLI - Manage remote GPU workspaces.
|
|
2
|
+
|
|
3
|
+
This module provides the implementation for the `wafer workspaces` subcommand.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from .api_client import get_api_url
|
|
14
|
+
from .auth import get_auth_headers
|
|
15
|
+
|
|
16
|
+
VALID_STATUSES = {"creating", "running"}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_client() -> tuple[str, dict[str, str]]:
|
|
20
|
+
"""Get API URL and auth headers."""
|
|
21
|
+
api_url = get_api_url()
|
|
22
|
+
headers = get_auth_headers()
|
|
23
|
+
|
|
24
|
+
assert api_url, "API URL must be configured"
|
|
25
|
+
assert api_url.startswith("http"), "API URL must be a valid HTTP(S) URL"
|
|
26
|
+
|
|
27
|
+
return api_url, headers
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _friendly_error(status_code: int, response_text: str, workspace_id: str) -> str:
|
|
31
|
+
"""Convert API errors to friendly messages with guidance.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
status_code: HTTP status code
|
|
35
|
+
response_text: Response body
|
|
36
|
+
workspace_id: Workspace ID or name for context
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
User-friendly error message with suggested next steps
|
|
40
|
+
"""
|
|
41
|
+
if status_code == 401:
|
|
42
|
+
return "Not authenticated. Run: wafer login"
|
|
43
|
+
|
|
44
|
+
if status_code == 402:
|
|
45
|
+
return (
|
|
46
|
+
"Insufficient credits.\n"
|
|
47
|
+
" Check usage: wafer billing\n"
|
|
48
|
+
" Add credits: wafer billing topup"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if status_code == 404:
|
|
52
|
+
return (
|
|
53
|
+
f"Workspace '{workspace_id}' not found.\n"
|
|
54
|
+
" List workspaces: wafer workspaces list\n"
|
|
55
|
+
" Create one: wafer workspaces create <name>"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if status_code == 503:
|
|
59
|
+
return (
|
|
60
|
+
"No GPU available.\n"
|
|
61
|
+
" The workspace is queued for GPU access. Try again in a moment.\n"
|
|
62
|
+
" Check status: wafer workspaces show " + workspace_id
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Parse common error details from response
|
|
66
|
+
detail = ""
|
|
67
|
+
if "not running" in response_text.lower() or "not found" in response_text.lower():
|
|
68
|
+
return (
|
|
69
|
+
f"Workspace '{workspace_id}' not found or not running.\n"
|
|
70
|
+
" Check status: wafer workspaces list\n"
|
|
71
|
+
" Create new: wafer workspaces create <name>"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if "timeout" in response_text.lower():
|
|
75
|
+
return (
|
|
76
|
+
"Command timed out.\n"
|
|
77
|
+
' Increase timeout: wafer workspaces exec <workspace> "cmd" --timeout 600\n'
|
|
78
|
+
" Or set default: wafer config set defaults.exec_timeout 600"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if "creating" in response_text.lower():
|
|
82
|
+
return (
|
|
83
|
+
f"Workspace '{workspace_id}' is still creating.\n"
|
|
84
|
+
" Check status: wafer workspaces list"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Generic error with response detail
|
|
88
|
+
try:
|
|
89
|
+
import json
|
|
90
|
+
|
|
91
|
+
data = json.loads(response_text)
|
|
92
|
+
detail = data.get("detail", response_text)
|
|
93
|
+
except (json.JSONDecodeError, KeyError):
|
|
94
|
+
detail = response_text
|
|
95
|
+
|
|
96
|
+
return f"API error ({status_code}): {detail}"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _list_workspaces_raw() -> list[dict]:
|
|
100
|
+
"""List workspaces and return raw data (for internal use)."""
|
|
101
|
+
api_url, headers = _get_client()
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
105
|
+
response = client.get(f"{api_url}/v1/workspaces")
|
|
106
|
+
response.raise_for_status()
|
|
107
|
+
workspaces = response.json()
|
|
108
|
+
except httpx.HTTPStatusError as e:
|
|
109
|
+
if e.response.status_code == 401:
|
|
110
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
111
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
112
|
+
except httpx.RequestError as e:
|
|
113
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
114
|
+
|
|
115
|
+
assert isinstance(workspaces, list), "API must return a list of workspaces"
|
|
116
|
+
|
|
117
|
+
for ws in workspaces:
|
|
118
|
+
status = ws.get("status", "unknown")
|
|
119
|
+
assert status in VALID_STATUSES or status == "unknown", (
|
|
120
|
+
f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
|
|
121
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return workspaces
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def resolve_workspace(specified: str | None) -> str:
|
|
128
|
+
"""Resolve workspace ID from specified name/ID, config default, or single workspace.
|
|
129
|
+
|
|
130
|
+
Priority:
|
|
131
|
+
1. If specified, return it (API will resolve name vs ID)
|
|
132
|
+
2. If config has defaults.workspace, return that
|
|
133
|
+
3. If user has exactly one workspace, return its ID
|
|
134
|
+
4. Otherwise, error with guidance
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
specified: Workspace name or ID, or None to use default
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Workspace name or ID to use
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
RuntimeError: If no workspace can be resolved
|
|
144
|
+
"""
|
|
145
|
+
from .global_config import get_defaults
|
|
146
|
+
|
|
147
|
+
# If specified, use it (API resolves name vs ID)
|
|
148
|
+
if specified:
|
|
149
|
+
return specified
|
|
150
|
+
|
|
151
|
+
# Check config default
|
|
152
|
+
defaults = get_defaults()
|
|
153
|
+
if defaults.workspace:
|
|
154
|
+
return defaults.workspace
|
|
155
|
+
|
|
156
|
+
# Check if user has exactly one workspace
|
|
157
|
+
workspaces = _list_workspaces_raw()
|
|
158
|
+
|
|
159
|
+
if len(workspaces) == 0:
|
|
160
|
+
raise RuntimeError("No workspaces found. Create one with: wafer workspaces create <name>")
|
|
161
|
+
|
|
162
|
+
if len(workspaces) == 1:
|
|
163
|
+
return workspaces[0]["id"]
|
|
164
|
+
|
|
165
|
+
# Multiple workspaces, no default - error with guidance
|
|
166
|
+
names = [ws.get("name", ws["id"]) for ws in workspaces]
|
|
167
|
+
raise RuntimeError(
|
|
168
|
+
f"Multiple workspaces found: {', '.join(names)}\n"
|
|
169
|
+
"Specify one, or set default: wafer config set defaults.workspace <name>"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def list_workspaces(json_output: bool = False) -> str:
|
|
174
|
+
"""List all workspaces for the current user.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
json_output: If True, return raw JSON; otherwise return formatted text
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Workspaces list as string (JSON or formatted text)
|
|
181
|
+
"""
|
|
182
|
+
api_url, headers = _get_client()
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
186
|
+
response = client.get(f"{api_url}/v1/workspaces")
|
|
187
|
+
response.raise_for_status()
|
|
188
|
+
workspaces = response.json()
|
|
189
|
+
except httpx.HTTPStatusError as e:
|
|
190
|
+
if e.response.status_code == 401:
|
|
191
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
192
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
193
|
+
except httpx.RequestError as e:
|
|
194
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
195
|
+
|
|
196
|
+
assert isinstance(workspaces, list), "API must return a list of workspaces"
|
|
197
|
+
|
|
198
|
+
for ws in workspaces:
|
|
199
|
+
status = ws.get("status", "unknown")
|
|
200
|
+
assert status in VALID_STATUSES or status == "unknown", (
|
|
201
|
+
f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
|
|
202
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if json_output:
|
|
206
|
+
return json.dumps(workspaces, indent=2)
|
|
207
|
+
|
|
208
|
+
if not workspaces:
|
|
209
|
+
return "No workspaces found."
|
|
210
|
+
|
|
211
|
+
lines = ["Workspaces:", ""]
|
|
212
|
+
for ws in workspaces:
|
|
213
|
+
status = ws.get("status", "unknown")
|
|
214
|
+
status_icon = {"running": "●", "creating": "◐"}.get(status, "?")
|
|
215
|
+
lines.append(f" {status_icon} {ws['name']} ({ws['id']})")
|
|
216
|
+
lines.append(f" GPU: {ws.get('gpu_type', 'N/A')} | Image: {ws.get('image', 'N/A')}")
|
|
217
|
+
if ws.get("ssh_host") and ws.get("ssh_port") and ws.get("ssh_user"):
|
|
218
|
+
lines.append(
|
|
219
|
+
f" SSH: ssh -p {ws['ssh_port']} {ws['ssh_user']}@{ws['ssh_host']}"
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
lines.append(" SSH: Not ready (run: wafer workspaces ssh <name>)")
|
|
223
|
+
lines.append("")
|
|
224
|
+
|
|
225
|
+
return "\n".join(lines)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def create_workspace(
|
|
229
|
+
name: str,
|
|
230
|
+
gpu_type: str = "B200",
|
|
231
|
+
image: str | None = None,
|
|
232
|
+
wait: bool = False,
|
|
233
|
+
json_output: bool = False,
|
|
234
|
+
) -> str:
|
|
235
|
+
"""Create a new workspace.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
name: Workspace name (must be unique)
|
|
239
|
+
gpu_type: GPU type (default: B200)
|
|
240
|
+
image: Docker image (optional, uses default if not specified)
|
|
241
|
+
wait: If True, stream provisioning progress and return SSH credentials
|
|
242
|
+
json_output: If True, return raw JSON; otherwise return formatted text
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Created workspace info as string
|
|
246
|
+
|
|
247
|
+
Raises:
|
|
248
|
+
RuntimeError: If name already exists or API error
|
|
249
|
+
"""
|
|
250
|
+
# Validate inputs
|
|
251
|
+
assert name, "Workspace name must be non-empty"
|
|
252
|
+
assert gpu_type, "GPU type must be non-empty"
|
|
253
|
+
|
|
254
|
+
api_url, headers = _get_client()
|
|
255
|
+
|
|
256
|
+
# Check for duplicate name
|
|
257
|
+
try:
|
|
258
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
259
|
+
response = client.get(f"{api_url}/v1/workspaces")
|
|
260
|
+
response.raise_for_status()
|
|
261
|
+
existing = response.json()
|
|
262
|
+
existing_names = [ws.get("name") for ws in existing]
|
|
263
|
+
if name in existing_names:
|
|
264
|
+
raise RuntimeError(
|
|
265
|
+
f"Workspace '{name}' already exists.\n"
|
|
266
|
+
f" Use a different name, or delete the existing one:\n"
|
|
267
|
+
f" wafer workspaces delete {name}"
|
|
268
|
+
)
|
|
269
|
+
except httpx.HTTPStatusError:
|
|
270
|
+
pass # Continue with create, let API handle auth errors
|
|
271
|
+
except httpx.RequestError:
|
|
272
|
+
pass # Continue with create, let API handle connection errors
|
|
273
|
+
|
|
274
|
+
request_body: dict = {
|
|
275
|
+
"name": name,
|
|
276
|
+
"gpu_type": gpu_type,
|
|
277
|
+
}
|
|
278
|
+
if image:
|
|
279
|
+
request_body["image"] = image
|
|
280
|
+
|
|
281
|
+
try:
|
|
282
|
+
with httpx.Client(timeout=60.0, headers=headers) as client:
|
|
283
|
+
response = client.post(f"{api_url}/v1/workspaces", json=request_body)
|
|
284
|
+
response.raise_for_status()
|
|
285
|
+
workspace = response.json()
|
|
286
|
+
except httpx.HTTPStatusError as e:
|
|
287
|
+
if e.response.status_code == 401:
|
|
288
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
289
|
+
if e.response.status_code == 400:
|
|
290
|
+
raise RuntimeError(f"Bad request: {e.response.text}") from e
|
|
291
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
292
|
+
except httpx.RequestError as e:
|
|
293
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
294
|
+
|
|
295
|
+
# Validate API response has required fields
|
|
296
|
+
assert "id" in workspace, "API response must contain workspace id"
|
|
297
|
+
assert "name" in workspace, "API response must contain workspace name"
|
|
298
|
+
|
|
299
|
+
if wait:
|
|
300
|
+
ssh_info = _wait_for_provisioning(workspace["id"])
|
|
301
|
+
if json_output:
|
|
302
|
+
payload = {
|
|
303
|
+
"workspace_id": workspace["id"],
|
|
304
|
+
"ssh_host": ssh_info["ssh_host"],
|
|
305
|
+
"ssh_port": ssh_info["ssh_port"],
|
|
306
|
+
"ssh_user": ssh_info["ssh_user"],
|
|
307
|
+
}
|
|
308
|
+
return json.dumps(payload, indent=2)
|
|
309
|
+
return (
|
|
310
|
+
f"Workspace ready: {workspace['name']} ({workspace['id']})\n"
|
|
311
|
+
f"SSH: ssh -p {ssh_info['ssh_port']} {ssh_info['ssh_user']}@{ssh_info['ssh_host']}"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if json_output:
|
|
315
|
+
return json.dumps(workspace, indent=2)
|
|
316
|
+
|
|
317
|
+
return (
|
|
318
|
+
f"Creating workspace: {workspace['name']} ({workspace['id']})\n"
|
|
319
|
+
"Check status with: wafer workspaces list\n"
|
|
320
|
+
"Estimated time: ~30 seconds"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _wait_for_provisioning(workspace_id: str) -> dict[str, str | int]:
|
|
325
|
+
"""Wait for workspace provisioning to complete via SSE."""
|
|
326
|
+
import sys
|
|
327
|
+
|
|
328
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
329
|
+
api_url, headers = _get_client()
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
with httpx.Client(timeout=None, headers=headers) as client:
|
|
333
|
+
with client.stream(
|
|
334
|
+
"POST",
|
|
335
|
+
f"{api_url}/v1/workspaces/{workspace_id}/provision-stream",
|
|
336
|
+
) as response:
|
|
337
|
+
if response.status_code != 200:
|
|
338
|
+
error_body = response.read().decode("utf-8", errors="replace")
|
|
339
|
+
raise RuntimeError(
|
|
340
|
+
_friendly_error(response.status_code, error_body, workspace_id)
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
ssh_info: dict[str, str | int] | None = None
|
|
344
|
+
for line in response.iter_lines():
|
|
345
|
+
if not line or not line.startswith("data: "):
|
|
346
|
+
continue
|
|
347
|
+
content = line[6:]
|
|
348
|
+
if content.startswith("[STATUS:"):
|
|
349
|
+
status = content[8:-1]
|
|
350
|
+
print(f"[wafer] {status.lower()}...", file=sys.stderr)
|
|
351
|
+
if status == "ERROR":
|
|
352
|
+
raise RuntimeError(
|
|
353
|
+
"Workspace provisioning failed. Check status with: wafer workspaces list"
|
|
354
|
+
)
|
|
355
|
+
elif content.startswith("[SSH:"):
|
|
356
|
+
parts = content[5:-1].split(":")
|
|
357
|
+
if len(parts) != 3:
|
|
358
|
+
raise RuntimeError("Malformed SSH info in provisioning stream")
|
|
359
|
+
ssh_info = {
|
|
360
|
+
"ssh_host": parts[0],
|
|
361
|
+
"ssh_port": int(parts[1]),
|
|
362
|
+
"ssh_user": parts[2],
|
|
363
|
+
}
|
|
364
|
+
break
|
|
365
|
+
|
|
366
|
+
if ssh_info is None:
|
|
367
|
+
raise RuntimeError("Provisioning did not return SSH credentials")
|
|
368
|
+
return ssh_info
|
|
369
|
+
except httpx.RequestError as e:
|
|
370
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
374
|
+
"""Delete a workspace.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
workspace_id: Workspace ID to delete
|
|
378
|
+
json_output: If True, return raw JSON; otherwise return formatted text
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
Deletion status as string
|
|
382
|
+
"""
|
|
383
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
384
|
+
|
|
385
|
+
api_url, headers = _get_client()
|
|
386
|
+
|
|
387
|
+
try:
|
|
388
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
389
|
+
response = client.delete(f"{api_url}/v1/workspaces/{workspace_id}")
|
|
390
|
+
response.raise_for_status()
|
|
391
|
+
result = response.json()
|
|
392
|
+
except httpx.HTTPStatusError as e:
|
|
393
|
+
if e.response.status_code == 401:
|
|
394
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
395
|
+
if e.response.status_code == 404:
|
|
396
|
+
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
397
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
398
|
+
except httpx.RequestError as e:
|
|
399
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
400
|
+
|
|
401
|
+
if json_output:
|
|
402
|
+
return json.dumps(result, indent=2)
|
|
403
|
+
|
|
404
|
+
return f"Deleted workspace: {workspace_id}"
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def sync_files(
|
|
408
|
+
workspace_id: str,
|
|
409
|
+
local_path: Path,
|
|
410
|
+
on_progress: Callable[[str], None] | None = None,
|
|
411
|
+
) -> tuple[int, str | None]:
|
|
412
|
+
"""Sync local files or directories to workspace via rsync over SSH.
|
|
413
|
+
|
|
414
|
+
After rsync completes, calls the API to sync files to Modal volume
|
|
415
|
+
so they're available for exec commands.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
workspace_id: Workspace ID or name
|
|
419
|
+
local_path: Local file or directory to sync
|
|
420
|
+
on_progress: Optional callback for progress messages
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
Tuple of (file_count, warning_message). Warning is None on success.
|
|
424
|
+
|
|
425
|
+
Raises:
|
|
426
|
+
RuntimeError: If rsync fails
|
|
427
|
+
"""
|
|
428
|
+
import subprocess
|
|
429
|
+
|
|
430
|
+
def emit(msg: str) -> None:
|
|
431
|
+
if on_progress:
|
|
432
|
+
on_progress(msg)
|
|
433
|
+
|
|
434
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
435
|
+
assert local_path.exists(), f"Path not found: {local_path}"
|
|
436
|
+
|
|
437
|
+
ws = get_workspace_raw(workspace_id)
|
|
438
|
+
resolved_id = ws["id"]
|
|
439
|
+
workspace_status = ws.get("status")
|
|
440
|
+
assert workspace_status in VALID_STATUSES, (
|
|
441
|
+
f"Workspace {workspace_id} has invalid status '{workspace_status}'. "
|
|
442
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
443
|
+
)
|
|
444
|
+
if workspace_status != "running":
|
|
445
|
+
raise RuntimeError(
|
|
446
|
+
f"Workspace is {workspace_status}. Wait for it to be running before syncing."
|
|
447
|
+
)
|
|
448
|
+
ssh_host = ws.get("ssh_host")
|
|
449
|
+
ssh_port = ws.get("ssh_port")
|
|
450
|
+
ssh_user = ws.get("ssh_user")
|
|
451
|
+
assert ssh_host, "Workspace missing ssh_host"
|
|
452
|
+
assert isinstance(ssh_port, int) and ssh_port > 0, "Workspace missing valid ssh_port"
|
|
453
|
+
assert ssh_user, "Workspace missing ssh_user"
|
|
454
|
+
|
|
455
|
+
# Build rsync command
|
|
456
|
+
# -a: archive mode (preserves permissions, etc.)
|
|
457
|
+
# -v: verbose
|
|
458
|
+
# -z: compress during transfer
|
|
459
|
+
if local_path.is_dir():
|
|
460
|
+
# Directory: sync contents (trailing slash)
|
|
461
|
+
source = f"{local_path}/"
|
|
462
|
+
else:
|
|
463
|
+
# Single file: sync the file itself
|
|
464
|
+
source = str(local_path)
|
|
465
|
+
|
|
466
|
+
# Build SSH command for rsync
|
|
467
|
+
# If key_path is None (BYOK model), SSH will use default key from ~/.ssh/
|
|
468
|
+
ssh_opts = f"-p {ssh_port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
|
|
469
|
+
|
|
470
|
+
rsync_cmd = [
|
|
471
|
+
"rsync",
|
|
472
|
+
"-avz",
|
|
473
|
+
"-e",
|
|
474
|
+
f"ssh {ssh_opts}",
|
|
475
|
+
source,
|
|
476
|
+
f"{ssh_user}@{ssh_host}:/workspace/",
|
|
477
|
+
]
|
|
478
|
+
|
|
479
|
+
try:
|
|
480
|
+
result = subprocess.run(rsync_cmd, capture_output=True, text=True)
|
|
481
|
+
if result.returncode != 0:
|
|
482
|
+
raise RuntimeError(f"rsync failed: {result.stderr}")
|
|
483
|
+
|
|
484
|
+
# Count files from rsync output (lines that don't start with special chars)
|
|
485
|
+
lines = result.stdout.strip().split("\n")
|
|
486
|
+
file_count = sum(
|
|
487
|
+
1
|
|
488
|
+
for line in lines
|
|
489
|
+
if line and not line.startswith((" ", "sent", "total", "receiving", "building"))
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
except FileNotFoundError:
|
|
493
|
+
raise RuntimeError("rsync not found. Install rsync to use sync feature.") from None
|
|
494
|
+
except subprocess.SubprocessError as e:
|
|
495
|
+
raise RuntimeError(f"Sync failed: {e}") from e
|
|
496
|
+
|
|
497
|
+
emit(f"Synced {file_count} files to SSH host")
|
|
498
|
+
|
|
499
|
+
# Notify API to sync files to Modal volume (so exec can see them)
|
|
500
|
+
# Use resolved UUID, not the name
|
|
501
|
+
emit("Syncing to Modal volume...")
|
|
502
|
+
warning = _init_sync_state(resolved_id)
|
|
503
|
+
|
|
504
|
+
if warning:
|
|
505
|
+
emit(f"Modal sync warning: {warning}")
|
|
506
|
+
else:
|
|
507
|
+
emit("Modal sync complete")
|
|
508
|
+
|
|
509
|
+
return file_count, warning
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _init_sync_state(workspace_id: str) -> str | None:
|
|
513
|
+
"""Tell API to sync files from bare metal to Modal volume.
|
|
514
|
+
|
|
515
|
+
This must be called after rsync completes so exec commands
|
|
516
|
+
can access the synced files.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
None on success, warning message on failure (non-fatal)
|
|
520
|
+
"""
|
|
521
|
+
api_url, headers = _get_client()
|
|
522
|
+
|
|
523
|
+
try:
|
|
524
|
+
with httpx.Client(timeout=120.0, headers=headers) as client:
|
|
525
|
+
response = client.post(f"{api_url}/v1/workspaces/{workspace_id}/init-sync-state")
|
|
526
|
+
response.raise_for_status()
|
|
527
|
+
return None
|
|
528
|
+
except httpx.HTTPStatusError as e:
|
|
529
|
+
# Non-fatal: sync to bare metal succeeded, Modal sync failed
|
|
530
|
+
# User can still SSH in and use files, just not via exec
|
|
531
|
+
if e.response.status_code == 404:
|
|
532
|
+
# Workspace not found or no target - sync still worked for SSH
|
|
533
|
+
return None
|
|
534
|
+
else:
|
|
535
|
+
# Extract error detail from response if available
|
|
536
|
+
detail = ""
|
|
537
|
+
try:
|
|
538
|
+
data = e.response.json()
|
|
539
|
+
detail = data.get("detail", "")
|
|
540
|
+
except Exception:
|
|
541
|
+
detail = e.response.text[:200] if e.response.text else ""
|
|
542
|
+
|
|
543
|
+
# Return warning instead of raising - rsync succeeded
|
|
544
|
+
if detail:
|
|
545
|
+
return f"Files synced to SSH, but Modal sync failed: {detail}"
|
|
546
|
+
return f"Files synced to SSH, but Modal sync failed ({e.response.status_code}). Use SSH or retry sync."
|
|
547
|
+
except httpx.RequestError:
|
|
548
|
+
# Network error - sync to bare metal succeeded
|
|
549
|
+
return None
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def get_workspace_raw(workspace_id: str) -> dict:
|
|
553
|
+
"""Get workspace details as raw JSON dict."""
|
|
554
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
555
|
+
|
|
556
|
+
api_url, headers = _get_client()
|
|
557
|
+
|
|
558
|
+
try:
|
|
559
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
560
|
+
response = client.get(f"{api_url}/v1/workspaces/{workspace_id}")
|
|
561
|
+
response.raise_for_status()
|
|
562
|
+
workspace = response.json()
|
|
563
|
+
except httpx.HTTPStatusError as e:
|
|
564
|
+
if e.response.status_code == 401:
|
|
565
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
566
|
+
if e.response.status_code == 404:
|
|
567
|
+
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
568
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
569
|
+
except httpx.RequestError as e:
|
|
570
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
571
|
+
|
|
572
|
+
assert "id" in workspace, "API response must contain workspace id"
|
|
573
|
+
assert "name" in workspace, "API response must contain workspace name"
|
|
574
|
+
|
|
575
|
+
status = workspace.get("status", "unknown")
|
|
576
|
+
assert status in VALID_STATUSES or status == "unknown", (
|
|
577
|
+
f"Workspace {workspace['id']} has invalid status '{status}'. "
|
|
578
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
return workspace
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def get_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
585
|
+
"""Get details of a specific workspace.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
workspace_id: Workspace ID to get
|
|
589
|
+
json_output: If True, return raw JSON; otherwise return formatted text
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
Workspace details as string
|
|
593
|
+
"""
|
|
594
|
+
workspace = get_workspace_raw(workspace_id)
|
|
595
|
+
|
|
596
|
+
if json_output:
|
|
597
|
+
return json.dumps(workspace, indent=2)
|
|
598
|
+
|
|
599
|
+
status = workspace.get("status", "unknown")
|
|
600
|
+
lines = [
|
|
601
|
+
f"Workspace: {workspace['name']} ({workspace['id']})",
|
|
602
|
+
"",
|
|
603
|
+
f" Status: {status}",
|
|
604
|
+
f" GPU Type: {workspace.get('gpu_type', 'N/A')}",
|
|
605
|
+
f" Image: {workspace.get('image', 'N/A')}",
|
|
606
|
+
f" Created: {workspace.get('created_at', 'N/A')}",
|
|
607
|
+
f" Last Used: {workspace.get('last_used_at', 'N/A')}",
|
|
608
|
+
]
|
|
609
|
+
|
|
610
|
+
if workspace.get("ssh_host"):
|
|
611
|
+
lines.extend([
|
|
612
|
+
"",
|
|
613
|
+
"SSH Info:",
|
|
614
|
+
f" Host: {workspace['ssh_host']}",
|
|
615
|
+
f" Port: {workspace.get('ssh_port', 22)}",
|
|
616
|
+
f" User: {workspace.get('ssh_user', 'root')}",
|
|
617
|
+
])
|
|
618
|
+
elif status == "creating":
|
|
619
|
+
lines.extend(["", "SSH: available once workspace is running"])
|
|
620
|
+
|
|
621
|
+
return "\n".join(lines)
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def _handle_sync_event(sync_type: str) -> None:
|
|
625
|
+
"""Handle sync events and print status to stderr.
|
|
626
|
+
|
|
627
|
+
Sync events:
|
|
628
|
+
- FORWARD:START - Starting workspace → GPU sync
|
|
629
|
+
- FORWARD:DONE:N - Synced N files to GPU
|
|
630
|
+
- FORWARD:WARN:msg - Warning during forward sync
|
|
631
|
+
- REVERSE:START - Starting GPU → workspace sync
|
|
632
|
+
- REVERSE:DONE:N - Synced N artifacts back
|
|
633
|
+
"""
|
|
634
|
+
import sys
|
|
635
|
+
|
|
636
|
+
if sync_type == "FORWARD:START":
|
|
637
|
+
print("[sync] Syncing workspace → GPU...", end="", file=sys.stderr, flush=True)
|
|
638
|
+
elif sync_type.startswith("FORWARD:DONE:"):
|
|
639
|
+
count = sync_type.split(":")[-1]
|
|
640
|
+
print(f" done ({count} files)", file=sys.stderr)
|
|
641
|
+
elif sync_type.startswith("FORWARD:WARN:"):
|
|
642
|
+
msg = sync_type[13:] # Remove "FORWARD:WARN:"
|
|
643
|
+
print(f" warning: {msg}", file=sys.stderr)
|
|
644
|
+
elif sync_type == "REVERSE:START":
|
|
645
|
+
print("[sync] Syncing artifacts back...", end="", file=sys.stderr, flush=True)
|
|
646
|
+
elif sync_type.startswith("REVERSE:DONE:"):
|
|
647
|
+
count = sync_type.split(":")[-1]
|
|
648
|
+
print(f" done ({count} files)", file=sys.stderr)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
@dataclass(frozen=True)
|
|
652
|
+
class SSEEvent:
|
|
653
|
+
"""Parsed SSE event result."""
|
|
654
|
+
|
|
655
|
+
output: str | None # Content to print (None = no output)
|
|
656
|
+
exit_code: int | None # Exit code if stream should end (None = continue)
|
|
657
|
+
is_error: bool # Whether output goes to stderr
|
|
658
|
+
sync_event: str | None = None # Sync event type (e.g., "FORWARD:START")
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def _parse_sse_content(content: str) -> SSEEvent:
|
|
662
|
+
"""Parse SSE content into structured event.
|
|
663
|
+
|
|
664
|
+
Pure function: content in, event out. No side effects.
|
|
665
|
+
"""
|
|
666
|
+
if content == "[DONE]":
|
|
667
|
+
return SSEEvent(output=None, exit_code=0, is_error=False)
|
|
668
|
+
|
|
669
|
+
if content.startswith("[EXIT:"):
|
|
670
|
+
# Parse exit code: [EXIT:0] or [EXIT:1]
|
|
671
|
+
try:
|
|
672
|
+
code = int(content[6:-1])
|
|
673
|
+
except ValueError:
|
|
674
|
+
code = 0
|
|
675
|
+
return SSEEvent(output=None, exit_code=code, is_error=False)
|
|
676
|
+
|
|
677
|
+
if content.startswith("[ERROR]"):
|
|
678
|
+
return SSEEvent(output=content[8:], exit_code=1, is_error=True)
|
|
679
|
+
|
|
680
|
+
# Sync events: [SYNC:FORWARD:START], [SYNC:FORWARD:DONE:5], etc.
|
|
681
|
+
if content.startswith("[SYNC:"):
|
|
682
|
+
# Extract sync type (e.g., "FORWARD:START" or "REVERSE:DONE:5")
|
|
683
|
+
sync_type = content[6:-1] # Remove [SYNC: and ]
|
|
684
|
+
return SSEEvent(output=None, exit_code=None, is_error=False, sync_event=sync_type)
|
|
685
|
+
|
|
686
|
+
# Status events we can ignore (already handled elsewhere)
|
|
687
|
+
if content.startswith("[STATUS:") or content.startswith("[CONTEXT:"):
|
|
688
|
+
return SSEEvent(output=None, exit_code=None, is_error=False)
|
|
689
|
+
|
|
690
|
+
# Regular output
|
|
691
|
+
return SSEEvent(output=content, exit_code=None, is_error=False)
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def exec_command(
|
|
695
|
+
workspace_id: str,
|
|
696
|
+
command: str,
|
|
697
|
+
timeout_seconds: int | None = None,
|
|
698
|
+
routing: str | None = None,
|
|
699
|
+
pull_image: bool = False,
|
|
700
|
+
) -> int:
|
|
701
|
+
"""Execute a command in workspace, streaming output.
|
|
702
|
+
|
|
703
|
+
Args:
|
|
704
|
+
workspace_id: Workspace ID or name
|
|
705
|
+
command: Command to execute
|
|
706
|
+
timeout_seconds: Execution timeout (default: 300, from config)
|
|
707
|
+
routing: Routing hint - "auto", "gpu", "cpu", or "baremetal" (default: auto)
|
|
708
|
+
|
|
709
|
+
Returns:
|
|
710
|
+
Exit code (0 = success, non-zero = failure)
|
|
711
|
+
"""
|
|
712
|
+
import base64
|
|
713
|
+
import sys
|
|
714
|
+
|
|
715
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
716
|
+
assert command, "Command must be non-empty"
|
|
717
|
+
|
|
718
|
+
api_url, headers = _get_client()
|
|
719
|
+
|
|
720
|
+
# Base64 encode command to avoid escaping issues
|
|
721
|
+
command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
|
|
722
|
+
|
|
723
|
+
request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
|
|
724
|
+
if timeout_seconds:
|
|
725
|
+
request_body["timeout_seconds"] = timeout_seconds
|
|
726
|
+
|
|
727
|
+
# Add routing hint if specified
|
|
728
|
+
if routing:
|
|
729
|
+
request_body["requirements"] = {"routing": routing}
|
|
730
|
+
|
|
731
|
+
try:
|
|
732
|
+
# Use streaming request for SSE output
|
|
733
|
+
with httpx.Client(timeout=None, headers=headers) as client:
|
|
734
|
+
with client.stream(
|
|
735
|
+
"POST",
|
|
736
|
+
f"{api_url}/v1/workspaces/{workspace_id}/exec",
|
|
737
|
+
json=request_body,
|
|
738
|
+
) as response:
|
|
739
|
+
if response.status_code != 200:
|
|
740
|
+
# Read error body and provide friendly message
|
|
741
|
+
error_body = response.read().decode("utf-8", errors="replace")
|
|
742
|
+
raise RuntimeError(
|
|
743
|
+
_friendly_error(response.status_code, error_body, workspace_id)
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
exit_code = 0
|
|
747
|
+
for line in response.iter_lines():
|
|
748
|
+
if not line or not line.startswith("data: "):
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
event = _parse_sse_content(line[6:])
|
|
752
|
+
|
|
753
|
+
# Handle sync events - display status to stderr
|
|
754
|
+
if event.sync_event:
|
|
755
|
+
_handle_sync_event(event.sync_event)
|
|
756
|
+
continue
|
|
757
|
+
|
|
758
|
+
if event.output is not None:
|
|
759
|
+
print(event.output, file=sys.stderr if event.is_error else sys.stdout)
|
|
760
|
+
|
|
761
|
+
if event.exit_code is not None:
|
|
762
|
+
exit_code = event.exit_code
|
|
763
|
+
break
|
|
764
|
+
|
|
765
|
+
return exit_code
|
|
766
|
+
|
|
767
|
+
except httpx.HTTPStatusError as e:
|
|
768
|
+
raise RuntimeError(
|
|
769
|
+
_friendly_error(e.response.status_code, e.response.text, workspace_id)
|
|
770
|
+
) from e
|
|
771
|
+
except httpx.RequestError as e:
|
|
772
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def exec_command_capture(
|
|
776
|
+
workspace_id: str,
|
|
777
|
+
command: str,
|
|
778
|
+
timeout_seconds: int | None = None,
|
|
779
|
+
routing: str | None = None,
|
|
780
|
+
pull_image: bool = False,
|
|
781
|
+
) -> tuple[int, str]:
|
|
782
|
+
"""Execute a command in workspace and capture output.
|
|
783
|
+
|
|
784
|
+
Similar to exec_command but returns output as string instead of printing.
|
|
785
|
+
|
|
786
|
+
Args:
|
|
787
|
+
workspace_id: Workspace ID or name
|
|
788
|
+
command: Command to execute
|
|
789
|
+
timeout_seconds: Execution timeout (default: 300)
|
|
790
|
+
routing: Routing hint - "auto", "gpu", "cpu", or "baremetal"
|
|
791
|
+
|
|
792
|
+
Returns:
|
|
793
|
+
Tuple of (exit_code, output_string)
|
|
794
|
+
"""
|
|
795
|
+
import base64
|
|
796
|
+
|
|
797
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
798
|
+
assert command, "Command must be non-empty"
|
|
799
|
+
|
|
800
|
+
api_url, headers = _get_client()
|
|
801
|
+
|
|
802
|
+
# Base64 encode command to avoid escaping issues
|
|
803
|
+
command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
|
|
804
|
+
|
|
805
|
+
request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
|
|
806
|
+
if timeout_seconds:
|
|
807
|
+
request_body["timeout_seconds"] = timeout_seconds
|
|
808
|
+
|
|
809
|
+
if routing:
|
|
810
|
+
request_body["requirements"] = {"routing": routing}
|
|
811
|
+
|
|
812
|
+
output_lines: list[str] = []
|
|
813
|
+
|
|
814
|
+
try:
|
|
815
|
+
with httpx.Client(timeout=None, headers=headers) as client:
|
|
816
|
+
with client.stream(
|
|
817
|
+
"POST",
|
|
818
|
+
f"{api_url}/v1/workspaces/{workspace_id}/exec",
|
|
819
|
+
json=request_body,
|
|
820
|
+
) as response:
|
|
821
|
+
if response.status_code != 200:
|
|
822
|
+
error_body = response.read().decode("utf-8", errors="replace")
|
|
823
|
+
raise RuntimeError(
|
|
824
|
+
_friendly_error(response.status_code, error_body, workspace_id)
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
exit_code = 0
|
|
828
|
+
for line in response.iter_lines():
|
|
829
|
+
if not line or not line.startswith("data: "):
|
|
830
|
+
continue
|
|
831
|
+
|
|
832
|
+
event = _parse_sse_content(line[6:])
|
|
833
|
+
|
|
834
|
+
# Skip sync events
|
|
835
|
+
if event.sync_event:
|
|
836
|
+
continue
|
|
837
|
+
|
|
838
|
+
if event.output is not None:
|
|
839
|
+
output_lines.append(event.output)
|
|
840
|
+
|
|
841
|
+
if event.exit_code is not None:
|
|
842
|
+
exit_code = event.exit_code
|
|
843
|
+
break
|
|
844
|
+
|
|
845
|
+
return exit_code, "\n".join(output_lines)
|
|
846
|
+
|
|
847
|
+
except httpx.HTTPStatusError as e:
|
|
848
|
+
raise RuntimeError(
|
|
849
|
+
_friendly_error(e.response.status_code, e.response.text, workspace_id)
|
|
850
|
+
) from e
|
|
851
|
+
except httpx.RequestError as e:
|
|
852
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|