wafer-cli 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +18 -7
- wafer/api_client.py +4 -0
- wafer/auth.py +85 -0
- wafer/cli.py +2339 -404
- wafer/corpus.py +158 -32
- wafer/evaluate.py +1232 -201
- wafer/gpu_run.py +5 -1
- wafer/kernel_scope.py +554 -0
- wafer/nsys_analyze.py +903 -73
- wafer/nsys_profile.py +511 -0
- wafer/output.py +241 -0
- wafer/problems.py +357 -0
- wafer/skills/wafer-guide/SKILL.md +13 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +490 -0
- wafer/targets_ops.py +718 -0
- wafer/wevin_cli.py +129 -18
- wafer/workspaces.py +282 -182
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/METADATA +1 -1
- wafer_cli-0.2.10.dist-info/RECORD +40 -0
- wafer_cli-0.2.8.dist-info/RECORD +0 -33
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.8.dist-info → wafer_cli-0.2.10.dist-info}/top_level.txt +0 -0
wafer/workspaces.py
CHANGED
|
@@ -13,15 +13,7 @@ import httpx
|
|
|
13
13
|
from .api_client import get_api_url
|
|
14
14
|
from .auth import get_auth_headers
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
@dataclass(frozen=True)
|
|
18
|
-
class SSHCredentials:
|
|
19
|
-
"""SSH credentials for workspace access."""
|
|
20
|
-
|
|
21
|
-
host: str
|
|
22
|
-
port: int
|
|
23
|
-
user: str
|
|
24
|
-
key_path: Path
|
|
16
|
+
VALID_STATUSES = {"creating", "running"}
|
|
25
17
|
|
|
26
18
|
|
|
27
19
|
def _get_client() -> tuple[str, dict[str, str]]:
|
|
@@ -72,10 +64,11 @@ def _friendly_error(status_code: int, response_text: str, workspace_id: str) ->
|
|
|
72
64
|
|
|
73
65
|
# Parse common error details from response
|
|
74
66
|
detail = ""
|
|
75
|
-
if "
|
|
67
|
+
if "not running" in response_text.lower() or "not found" in response_text.lower():
|
|
76
68
|
return (
|
|
77
|
-
f"Workspace '{workspace_id}'
|
|
78
|
-
"
|
|
69
|
+
f"Workspace '{workspace_id}' not found or not running.\n"
|
|
70
|
+
" Check status: wafer workspaces list\n"
|
|
71
|
+
" Create new: wafer workspaces create <name>"
|
|
79
72
|
)
|
|
80
73
|
|
|
81
74
|
if "timeout" in response_text.lower():
|
|
@@ -85,6 +78,12 @@ def _friendly_error(status_code: int, response_text: str, workspace_id: str) ->
|
|
|
85
78
|
" Or set default: wafer config set defaults.exec_timeout 600"
|
|
86
79
|
)
|
|
87
80
|
|
|
81
|
+
if "creating" in response_text.lower():
|
|
82
|
+
return (
|
|
83
|
+
f"Workspace '{workspace_id}' is still creating.\n"
|
|
84
|
+
" Check status: wafer workspaces list"
|
|
85
|
+
)
|
|
86
|
+
|
|
88
87
|
# Generic error with response detail
|
|
89
88
|
try:
|
|
90
89
|
import json
|
|
@@ -114,6 +113,14 @@ def _list_workspaces_raw() -> list[dict]:
|
|
|
114
113
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
115
114
|
|
|
116
115
|
assert isinstance(workspaces, list), "API must return a list of workspaces"
|
|
116
|
+
|
|
117
|
+
for ws in workspaces:
|
|
118
|
+
status = ws.get("status", "unknown")
|
|
119
|
+
assert status in VALID_STATUSES or status == "unknown", (
|
|
120
|
+
f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
|
|
121
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
122
|
+
)
|
|
123
|
+
|
|
117
124
|
return workspaces
|
|
118
125
|
|
|
119
126
|
|
|
@@ -186,9 +193,15 @@ def list_workspaces(json_output: bool = False) -> str:
|
|
|
186
193
|
except httpx.RequestError as e:
|
|
187
194
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
188
195
|
|
|
189
|
-
# Validate API response shape
|
|
190
196
|
assert isinstance(workspaces, list), "API must return a list of workspaces"
|
|
191
197
|
|
|
198
|
+
for ws in workspaces:
|
|
199
|
+
status = ws.get("status", "unknown")
|
|
200
|
+
assert status in VALID_STATUSES or status == "unknown", (
|
|
201
|
+
f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
|
|
202
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
203
|
+
)
|
|
204
|
+
|
|
192
205
|
if json_output:
|
|
193
206
|
return json.dumps(workspaces, indent=2)
|
|
194
207
|
|
|
@@ -198,10 +211,14 @@ def list_workspaces(json_output: bool = False) -> str:
|
|
|
198
211
|
lines = ["Workspaces:", ""]
|
|
199
212
|
for ws in workspaces:
|
|
200
213
|
status = ws.get("status", "unknown")
|
|
201
|
-
status_icon = {"running": "●", "
|
|
214
|
+
status_icon = {"running": "●", "creating": "◐"}.get(status, "?")
|
|
202
215
|
lines.append(f" {status_icon} {ws['name']} ({ws['id']})")
|
|
203
216
|
lines.append(f" GPU: {ws.get('gpu_type', 'N/A')} | Image: {ws.get('image', 'N/A')}")
|
|
204
217
|
lines.append(f" Status: {status} | Created: {ws.get('created_at', 'N/A')}")
|
|
218
|
+
if status == "running" and ws.get("ssh_host") and ws.get("ssh_port") and ws.get("ssh_user"):
|
|
219
|
+
lines.append(
|
|
220
|
+
f" SSH: ssh -p {ws['ssh_port']} {ws['ssh_user']}@{ws['ssh_host']}"
|
|
221
|
+
)
|
|
205
222
|
lines.append("")
|
|
206
223
|
|
|
207
224
|
return "\n".join(lines)
|
|
@@ -211,6 +228,7 @@ def create_workspace(
|
|
|
211
228
|
name: str,
|
|
212
229
|
gpu_type: str = "B200",
|
|
213
230
|
image: str | None = None,
|
|
231
|
+
wait: bool = False,
|
|
214
232
|
json_output: bool = False,
|
|
215
233
|
) -> str:
|
|
216
234
|
"""Create a new workspace.
|
|
@@ -219,6 +237,7 @@ def create_workspace(
|
|
|
219
237
|
name: Workspace name (must be unique)
|
|
220
238
|
gpu_type: GPU type (default: B200)
|
|
221
239
|
image: Docker image (optional, uses default if not specified)
|
|
240
|
+
wait: If True, stream provisioning progress and return SSH credentials
|
|
222
241
|
json_output: If True, return raw JSON; otherwise return formatted text
|
|
223
242
|
|
|
224
243
|
Returns:
|
|
@@ -276,190 +295,114 @@ def create_workspace(
|
|
|
276
295
|
assert "id" in workspace, "API response must contain workspace id"
|
|
277
296
|
assert "name" in workspace, "API response must contain workspace name"
|
|
278
297
|
|
|
298
|
+
if wait:
|
|
299
|
+
ssh_info = _wait_for_provisioning(workspace["id"])
|
|
300
|
+
if json_output:
|
|
301
|
+
payload = {
|
|
302
|
+
"workspace_id": workspace["id"],
|
|
303
|
+
"ssh_host": ssh_info["ssh_host"],
|
|
304
|
+
"ssh_port": ssh_info["ssh_port"],
|
|
305
|
+
"ssh_user": ssh_info["ssh_user"],
|
|
306
|
+
}
|
|
307
|
+
return json.dumps(payload, indent=2)
|
|
308
|
+
return (
|
|
309
|
+
f"Workspace ready: {workspace['name']} ({workspace['id']})\n"
|
|
310
|
+
f"SSH: ssh -p {ssh_info['ssh_port']} {ssh_info['ssh_user']}@{ssh_info['ssh_host']}"
|
|
311
|
+
)
|
|
312
|
+
|
|
279
313
|
if json_output:
|
|
280
314
|
return json.dumps(workspace, indent=2)
|
|
281
315
|
|
|
282
|
-
return
|
|
283
|
-
|
|
316
|
+
return (
|
|
317
|
+
f"Creating workspace: {workspace['name']} ({workspace['id']})\n"
|
|
318
|
+
"Check status with: wafer workspaces list\n"
|
|
319
|
+
"Estimated time: ~30 seconds"
|
|
320
|
+
)
|
|
284
321
|
|
|
285
|
-
def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
286
|
-
"""Delete a workspace.
|
|
287
322
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
323
|
+
def _wait_for_provisioning(workspace_id: str) -> dict[str, str | int]:
|
|
324
|
+
"""Wait for workspace provisioning to complete via SSE."""
|
|
325
|
+
import sys
|
|
291
326
|
|
|
292
|
-
Returns:
|
|
293
|
-
Deletion status as string
|
|
294
|
-
"""
|
|
295
327
|
assert workspace_id, "Workspace ID must be non-empty"
|
|
296
|
-
|
|
297
328
|
api_url, headers = _get_client()
|
|
298
329
|
|
|
299
330
|
try:
|
|
300
|
-
with httpx.Client(timeout=
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
except httpx.RequestError as e:
|
|
311
|
-
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
331
|
+
with httpx.Client(timeout=None, headers=headers) as client:
|
|
332
|
+
with client.stream(
|
|
333
|
+
"POST",
|
|
334
|
+
f"{api_url}/v1/workspaces/{workspace_id}/provision-stream",
|
|
335
|
+
) as response:
|
|
336
|
+
if response.status_code != 200:
|
|
337
|
+
error_body = response.read().decode("utf-8", errors="replace")
|
|
338
|
+
raise RuntimeError(
|
|
339
|
+
_friendly_error(response.status_code, error_body, workspace_id)
|
|
340
|
+
)
|
|
312
341
|
|
|
313
|
-
|
|
314
|
-
|
|
342
|
+
ssh_info: dict[str, str | int] | None = None
|
|
343
|
+
for line in response.iter_lines():
|
|
344
|
+
if not line or not line.startswith("data: "):
|
|
345
|
+
continue
|
|
346
|
+
content = line[6:]
|
|
347
|
+
if content.startswith("[STATUS:"):
|
|
348
|
+
status = content[8:-1]
|
|
349
|
+
print(f"[wafer] {status.lower()}...", file=sys.stderr)
|
|
350
|
+
if status == "ERROR":
|
|
351
|
+
raise RuntimeError(
|
|
352
|
+
"Workspace provisioning failed. Check status with: wafer workspaces list"
|
|
353
|
+
)
|
|
354
|
+
elif content.startswith("[SSH:"):
|
|
355
|
+
parts = content[5:-1].split(":")
|
|
356
|
+
if len(parts) != 3:
|
|
357
|
+
raise RuntimeError("Malformed SSH info in provisioning stream")
|
|
358
|
+
ssh_info = {
|
|
359
|
+
"ssh_host": parts[0],
|
|
360
|
+
"ssh_port": int(parts[1]),
|
|
361
|
+
"ssh_user": parts[2],
|
|
362
|
+
}
|
|
363
|
+
break
|
|
315
364
|
|
|
316
|
-
|
|
365
|
+
if ssh_info is None:
|
|
366
|
+
raise RuntimeError("Provisioning did not return SSH credentials")
|
|
367
|
+
return ssh_info
|
|
368
|
+
except httpx.RequestError as e:
|
|
369
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
317
370
|
|
|
318
371
|
|
|
319
|
-
def
|
|
320
|
-
"""
|
|
372
|
+
def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
373
|
+
"""Delete a workspace.
|
|
321
374
|
|
|
322
375
|
Args:
|
|
323
|
-
workspace_id: Workspace ID to
|
|
376
|
+
workspace_id: Workspace ID to delete
|
|
324
377
|
json_output: If True, return raw JSON; otherwise return formatted text
|
|
325
378
|
|
|
326
379
|
Returns:
|
|
327
|
-
|
|
380
|
+
Deletion status as string
|
|
328
381
|
"""
|
|
329
382
|
assert workspace_id, "Workspace ID must be non-empty"
|
|
330
383
|
|
|
331
384
|
api_url, headers = _get_client()
|
|
332
385
|
|
|
333
386
|
try:
|
|
334
|
-
with httpx.Client(timeout=
|
|
335
|
-
response = client.
|
|
387
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
388
|
+
response = client.delete(f"{api_url}/v1/workspaces/{workspace_id}")
|
|
336
389
|
response.raise_for_status()
|
|
337
|
-
|
|
390
|
+
result = response.json()
|
|
338
391
|
except httpx.HTTPStatusError as e:
|
|
339
392
|
if e.response.status_code == 401:
|
|
340
393
|
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
341
394
|
if e.response.status_code == 404:
|
|
342
395
|
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
343
|
-
if e.response.status_code == 503:
|
|
344
|
-
raise RuntimeError("No GPU available. Please try again later.") from e
|
|
345
396
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
346
397
|
except httpx.RequestError as e:
|
|
347
398
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
348
399
|
|
|
349
|
-
# Validate API response has required SSH fields
|
|
350
|
-
assert "ssh_host" in attach_info, "API response must contain ssh_host"
|
|
351
|
-
assert "ssh_port" in attach_info, "API response must contain ssh_port"
|
|
352
|
-
assert "ssh_user" in attach_info, "API response must contain ssh_user"
|
|
353
|
-
assert "private_key_pem" in attach_info, "API response must contain private_key_pem"
|
|
354
|
-
|
|
355
400
|
if json_output:
|
|
356
|
-
return json.dumps(
|
|
357
|
-
|
|
358
|
-
# Write private key to temp file and generate SSH config
|
|
359
|
-
ssh_host = attach_info["ssh_host"]
|
|
360
|
-
ssh_port = attach_info["ssh_port"]
|
|
361
|
-
ssh_user = attach_info["ssh_user"]
|
|
362
|
-
private_key = attach_info["private_key_pem"]
|
|
363
|
-
|
|
364
|
-
# Validate field values before using them
|
|
365
|
-
assert ssh_host, "ssh_host must be non-empty"
|
|
366
|
-
assert isinstance(ssh_port, int), "ssh_port must be an integer"
|
|
367
|
-
assert ssh_port > 0, "ssh_port must be positive"
|
|
368
|
-
assert ssh_user, "ssh_user must be non-empty"
|
|
369
|
-
assert private_key, "private_key_pem must be non-empty"
|
|
370
|
-
|
|
371
|
-
# Save private key
|
|
372
|
-
key_dir = Path.home() / ".wafer" / "keys"
|
|
373
|
-
key_dir.mkdir(parents=True, exist_ok=True)
|
|
374
|
-
key_path = key_dir / f"{workspace_id}.pem"
|
|
375
|
-
key_path.write_text(private_key)
|
|
376
|
-
key_path.chmod(0o600)
|
|
377
|
-
|
|
378
|
-
# Generate SSH config entry
|
|
379
|
-
config_entry = f"""
|
|
380
|
-
# Wafer workspace: {workspace_id}
|
|
381
|
-
Host wafer-{workspace_id}
|
|
382
|
-
HostName {ssh_host}
|
|
383
|
-
Port {ssh_port}
|
|
384
|
-
User {ssh_user}
|
|
385
|
-
IdentityFile {key_path}
|
|
386
|
-
StrictHostKeyChecking no
|
|
387
|
-
UserKnownHostsFile /dev/null
|
|
388
|
-
"""
|
|
389
|
-
|
|
390
|
-
lines = [
|
|
391
|
-
f"Attached to workspace: {workspace_id}",
|
|
392
|
-
"",
|
|
393
|
-
"SSH Connection:",
|
|
394
|
-
f" ssh -i {key_path} -p {ssh_port} {ssh_user}@{ssh_host}",
|
|
395
|
-
"",
|
|
396
|
-
"Or add to ~/.ssh/config:",
|
|
397
|
-
config_entry,
|
|
398
|
-
f"Then connect with: ssh wafer-{workspace_id}",
|
|
399
|
-
]
|
|
400
|
-
|
|
401
|
-
return "\n".join(lines)
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
def get_ssh_credentials(workspace_id: str) -> tuple[SSHCredentials, str]:
|
|
405
|
-
"""Get SSH credentials for a workspace.
|
|
406
|
-
|
|
407
|
-
Calls attach API, saves key file, returns credentials for SSH.
|
|
401
|
+
return json.dumps(result, indent=2)
|
|
408
402
|
|
|
409
|
-
|
|
410
|
-
workspace_id: Workspace ID or name
|
|
403
|
+
return f"Deleted workspace: {workspace_id}"
|
|
411
404
|
|
|
412
|
-
Returns:
|
|
413
|
-
Tuple of (SSHCredentials, resolved_workspace_id)
|
|
414
405
|
|
|
415
|
-
Raises:
|
|
416
|
-
RuntimeError: If attach fails
|
|
417
|
-
"""
|
|
418
|
-
assert workspace_id, "Workspace ID must be non-empty"
|
|
419
|
-
|
|
420
|
-
api_url, headers = _get_client()
|
|
421
|
-
|
|
422
|
-
try:
|
|
423
|
-
with httpx.Client(timeout=120.0, headers=headers) as client:
|
|
424
|
-
response = client.post(f"{api_url}/v1/workspaces/{workspace_id}/attach")
|
|
425
|
-
response.raise_for_status()
|
|
426
|
-
attach_info = response.json()
|
|
427
|
-
except httpx.HTTPStatusError as e:
|
|
428
|
-
if e.response.status_code == 401:
|
|
429
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
430
|
-
if e.response.status_code == 404:
|
|
431
|
-
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
432
|
-
if e.response.status_code == 503:
|
|
433
|
-
raise RuntimeError("No GPU available. Please try again later.") from e
|
|
434
|
-
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
435
|
-
except httpx.RequestError as e:
|
|
436
|
-
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
437
|
-
|
|
438
|
-
# Validate and extract fields
|
|
439
|
-
resolved_id = attach_info.get("workspace_id", workspace_id)
|
|
440
|
-
ssh_host = attach_info.get("ssh_host")
|
|
441
|
-
ssh_port = attach_info.get("ssh_port")
|
|
442
|
-
ssh_user = attach_info.get("ssh_user")
|
|
443
|
-
private_key = attach_info.get("private_key_pem")
|
|
444
|
-
|
|
445
|
-
assert ssh_host, "API response must contain ssh_host"
|
|
446
|
-
assert isinstance(ssh_port, int) and ssh_port > 0, "ssh_port must be positive integer"
|
|
447
|
-
assert ssh_user, "API response must contain ssh_user"
|
|
448
|
-
assert private_key, "API response must contain private_key_pem"
|
|
449
|
-
|
|
450
|
-
# Save private key using resolved ID
|
|
451
|
-
key_dir = Path.home() / ".wafer" / "keys"
|
|
452
|
-
key_dir.mkdir(parents=True, exist_ok=True)
|
|
453
|
-
key_path = key_dir / f"{resolved_id}.pem"
|
|
454
|
-
key_path.write_text(private_key)
|
|
455
|
-
key_path.chmod(0o600)
|
|
456
|
-
|
|
457
|
-
return SSHCredentials(
|
|
458
|
-
host=ssh_host,
|
|
459
|
-
port=ssh_port,
|
|
460
|
-
user=ssh_user,
|
|
461
|
-
key_path=key_path,
|
|
462
|
-
), resolved_id
|
|
463
406
|
|
|
464
407
|
|
|
465
408
|
def sync_files(
|
|
@@ -492,9 +435,23 @@ def sync_files(
|
|
|
492
435
|
assert workspace_id, "Workspace ID must be non-empty"
|
|
493
436
|
assert local_path.exists(), f"Path not found: {local_path}"
|
|
494
437
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
438
|
+
ws = get_workspace_raw(workspace_id)
|
|
439
|
+
resolved_id = ws["id"]
|
|
440
|
+
workspace_status = ws.get("status")
|
|
441
|
+
assert workspace_status in VALID_STATUSES, (
|
|
442
|
+
f"Workspace {workspace_id} has invalid status '{workspace_status}'. "
|
|
443
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
444
|
+
)
|
|
445
|
+
if workspace_status != "running":
|
|
446
|
+
raise RuntimeError(
|
|
447
|
+
f"Workspace is {workspace_status}. Wait for it to be running before syncing."
|
|
448
|
+
)
|
|
449
|
+
ssh_host = ws.get("ssh_host")
|
|
450
|
+
ssh_port = ws.get("ssh_port")
|
|
451
|
+
ssh_user = ws.get("ssh_user")
|
|
452
|
+
assert ssh_host, "Workspace missing ssh_host"
|
|
453
|
+
assert isinstance(ssh_port, int) and ssh_port > 0, "Workspace missing valid ssh_port"
|
|
454
|
+
assert ssh_user, "Workspace missing ssh_user"
|
|
498
455
|
|
|
499
456
|
# Build rsync command
|
|
500
457
|
# -a: archive mode (preserves permissions, etc.)
|
|
@@ -507,13 +464,17 @@ def sync_files(
|
|
|
507
464
|
# Single file: sync the file itself
|
|
508
465
|
source = str(local_path)
|
|
509
466
|
|
|
467
|
+
# Build SSH command for rsync
|
|
468
|
+
# If key_path is None (BYOK model), SSH will use default key from ~/.ssh/
|
|
469
|
+
ssh_opts = f"-p {ssh_port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
|
|
470
|
+
|
|
510
471
|
rsync_cmd = [
|
|
511
472
|
"rsync",
|
|
512
473
|
"-avz",
|
|
513
474
|
"-e",
|
|
514
|
-
f"ssh
|
|
475
|
+
f"ssh {ssh_opts}",
|
|
515
476
|
source,
|
|
516
|
-
f"{
|
|
477
|
+
f"{ssh_user}@{ssh_host}:/workspace/",
|
|
517
478
|
]
|
|
518
479
|
|
|
519
480
|
try:
|
|
@@ -589,16 +550,8 @@ def _init_sync_state(workspace_id: str) -> str | None:
|
|
|
589
550
|
return None
|
|
590
551
|
|
|
591
552
|
|
|
592
|
-
def
|
|
593
|
-
"""Get details
|
|
594
|
-
|
|
595
|
-
Args:
|
|
596
|
-
workspace_id: Workspace ID to get
|
|
597
|
-
json_output: If True, return raw JSON; otherwise return formatted text
|
|
598
|
-
|
|
599
|
-
Returns:
|
|
600
|
-
Workspace details as string
|
|
601
|
-
"""
|
|
553
|
+
def get_workspace_raw(workspace_id: str) -> dict:
|
|
554
|
+
"""Get workspace details as raw JSON dict."""
|
|
602
555
|
assert workspace_id, "Workspace ID must be non-empty"
|
|
603
556
|
|
|
604
557
|
api_url, headers = _get_client()
|
|
@@ -617,9 +570,29 @@ def get_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
|
617
570
|
except httpx.RequestError as e:
|
|
618
571
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
619
572
|
|
|
620
|
-
# Validate API response has required fields
|
|
621
573
|
assert "id" in workspace, "API response must contain workspace id"
|
|
622
574
|
assert "name" in workspace, "API response must contain workspace name"
|
|
575
|
+
|
|
576
|
+
status = workspace.get("status", "unknown")
|
|
577
|
+
assert status in VALID_STATUSES or status == "unknown", (
|
|
578
|
+
f"Workspace {workspace['id']} has invalid status '{status}'. "
|
|
579
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
return workspace
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def get_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
586
|
+
"""Get details of a specific workspace.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
workspace_id: Workspace ID to get
|
|
590
|
+
json_output: If True, return raw JSON; otherwise return formatted text
|
|
591
|
+
|
|
592
|
+
Returns:
|
|
593
|
+
Workspace details as string
|
|
594
|
+
"""
|
|
595
|
+
workspace = get_workspace_raw(workspace_id)
|
|
623
596
|
|
|
624
597
|
if json_output:
|
|
625
598
|
return json.dumps(workspace, indent=2)
|
|
@@ -643,15 +616,39 @@ def get_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
|
643
616
|
f" Port: {workspace.get('ssh_port', 22)}",
|
|
644
617
|
f" User: {workspace.get('ssh_user', 'root')}",
|
|
645
618
|
])
|
|
646
|
-
elif status
|
|
647
|
-
lines.extend([
|
|
648
|
-
"",
|
|
649
|
-
"SSH: Run 'wafer workspaces attach' to get SSH credentials",
|
|
650
|
-
])
|
|
619
|
+
elif status == "creating":
|
|
620
|
+
lines.extend(["", "SSH: available once workspace is running"])
|
|
651
621
|
|
|
652
622
|
return "\n".join(lines)
|
|
653
623
|
|
|
654
624
|
|
|
625
|
+
def _handle_sync_event(sync_type: str) -> None:
|
|
626
|
+
"""Handle sync events and print status to stderr.
|
|
627
|
+
|
|
628
|
+
Sync events:
|
|
629
|
+
- FORWARD:START - Starting workspace → GPU sync
|
|
630
|
+
- FORWARD:DONE:N - Synced N files to GPU
|
|
631
|
+
- FORWARD:WARN:msg - Warning during forward sync
|
|
632
|
+
- REVERSE:START - Starting GPU → workspace sync
|
|
633
|
+
- REVERSE:DONE:N - Synced N artifacts back
|
|
634
|
+
"""
|
|
635
|
+
import sys
|
|
636
|
+
|
|
637
|
+
if sync_type == "FORWARD:START":
|
|
638
|
+
print("[sync] Syncing workspace → GPU...", end="", file=sys.stderr, flush=True)
|
|
639
|
+
elif sync_type.startswith("FORWARD:DONE:"):
|
|
640
|
+
count = sync_type.split(":")[-1]
|
|
641
|
+
print(f" done ({count} files)", file=sys.stderr)
|
|
642
|
+
elif sync_type.startswith("FORWARD:WARN:"):
|
|
643
|
+
msg = sync_type[13:] # Remove "FORWARD:WARN:"
|
|
644
|
+
print(f" warning: {msg}", file=sys.stderr)
|
|
645
|
+
elif sync_type == "REVERSE:START":
|
|
646
|
+
print("[sync] Syncing artifacts back...", end="", file=sys.stderr, flush=True)
|
|
647
|
+
elif sync_type.startswith("REVERSE:DONE:"):
|
|
648
|
+
count = sync_type.split(":")[-1]
|
|
649
|
+
print(f" done ({count} files)", file=sys.stderr)
|
|
650
|
+
|
|
651
|
+
|
|
655
652
|
@dataclass(frozen=True)
|
|
656
653
|
class SSEEvent:
|
|
657
654
|
"""Parsed SSE event result."""
|
|
@@ -659,6 +656,7 @@ class SSEEvent:
|
|
|
659
656
|
output: str | None # Content to print (None = no output)
|
|
660
657
|
exit_code: int | None # Exit code if stream should end (None = continue)
|
|
661
658
|
is_error: bool # Whether output goes to stderr
|
|
659
|
+
sync_event: str | None = None # Sync event type (e.g., "FORWARD:START")
|
|
662
660
|
|
|
663
661
|
|
|
664
662
|
def _parse_sse_content(content: str) -> SSEEvent:
|
|
@@ -680,6 +678,16 @@ def _parse_sse_content(content: str) -> SSEEvent:
|
|
|
680
678
|
if content.startswith("[ERROR]"):
|
|
681
679
|
return SSEEvent(output=content[8:], exit_code=1, is_error=True)
|
|
682
680
|
|
|
681
|
+
# Sync events: [SYNC:FORWARD:START], [SYNC:FORWARD:DONE:5], etc.
|
|
682
|
+
if content.startswith("[SYNC:"):
|
|
683
|
+
# Extract sync type (e.g., "FORWARD:START" or "REVERSE:DONE:5")
|
|
684
|
+
sync_type = content[6:-1] # Remove [SYNC: and ]
|
|
685
|
+
return SSEEvent(output=None, exit_code=None, is_error=False, sync_event=sync_type)
|
|
686
|
+
|
|
687
|
+
# Status events we can ignore (already handled elsewhere)
|
|
688
|
+
if content.startswith("[STATUS:") or content.startswith("[CONTEXT:"):
|
|
689
|
+
return SSEEvent(output=None, exit_code=None, is_error=False)
|
|
690
|
+
|
|
683
691
|
# Regular output
|
|
684
692
|
return SSEEvent(output=content, exit_code=None, is_error=False)
|
|
685
693
|
|
|
@@ -688,13 +696,16 @@ def exec_command(
|
|
|
688
696
|
workspace_id: str,
|
|
689
697
|
command: str,
|
|
690
698
|
timeout_seconds: int | None = None,
|
|
699
|
+
routing: str | None = None,
|
|
700
|
+
pull_image: bool = False,
|
|
691
701
|
) -> int:
|
|
692
|
-
"""Execute a command in workspace
|
|
702
|
+
"""Execute a command in workspace, streaming output.
|
|
693
703
|
|
|
694
704
|
Args:
|
|
695
705
|
workspace_id: Workspace ID or name
|
|
696
706
|
command: Command to execute
|
|
697
707
|
timeout_seconds: Execution timeout (default: 300, from config)
|
|
708
|
+
routing: Routing hint - "auto", "gpu", "cpu", or "baremetal" (default: auto)
|
|
698
709
|
|
|
699
710
|
Returns:
|
|
700
711
|
Exit code (0 = success, non-zero = failure)
|
|
@@ -710,10 +721,14 @@ def exec_command(
|
|
|
710
721
|
# Base64 encode command to avoid escaping issues
|
|
711
722
|
command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
|
|
712
723
|
|
|
713
|
-
request_body: dict = {"command_b64": command_b64}
|
|
724
|
+
request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
|
|
714
725
|
if timeout_seconds:
|
|
715
726
|
request_body["timeout_seconds"] = timeout_seconds
|
|
716
727
|
|
|
728
|
+
# Add routing hint if specified
|
|
729
|
+
if routing:
|
|
730
|
+
request_body["requirements"] = {"routing": routing}
|
|
731
|
+
|
|
717
732
|
try:
|
|
718
733
|
# Use streaming request for SSE output
|
|
719
734
|
with httpx.Client(timeout=None, headers=headers) as client:
|
|
@@ -736,6 +751,11 @@ def exec_command(
|
|
|
736
751
|
|
|
737
752
|
event = _parse_sse_content(line[6:])
|
|
738
753
|
|
|
754
|
+
# Handle sync events - display status to stderr
|
|
755
|
+
if event.sync_event:
|
|
756
|
+
_handle_sync_event(event.sync_event)
|
|
757
|
+
continue
|
|
758
|
+
|
|
739
759
|
if event.output is not None:
|
|
740
760
|
print(event.output, file=sys.stderr if event.is_error else sys.stdout)
|
|
741
761
|
|
|
@@ -751,3 +771,83 @@ def exec_command(
|
|
|
751
771
|
) from e
|
|
752
772
|
except httpx.RequestError as e:
|
|
753
773
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def exec_command_capture(
|
|
777
|
+
workspace_id: str,
|
|
778
|
+
command: str,
|
|
779
|
+
timeout_seconds: int | None = None,
|
|
780
|
+
routing: str | None = None,
|
|
781
|
+
pull_image: bool = False,
|
|
782
|
+
) -> tuple[int, str]:
|
|
783
|
+
"""Execute a command in workspace and capture output.
|
|
784
|
+
|
|
785
|
+
Similar to exec_command but returns output as string instead of printing.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
workspace_id: Workspace ID or name
|
|
789
|
+
command: Command to execute
|
|
790
|
+
timeout_seconds: Execution timeout (default: 300)
|
|
791
|
+
routing: Routing hint - "auto", "gpu", "cpu", or "baremetal"
|
|
792
|
+
|
|
793
|
+
Returns:
|
|
794
|
+
Tuple of (exit_code, output_string)
|
|
795
|
+
"""
|
|
796
|
+
import base64
|
|
797
|
+
|
|
798
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
799
|
+
assert command, "Command must be non-empty"
|
|
800
|
+
|
|
801
|
+
api_url, headers = _get_client()
|
|
802
|
+
|
|
803
|
+
# Base64 encode command to avoid escaping issues
|
|
804
|
+
command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
|
|
805
|
+
|
|
806
|
+
request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
|
|
807
|
+
if timeout_seconds:
|
|
808
|
+
request_body["timeout_seconds"] = timeout_seconds
|
|
809
|
+
|
|
810
|
+
if routing:
|
|
811
|
+
request_body["requirements"] = {"routing": routing}
|
|
812
|
+
|
|
813
|
+
output_lines: list[str] = []
|
|
814
|
+
|
|
815
|
+
try:
|
|
816
|
+
with httpx.Client(timeout=None, headers=headers) as client:
|
|
817
|
+
with client.stream(
|
|
818
|
+
"POST",
|
|
819
|
+
f"{api_url}/v1/workspaces/{workspace_id}/exec",
|
|
820
|
+
json=request_body,
|
|
821
|
+
) as response:
|
|
822
|
+
if response.status_code != 200:
|
|
823
|
+
error_body = response.read().decode("utf-8", errors="replace")
|
|
824
|
+
raise RuntimeError(
|
|
825
|
+
_friendly_error(response.status_code, error_body, workspace_id)
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
exit_code = 0
|
|
829
|
+
for line in response.iter_lines():
|
|
830
|
+
if not line or not line.startswith("data: "):
|
|
831
|
+
continue
|
|
832
|
+
|
|
833
|
+
event = _parse_sse_content(line[6:])
|
|
834
|
+
|
|
835
|
+
# Skip sync events
|
|
836
|
+
if event.sync_event:
|
|
837
|
+
continue
|
|
838
|
+
|
|
839
|
+
if event.output is not None:
|
|
840
|
+
output_lines.append(event.output)
|
|
841
|
+
|
|
842
|
+
if event.exit_code is not None:
|
|
843
|
+
exit_code = event.exit_code
|
|
844
|
+
break
|
|
845
|
+
|
|
846
|
+
return exit_code, "\n".join(output_lines)
|
|
847
|
+
|
|
848
|
+
except httpx.HTTPStatusError as e:
|
|
849
|
+
raise RuntimeError(
|
|
850
|
+
_friendly_error(e.response.status_code, e.response.text, workspace_id)
|
|
851
|
+
) from e
|
|
852
|
+
except httpx.RequestError as e:
|
|
853
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|