wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/workspaces.py ADDED
@@ -0,0 +1,852 @@
1
+ """Workspaces CLI - Manage remote GPU workspaces.
2
+
3
+ This module provides the implementation for the `wafer workspaces` subcommand.
4
+ """
5
+
6
+ import json
7
+ from collections.abc import Callable
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+
11
+ import httpx
12
+
13
+ from .api_client import get_api_url
14
+ from .auth import get_auth_headers
15
+
16
+ VALID_STATUSES = {"creating", "running"}
17
+
18
+
19
+ def _get_client() -> tuple[str, dict[str, str]]:
20
+ """Get API URL and auth headers."""
21
+ api_url = get_api_url()
22
+ headers = get_auth_headers()
23
+
24
+ assert api_url, "API URL must be configured"
25
+ assert api_url.startswith("http"), "API URL must be a valid HTTP(S) URL"
26
+
27
+ return api_url, headers
28
+
29
+
30
+ def _friendly_error(status_code: int, response_text: str, workspace_id: str) -> str:
31
+ """Convert API errors to friendly messages with guidance.
32
+
33
+ Args:
34
+ status_code: HTTP status code
35
+ response_text: Response body
36
+ workspace_id: Workspace ID or name for context
37
+
38
+ Returns:
39
+ User-friendly error message with suggested next steps
40
+ """
41
+ if status_code == 401:
42
+ return "Not authenticated. Run: wafer login"
43
+
44
+ if status_code == 402:
45
+ return (
46
+ "Insufficient credits.\n"
47
+ " Check usage: wafer billing\n"
48
+ " Add credits: wafer billing topup"
49
+ )
50
+
51
+ if status_code == 404:
52
+ return (
53
+ f"Workspace '{workspace_id}' not found.\n"
54
+ " List workspaces: wafer workspaces list\n"
55
+ " Create one: wafer workspaces create <name>"
56
+ )
57
+
58
+ if status_code == 503:
59
+ return (
60
+ "No GPU available.\n"
61
+ " The workspace is queued for GPU access. Try again in a moment.\n"
62
+ " Check status: wafer workspaces show " + workspace_id
63
+ )
64
+
65
+ # Parse common error details from response
66
+ detail = ""
67
+ if "not running" in response_text.lower() or "not found" in response_text.lower():
68
+ return (
69
+ f"Workspace '{workspace_id}' not found or not running.\n"
70
+ " Check status: wafer workspaces list\n"
71
+ " Create new: wafer workspaces create <name>"
72
+ )
73
+
74
+ if "timeout" in response_text.lower():
75
+ return (
76
+ "Command timed out.\n"
77
+ ' Increase timeout: wafer workspaces exec <workspace> "cmd" --timeout 600\n'
78
+ " Or set default: wafer config set defaults.exec_timeout 600"
79
+ )
80
+
81
+ if "creating" in response_text.lower():
82
+ return (
83
+ f"Workspace '{workspace_id}' is still creating.\n"
84
+ " Check status: wafer workspaces list"
85
+ )
86
+
87
+ # Generic error with response detail
88
+ try:
89
+ import json
90
+
91
+ data = json.loads(response_text)
92
+ detail = data.get("detail", response_text)
93
+ except (json.JSONDecodeError, KeyError):
94
+ detail = response_text
95
+
96
+ return f"API error ({status_code}): {detail}"
97
+
98
+
99
+ def _list_workspaces_raw() -> list[dict]:
100
+ """List workspaces and return raw data (for internal use)."""
101
+ api_url, headers = _get_client()
102
+
103
+ try:
104
+ with httpx.Client(timeout=30.0, headers=headers) as client:
105
+ response = client.get(f"{api_url}/v1/workspaces")
106
+ response.raise_for_status()
107
+ workspaces = response.json()
108
+ except httpx.HTTPStatusError as e:
109
+ if e.response.status_code == 401:
110
+ raise RuntimeError("Not authenticated. Run: wafer login") from e
111
+ raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
112
+ except httpx.RequestError as e:
113
+ raise RuntimeError(f"Could not reach API: {e}") from e
114
+
115
+ assert isinstance(workspaces, list), "API must return a list of workspaces"
116
+
117
+ for ws in workspaces:
118
+ status = ws.get("status", "unknown")
119
+ assert status in VALID_STATUSES or status == "unknown", (
120
+ f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
121
+ f"Valid statuses: {VALID_STATUSES}"
122
+ )
123
+
124
+ return workspaces
125
+
126
+
127
+ def resolve_workspace(specified: str | None) -> str:
128
+ """Resolve workspace ID from specified name/ID, config default, or single workspace.
129
+
130
+ Priority:
131
+ 1. If specified, return it (API will resolve name vs ID)
132
+ 2. If config has defaults.workspace, return that
133
+ 3. If user has exactly one workspace, return its ID
134
+ 4. Otherwise, error with guidance
135
+
136
+ Args:
137
+ specified: Workspace name or ID, or None to use default
138
+
139
+ Returns:
140
+ Workspace name or ID to use
141
+
142
+ Raises:
143
+ RuntimeError: If no workspace can be resolved
144
+ """
145
+ from .global_config import get_defaults
146
+
147
+ # If specified, use it (API resolves name vs ID)
148
+ if specified:
149
+ return specified
150
+
151
+ # Check config default
152
+ defaults = get_defaults()
153
+ if defaults.workspace:
154
+ return defaults.workspace
155
+
156
+ # Check if user has exactly one workspace
157
+ workspaces = _list_workspaces_raw()
158
+
159
+ if len(workspaces) == 0:
160
+ raise RuntimeError("No workspaces found. Create one with: wafer workspaces create <name>")
161
+
162
+ if len(workspaces) == 1:
163
+ return workspaces[0]["id"]
164
+
165
+ # Multiple workspaces, no default - error with guidance
166
+ names = [ws.get("name", ws["id"]) for ws in workspaces]
167
+ raise RuntimeError(
168
+ f"Multiple workspaces found: {', '.join(names)}\n"
169
+ "Specify one, or set default: wafer config set defaults.workspace <name>"
170
+ )
171
+
172
+
173
+ def list_workspaces(json_output: bool = False) -> str:
174
+ """List all workspaces for the current user.
175
+
176
+ Args:
177
+ json_output: If True, return raw JSON; otherwise return formatted text
178
+
179
+ Returns:
180
+ Workspaces list as string (JSON or formatted text)
181
+ """
182
+ api_url, headers = _get_client()
183
+
184
+ try:
185
+ with httpx.Client(timeout=30.0, headers=headers) as client:
186
+ response = client.get(f"{api_url}/v1/workspaces")
187
+ response.raise_for_status()
188
+ workspaces = response.json()
189
+ except httpx.HTTPStatusError as e:
190
+ if e.response.status_code == 401:
191
+ raise RuntimeError("Not authenticated. Run: wafer login") from e
192
+ raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
193
+ except httpx.RequestError as e:
194
+ raise RuntimeError(f"Could not reach API: {e}") from e
195
+
196
+ assert isinstance(workspaces, list), "API must return a list of workspaces"
197
+
198
+ for ws in workspaces:
199
+ status = ws.get("status", "unknown")
200
+ assert status in VALID_STATUSES or status == "unknown", (
201
+ f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
202
+ f"Valid statuses: {VALID_STATUSES}"
203
+ )
204
+
205
+ if json_output:
206
+ return json.dumps(workspaces, indent=2)
207
+
208
+ if not workspaces:
209
+ return "No workspaces found."
210
+
211
+ lines = ["Workspaces:", ""]
212
+ for ws in workspaces:
213
+ status = ws.get("status", "unknown")
214
+ status_icon = {"running": "●", "creating": "◐"}.get(status, "?")
215
+ lines.append(f" {status_icon} {ws['name']} ({ws['id']})")
216
+ lines.append(f" GPU: {ws.get('gpu_type', 'N/A')} | Image: {ws.get('image', 'N/A')}")
217
+ if ws.get("ssh_host") and ws.get("ssh_port") and ws.get("ssh_user"):
218
+ lines.append(
219
+ f" SSH: ssh -p {ws['ssh_port']} {ws['ssh_user']}@{ws['ssh_host']}"
220
+ )
221
+ else:
222
+ lines.append(" SSH: Not ready (run: wafer workspaces ssh <name>)")
223
+ lines.append("")
224
+
225
+ return "\n".join(lines)
226
+
227
+
228
+ def create_workspace(
229
+ name: str,
230
+ gpu_type: str = "B200",
231
+ image: str | None = None,
232
+ wait: bool = False,
233
+ json_output: bool = False,
234
+ ) -> str:
235
+ """Create a new workspace.
236
+
237
+ Args:
238
+ name: Workspace name (must be unique)
239
+ gpu_type: GPU type (default: B200)
240
+ image: Docker image (optional, uses default if not specified)
241
+ wait: If True, stream provisioning progress and return SSH credentials
242
+ json_output: If True, return raw JSON; otherwise return formatted text
243
+
244
+ Returns:
245
+ Created workspace info as string
246
+
247
+ Raises:
248
+ RuntimeError: If name already exists or API error
249
+ """
250
+ # Validate inputs
251
+ assert name, "Workspace name must be non-empty"
252
+ assert gpu_type, "GPU type must be non-empty"
253
+
254
+ api_url, headers = _get_client()
255
+
256
+ # Check for duplicate name
257
+ try:
258
+ with httpx.Client(timeout=30.0, headers=headers) as client:
259
+ response = client.get(f"{api_url}/v1/workspaces")
260
+ response.raise_for_status()
261
+ existing = response.json()
262
+ existing_names = [ws.get("name") for ws in existing]
263
+ if name in existing_names:
264
+ raise RuntimeError(
265
+ f"Workspace '{name}' already exists.\n"
266
+ f" Use a different name, or delete the existing one:\n"
267
+ f" wafer workspaces delete {name}"
268
+ )
269
+ except httpx.HTTPStatusError:
270
+ pass # Continue with create, let API handle auth errors
271
+ except httpx.RequestError:
272
+ pass # Continue with create, let API handle connection errors
273
+
274
+ request_body: dict = {
275
+ "name": name,
276
+ "gpu_type": gpu_type,
277
+ }
278
+ if image:
279
+ request_body["image"] = image
280
+
281
+ try:
282
+ with httpx.Client(timeout=60.0, headers=headers) as client:
283
+ response = client.post(f"{api_url}/v1/workspaces", json=request_body)
284
+ response.raise_for_status()
285
+ workspace = response.json()
286
+ except httpx.HTTPStatusError as e:
287
+ if e.response.status_code == 401:
288
+ raise RuntimeError("Not authenticated. Run: wafer login") from e
289
+ if e.response.status_code == 400:
290
+ raise RuntimeError(f"Bad request: {e.response.text}") from e
291
+ raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
292
+ except httpx.RequestError as e:
293
+ raise RuntimeError(f"Could not reach API: {e}") from e
294
+
295
+ # Validate API response has required fields
296
+ assert "id" in workspace, "API response must contain workspace id"
297
+ assert "name" in workspace, "API response must contain workspace name"
298
+
299
+ if wait:
300
+ ssh_info = _wait_for_provisioning(workspace["id"])
301
+ if json_output:
302
+ payload = {
303
+ "workspace_id": workspace["id"],
304
+ "ssh_host": ssh_info["ssh_host"],
305
+ "ssh_port": ssh_info["ssh_port"],
306
+ "ssh_user": ssh_info["ssh_user"],
307
+ }
308
+ return json.dumps(payload, indent=2)
309
+ return (
310
+ f"Workspace ready: {workspace['name']} ({workspace['id']})\n"
311
+ f"SSH: ssh -p {ssh_info['ssh_port']} {ssh_info['ssh_user']}@{ssh_info['ssh_host']}"
312
+ )
313
+
314
+ if json_output:
315
+ return json.dumps(workspace, indent=2)
316
+
317
+ return (
318
+ f"Creating workspace: {workspace['name']} ({workspace['id']})\n"
319
+ "Check status with: wafer workspaces list\n"
320
+ "Estimated time: ~30 seconds"
321
+ )
322
+
323
+
324
+ def _wait_for_provisioning(workspace_id: str) -> dict[str, str | int]:
325
+ """Wait for workspace provisioning to complete via SSE."""
326
+ import sys
327
+
328
+ assert workspace_id, "Workspace ID must be non-empty"
329
+ api_url, headers = _get_client()
330
+
331
+ try:
332
+ with httpx.Client(timeout=None, headers=headers) as client:
333
+ with client.stream(
334
+ "POST",
335
+ f"{api_url}/v1/workspaces/{workspace_id}/provision-stream",
336
+ ) as response:
337
+ if response.status_code != 200:
338
+ error_body = response.read().decode("utf-8", errors="replace")
339
+ raise RuntimeError(
340
+ _friendly_error(response.status_code, error_body, workspace_id)
341
+ )
342
+
343
+ ssh_info: dict[str, str | int] | None = None
344
+ for line in response.iter_lines():
345
+ if not line or not line.startswith("data: "):
346
+ continue
347
+ content = line[6:]
348
+ if content.startswith("[STATUS:"):
349
+ status = content[8:-1]
350
+ print(f"[wafer] {status.lower()}...", file=sys.stderr)
351
+ if status == "ERROR":
352
+ raise RuntimeError(
353
+ "Workspace provisioning failed. Check status with: wafer workspaces list"
354
+ )
355
+ elif content.startswith("[SSH:"):
356
+ parts = content[5:-1].split(":")
357
+ if len(parts) != 3:
358
+ raise RuntimeError("Malformed SSH info in provisioning stream")
359
+ ssh_info = {
360
+ "ssh_host": parts[0],
361
+ "ssh_port": int(parts[1]),
362
+ "ssh_user": parts[2],
363
+ }
364
+ break
365
+
366
+ if ssh_info is None:
367
+ raise RuntimeError("Provisioning did not return SSH credentials")
368
+ return ssh_info
369
+ except httpx.RequestError as e:
370
+ raise RuntimeError(f"Could not reach API: {e}") from e
371
+
372
+
373
+ def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
374
+ """Delete a workspace.
375
+
376
+ Args:
377
+ workspace_id: Workspace ID to delete
378
+ json_output: If True, return raw JSON; otherwise return formatted text
379
+
380
+ Returns:
381
+ Deletion status as string
382
+ """
383
+ assert workspace_id, "Workspace ID must be non-empty"
384
+
385
+ api_url, headers = _get_client()
386
+
387
+ try:
388
+ with httpx.Client(timeout=30.0, headers=headers) as client:
389
+ response = client.delete(f"{api_url}/v1/workspaces/{workspace_id}")
390
+ response.raise_for_status()
391
+ result = response.json()
392
+ except httpx.HTTPStatusError as e:
393
+ if e.response.status_code == 401:
394
+ raise RuntimeError("Not authenticated. Run: wafer login") from e
395
+ if e.response.status_code == 404:
396
+ raise RuntimeError(f"Workspace not found: {workspace_id}") from e
397
+ raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
398
+ except httpx.RequestError as e:
399
+ raise RuntimeError(f"Could not reach API: {e}") from e
400
+
401
+ if json_output:
402
+ return json.dumps(result, indent=2)
403
+
404
+ return f"Deleted workspace: {workspace_id}"
405
+
406
+
407
+ def sync_files(
408
+ workspace_id: str,
409
+ local_path: Path,
410
+ on_progress: Callable[[str], None] | None = None,
411
+ ) -> tuple[int, str | None]:
412
+ """Sync local files or directories to workspace via rsync over SSH.
413
+
414
+ After rsync completes, calls the API to sync files to Modal volume
415
+ so they're available for exec commands.
416
+
417
+ Args:
418
+ workspace_id: Workspace ID or name
419
+ local_path: Local file or directory to sync
420
+ on_progress: Optional callback for progress messages
421
+
422
+ Returns:
423
+ Tuple of (file_count, warning_message). Warning is None on success.
424
+
425
+ Raises:
426
+ RuntimeError: If rsync fails
427
+ """
428
+ import subprocess
429
+
430
+ def emit(msg: str) -> None:
431
+ if on_progress:
432
+ on_progress(msg)
433
+
434
+ assert workspace_id, "Workspace ID must be non-empty"
435
+ assert local_path.exists(), f"Path not found: {local_path}"
436
+
437
+ ws = get_workspace_raw(workspace_id)
438
+ resolved_id = ws["id"]
439
+ workspace_status = ws.get("status")
440
+ assert workspace_status in VALID_STATUSES, (
441
+ f"Workspace {workspace_id} has invalid status '{workspace_status}'. "
442
+ f"Valid statuses: {VALID_STATUSES}"
443
+ )
444
+ if workspace_status != "running":
445
+ raise RuntimeError(
446
+ f"Workspace is {workspace_status}. Wait for it to be running before syncing."
447
+ )
448
+ ssh_host = ws.get("ssh_host")
449
+ ssh_port = ws.get("ssh_port")
450
+ ssh_user = ws.get("ssh_user")
451
+ assert ssh_host, "Workspace missing ssh_host"
452
+ assert isinstance(ssh_port, int) and ssh_port > 0, "Workspace missing valid ssh_port"
453
+ assert ssh_user, "Workspace missing ssh_user"
454
+
455
+ # Build rsync command
456
+ # -a: archive mode (preserves permissions, etc.)
457
+ # -v: verbose
458
+ # -z: compress during transfer
459
+ if local_path.is_dir():
460
+ # Directory: sync contents (trailing slash)
461
+ source = f"{local_path}/"
462
+ else:
463
+ # Single file: sync the file itself
464
+ source = str(local_path)
465
+
466
+ # Build SSH command for rsync
467
+ # If key_path is None (BYOK model), SSH will use default key from ~/.ssh/
468
+ ssh_opts = f"-p {ssh_port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
469
+
470
+ rsync_cmd = [
471
+ "rsync",
472
+ "-avz",
473
+ "-e",
474
+ f"ssh {ssh_opts}",
475
+ source,
476
+ f"{ssh_user}@{ssh_host}:/workspace/",
477
+ ]
478
+
479
+ try:
480
+ result = subprocess.run(rsync_cmd, capture_output=True, text=True)
481
+ if result.returncode != 0:
482
+ raise RuntimeError(f"rsync failed: {result.stderr}")
483
+
484
+ # Count files from rsync output (lines that don't start with special chars)
485
+ lines = result.stdout.strip().split("\n")
486
+ file_count = sum(
487
+ 1
488
+ for line in lines
489
+ if line and not line.startswith((" ", "sent", "total", "receiving", "building"))
490
+ )
491
+
492
+ except FileNotFoundError:
493
+ raise RuntimeError("rsync not found. Install rsync to use sync feature.") from None
494
+ except subprocess.SubprocessError as e:
495
+ raise RuntimeError(f"Sync failed: {e}") from e
496
+
497
+ emit(f"Synced {file_count} files to SSH host")
498
+
499
+ # Notify API to sync files to Modal volume (so exec can see them)
500
+ # Use resolved UUID, not the name
501
+ emit("Syncing to Modal volume...")
502
+ warning = _init_sync_state(resolved_id)
503
+
504
+ if warning:
505
+ emit(f"Modal sync warning: {warning}")
506
+ else:
507
+ emit("Modal sync complete")
508
+
509
+ return file_count, warning
510
+
511
+
512
+ def _init_sync_state(workspace_id: str) -> str | None:
513
+ """Tell API to sync files from bare metal to Modal volume.
514
+
515
+ This must be called after rsync completes so exec commands
516
+ can access the synced files.
517
+
518
+ Returns:
519
+ None on success, warning message on failure (non-fatal)
520
+ """
521
+ api_url, headers = _get_client()
522
+
523
+ try:
524
+ with httpx.Client(timeout=120.0, headers=headers) as client:
525
+ response = client.post(f"{api_url}/v1/workspaces/{workspace_id}/init-sync-state")
526
+ response.raise_for_status()
527
+ return None
528
+ except httpx.HTTPStatusError as e:
529
+ # Non-fatal: sync to bare metal succeeded, Modal sync failed
530
+ # User can still SSH in and use files, just not via exec
531
+ if e.response.status_code == 404:
532
+ # Workspace not found or no target - sync still worked for SSH
533
+ return None
534
+ else:
535
+ # Extract error detail from response if available
536
+ detail = ""
537
+ try:
538
+ data = e.response.json()
539
+ detail = data.get("detail", "")
540
+ except Exception:
541
+ detail = e.response.text[:200] if e.response.text else ""
542
+
543
+ # Return warning instead of raising - rsync succeeded
544
+ if detail:
545
+ return f"Files synced to SSH, but Modal sync failed: {detail}"
546
+ return f"Files synced to SSH, but Modal sync failed ({e.response.status_code}). Use SSH or retry sync."
547
+ except httpx.RequestError:
548
+ # Network error - sync to bare metal succeeded
549
+ return None
550
+
551
+
552
+ def get_workspace_raw(workspace_id: str) -> dict:
553
+ """Get workspace details as raw JSON dict."""
554
+ assert workspace_id, "Workspace ID must be non-empty"
555
+
556
+ api_url, headers = _get_client()
557
+
558
+ try:
559
+ with httpx.Client(timeout=30.0, headers=headers) as client:
560
+ response = client.get(f"{api_url}/v1/workspaces/{workspace_id}")
561
+ response.raise_for_status()
562
+ workspace = response.json()
563
+ except httpx.HTTPStatusError as e:
564
+ if e.response.status_code == 401:
565
+ raise RuntimeError("Not authenticated. Run: wafer login") from e
566
+ if e.response.status_code == 404:
567
+ raise RuntimeError(f"Workspace not found: {workspace_id}") from e
568
+ raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
569
+ except httpx.RequestError as e:
570
+ raise RuntimeError(f"Could not reach API: {e}") from e
571
+
572
+ assert "id" in workspace, "API response must contain workspace id"
573
+ assert "name" in workspace, "API response must contain workspace name"
574
+
575
+ status = workspace.get("status", "unknown")
576
+ assert status in VALID_STATUSES or status == "unknown", (
577
+ f"Workspace {workspace['id']} has invalid status '{status}'. "
578
+ f"Valid statuses: {VALID_STATUSES}"
579
+ )
580
+
581
+ return workspace
582
+
583
+
584
+ def get_workspace(workspace_id: str, json_output: bool = False) -> str:
585
+ """Get details of a specific workspace.
586
+
587
+ Args:
588
+ workspace_id: Workspace ID to get
589
+ json_output: If True, return raw JSON; otherwise return formatted text
590
+
591
+ Returns:
592
+ Workspace details as string
593
+ """
594
+ workspace = get_workspace_raw(workspace_id)
595
+
596
+ if json_output:
597
+ return json.dumps(workspace, indent=2)
598
+
599
+ status = workspace.get("status", "unknown")
600
+ lines = [
601
+ f"Workspace: {workspace['name']} ({workspace['id']})",
602
+ "",
603
+ f" Status: {status}",
604
+ f" GPU Type: {workspace.get('gpu_type', 'N/A')}",
605
+ f" Image: {workspace.get('image', 'N/A')}",
606
+ f" Created: {workspace.get('created_at', 'N/A')}",
607
+ f" Last Used: {workspace.get('last_used_at', 'N/A')}",
608
+ ]
609
+
610
+ if workspace.get("ssh_host"):
611
+ lines.extend([
612
+ "",
613
+ "SSH Info:",
614
+ f" Host: {workspace['ssh_host']}",
615
+ f" Port: {workspace.get('ssh_port', 22)}",
616
+ f" User: {workspace.get('ssh_user', 'root')}",
617
+ ])
618
+ elif status == "creating":
619
+ lines.extend(["", "SSH: available once workspace is running"])
620
+
621
+ return "\n".join(lines)
622
+
623
+
624
+ def _handle_sync_event(sync_type: str) -> None:
625
+ """Handle sync events and print status to stderr.
626
+
627
+ Sync events:
628
+ - FORWARD:START - Starting workspace → GPU sync
629
+ - FORWARD:DONE:N - Synced N files to GPU
630
+ - FORWARD:WARN:msg - Warning during forward sync
631
+ - REVERSE:START - Starting GPU → workspace sync
632
+ - REVERSE:DONE:N - Synced N artifacts back
633
+ """
634
+ import sys
635
+
636
+ if sync_type == "FORWARD:START":
637
+ print("[sync] Syncing workspace → GPU...", end="", file=sys.stderr, flush=True)
638
+ elif sync_type.startswith("FORWARD:DONE:"):
639
+ count = sync_type.split(":")[-1]
640
+ print(f" done ({count} files)", file=sys.stderr)
641
+ elif sync_type.startswith("FORWARD:WARN:"):
642
+ msg = sync_type[13:] # Remove "FORWARD:WARN:"
643
+ print(f" warning: {msg}", file=sys.stderr)
644
+ elif sync_type == "REVERSE:START":
645
+ print("[sync] Syncing artifacts back...", end="", file=sys.stderr, flush=True)
646
+ elif sync_type.startswith("REVERSE:DONE:"):
647
+ count = sync_type.split(":")[-1]
648
+ print(f" done ({count} files)", file=sys.stderr)
649
+
650
+
651
+ @dataclass(frozen=True)
652
+ class SSEEvent:
653
+ """Parsed SSE event result."""
654
+
655
+ output: str | None # Content to print (None = no output)
656
+ exit_code: int | None # Exit code if stream should end (None = continue)
657
+ is_error: bool # Whether output goes to stderr
658
+ sync_event: str | None = None # Sync event type (e.g., "FORWARD:START")
659
+
660
+
661
+ def _parse_sse_content(content: str) -> SSEEvent:
662
+ """Parse SSE content into structured event.
663
+
664
+ Pure function: content in, event out. No side effects.
665
+ """
666
+ if content == "[DONE]":
667
+ return SSEEvent(output=None, exit_code=0, is_error=False)
668
+
669
+ if content.startswith("[EXIT:"):
670
+ # Parse exit code: [EXIT:0] or [EXIT:1]
671
+ try:
672
+ code = int(content[6:-1])
673
+ except ValueError:
674
+ code = 0
675
+ return SSEEvent(output=None, exit_code=code, is_error=False)
676
+
677
+ if content.startswith("[ERROR]"):
678
+ return SSEEvent(output=content[8:], exit_code=1, is_error=True)
679
+
680
+ # Sync events: [SYNC:FORWARD:START], [SYNC:FORWARD:DONE:5], etc.
681
+ if content.startswith("[SYNC:"):
682
+ # Extract sync type (e.g., "FORWARD:START" or "REVERSE:DONE:5")
683
+ sync_type = content[6:-1] # Remove [SYNC: and ]
684
+ return SSEEvent(output=None, exit_code=None, is_error=False, sync_event=sync_type)
685
+
686
+ # Status events we can ignore (already handled elsewhere)
687
+ if content.startswith("[STATUS:") or content.startswith("[CONTEXT:"):
688
+ return SSEEvent(output=None, exit_code=None, is_error=False)
689
+
690
+ # Regular output
691
+ return SSEEvent(output=content, exit_code=None, is_error=False)
692
+
693
+
694
+ def exec_command(
695
+ workspace_id: str,
696
+ command: str,
697
+ timeout_seconds: int | None = None,
698
+ routing: str | None = None,
699
+ pull_image: bool = False,
700
+ ) -> int:
701
+ """Execute a command in workspace, streaming output.
702
+
703
+ Args:
704
+ workspace_id: Workspace ID or name
705
+ command: Command to execute
706
+ timeout_seconds: Execution timeout (default: 300, from config)
707
+ routing: Routing hint - "auto", "gpu", "cpu", or "baremetal" (default: auto)
708
+
709
+ Returns:
710
+ Exit code (0 = success, non-zero = failure)
711
+ """
712
+ import base64
713
+ import sys
714
+
715
+ assert workspace_id, "Workspace ID must be non-empty"
716
+ assert command, "Command must be non-empty"
717
+
718
+ api_url, headers = _get_client()
719
+
720
+ # Base64 encode command to avoid escaping issues
721
+ command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
722
+
723
+ request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
724
+ if timeout_seconds:
725
+ request_body["timeout_seconds"] = timeout_seconds
726
+
727
+ # Add routing hint if specified
728
+ if routing:
729
+ request_body["requirements"] = {"routing": routing}
730
+
731
+ try:
732
+ # Use streaming request for SSE output
733
+ with httpx.Client(timeout=None, headers=headers) as client:
734
+ with client.stream(
735
+ "POST",
736
+ f"{api_url}/v1/workspaces/{workspace_id}/exec",
737
+ json=request_body,
738
+ ) as response:
739
+ if response.status_code != 200:
740
+ # Read error body and provide friendly message
741
+ error_body = response.read().decode("utf-8", errors="replace")
742
+ raise RuntimeError(
743
+ _friendly_error(response.status_code, error_body, workspace_id)
744
+ )
745
+
746
+ exit_code = 0
747
+ for line in response.iter_lines():
748
+ if not line or not line.startswith("data: "):
749
+ continue
750
+
751
+ event = _parse_sse_content(line[6:])
752
+
753
+ # Handle sync events - display status to stderr
754
+ if event.sync_event:
755
+ _handle_sync_event(event.sync_event)
756
+ continue
757
+
758
+ if event.output is not None:
759
+ print(event.output, file=sys.stderr if event.is_error else sys.stdout)
760
+
761
+ if event.exit_code is not None:
762
+ exit_code = event.exit_code
763
+ break
764
+
765
+ return exit_code
766
+
767
+ except httpx.HTTPStatusError as e:
768
+ raise RuntimeError(
769
+ _friendly_error(e.response.status_code, e.response.text, workspace_id)
770
+ ) from e
771
+ except httpx.RequestError as e:
772
+ raise RuntimeError(f"Could not reach API: {e}") from e
773
+
774
+
775
+ def exec_command_capture(
776
+ workspace_id: str,
777
+ command: str,
778
+ timeout_seconds: int | None = None,
779
+ routing: str | None = None,
780
+ pull_image: bool = False,
781
+ ) -> tuple[int, str]:
782
+ """Execute a command in workspace and capture output.
783
+
784
+ Similar to exec_command but returns output as string instead of printing.
785
+
786
+ Args:
787
+ workspace_id: Workspace ID or name
788
+ command: Command to execute
789
+ timeout_seconds: Execution timeout (default: 300)
790
+ routing: Routing hint - "auto", "gpu", "cpu", or "baremetal"
791
+
792
+ Returns:
793
+ Tuple of (exit_code, output_string)
794
+ """
795
+ import base64
796
+
797
+ assert workspace_id, "Workspace ID must be non-empty"
798
+ assert command, "Command must be non-empty"
799
+
800
+ api_url, headers = _get_client()
801
+
802
+ # Base64 encode command to avoid escaping issues
803
+ command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
804
+
805
+ request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
806
+ if timeout_seconds:
807
+ request_body["timeout_seconds"] = timeout_seconds
808
+
809
+ if routing:
810
+ request_body["requirements"] = {"routing": routing}
811
+
812
+ output_lines: list[str] = []
813
+
814
+ try:
815
+ with httpx.Client(timeout=None, headers=headers) as client:
816
+ with client.stream(
817
+ "POST",
818
+ f"{api_url}/v1/workspaces/{workspace_id}/exec",
819
+ json=request_body,
820
+ ) as response:
821
+ if response.status_code != 200:
822
+ error_body = response.read().decode("utf-8", errors="replace")
823
+ raise RuntimeError(
824
+ _friendly_error(response.status_code, error_body, workspace_id)
825
+ )
826
+
827
+ exit_code = 0
828
+ for line in response.iter_lines():
829
+ if not line or not line.startswith("data: "):
830
+ continue
831
+
832
+ event = _parse_sse_content(line[6:])
833
+
834
+ # Skip sync events
835
+ if event.sync_event:
836
+ continue
837
+
838
+ if event.output is not None:
839
+ output_lines.append(event.output)
840
+
841
+ if event.exit_code is not None:
842
+ exit_code = event.exit_code
843
+ break
844
+
845
+ return exit_code, "\n".join(output_lines)
846
+
847
+ except httpx.HTTPStatusError as e:
848
+ raise RuntimeError(
849
+ _friendly_error(e.response.status_code, e.response.text, workspace_id)
850
+ ) from e
851
+ except httpx.RequestError as e:
852
+ raise RuntimeError(f"Could not reach API: {e}") from e