wafer-cli 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/workspaces.py CHANGED
@@ -13,15 +13,7 @@ import httpx
13
13
  from .api_client import get_api_url
14
14
  from .auth import get_auth_headers
15
15
 
16
-
17
- @dataclass(frozen=True)
18
- class SSHCredentials:
19
- """SSH credentials for workspace access."""
20
-
21
- host: str
22
- port: int
23
- user: str
24
- key_path: Path
16
+ VALID_STATUSES = {"creating", "running"}
25
17
 
26
18
 
27
19
  def _get_client() -> tuple[str, dict[str, str]]:
@@ -72,10 +64,11 @@ def _friendly_error(status_code: int, response_text: str, workspace_id: str) ->
72
64
 
73
65
  # Parse common error details from response
74
66
  detail = ""
75
- if "stopped" in response_text.lower():
67
+ if "not running" in response_text.lower() or "not found" in response_text.lower():
76
68
  return (
77
- f"Workspace '{workspace_id}' is stopped.\n"
78
- " Attach to start it: wafer workspaces attach " + workspace_id
69
+ f"Workspace '{workspace_id}' not found or not running.\n"
70
+ " Check status: wafer workspaces list\n"
71
+ " Create new: wafer workspaces create <name>"
79
72
  )
80
73
 
81
74
  if "timeout" in response_text.lower():
@@ -85,6 +78,12 @@ def _friendly_error(status_code: int, response_text: str, workspace_id: str) ->
85
78
  " Or set default: wafer config set defaults.exec_timeout 600"
86
79
  )
87
80
 
81
+ if "creating" in response_text.lower():
82
+ return (
83
+ f"Workspace '{workspace_id}' is still creating.\n"
84
+ " Check status: wafer workspaces list"
85
+ )
86
+
88
87
  # Generic error with response detail
89
88
  try:
90
89
  import json
@@ -114,6 +113,14 @@ def _list_workspaces_raw() -> list[dict]:
114
113
  raise RuntimeError(f"Could not reach API: {e}") from e
115
114
 
116
115
  assert isinstance(workspaces, list), "API must return a list of workspaces"
116
+
117
+ for ws in workspaces:
118
+ status = ws.get("status", "unknown")
119
+ assert status in VALID_STATUSES or status == "unknown", (
120
+ f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
121
+ f"Valid statuses: {VALID_STATUSES}"
122
+ )
123
+
117
124
  return workspaces
118
125
 
119
126
 
@@ -186,9 +193,15 @@ def list_workspaces(json_output: bool = False) -> str:
186
193
  except httpx.RequestError as e:
187
194
  raise RuntimeError(f"Could not reach API: {e}") from e
188
195
 
189
- # Validate API response shape
190
196
  assert isinstance(workspaces, list), "API must return a list of workspaces"
191
197
 
198
+ for ws in workspaces:
199
+ status = ws.get("status", "unknown")
200
+ assert status in VALID_STATUSES or status == "unknown", (
201
+ f"Workspace {ws.get('id', 'unknown')} has invalid status '{status}'. "
202
+ f"Valid statuses: {VALID_STATUSES}"
203
+ )
204
+
192
205
  if json_output:
193
206
  return json.dumps(workspaces, indent=2)
194
207
 
@@ -198,10 +211,14 @@ def list_workspaces(json_output: bool = False) -> str:
198
211
  lines = ["Workspaces:", ""]
199
212
  for ws in workspaces:
200
213
  status = ws.get("status", "unknown")
201
- status_icon = {"running": "●", "stopped": "○", "queued": "◐"}.get(status, "?")
214
+ status_icon = {"running": "●", "creating": "◐"}.get(status, "?")
202
215
  lines.append(f" {status_icon} {ws['name']} ({ws['id']})")
203
216
  lines.append(f" GPU: {ws.get('gpu_type', 'N/A')} | Image: {ws.get('image', 'N/A')}")
204
217
  lines.append(f" Status: {status} | Created: {ws.get('created_at', 'N/A')}")
218
+ if status == "running" and ws.get("ssh_host") and ws.get("ssh_port") and ws.get("ssh_user"):
219
+ lines.append(
220
+ f" SSH: ssh -p {ws['ssh_port']} {ws['ssh_user']}@{ws['ssh_host']}"
221
+ )
205
222
  lines.append("")
206
223
 
207
224
  return "\n".join(lines)
@@ -211,6 +228,7 @@ def create_workspace(
211
228
  name: str,
212
229
  gpu_type: str = "B200",
213
230
  image: str | None = None,
231
+ wait: bool = False,
214
232
  json_output: bool = False,
215
233
  ) -> str:
216
234
  """Create a new workspace.
@@ -219,6 +237,7 @@ def create_workspace(
219
237
  name: Workspace name (must be unique)
220
238
  gpu_type: GPU type (default: B200)
221
239
  image: Docker image (optional, uses default if not specified)
240
+ wait: If True, stream provisioning progress and return SSH credentials
222
241
  json_output: If True, return raw JSON; otherwise return formatted text
223
242
 
224
243
  Returns:
@@ -276,190 +295,114 @@ def create_workspace(
276
295
  assert "id" in workspace, "API response must contain workspace id"
277
296
  assert "name" in workspace, "API response must contain workspace name"
278
297
 
298
+ if wait:
299
+ ssh_info = _wait_for_provisioning(workspace["id"])
300
+ if json_output:
301
+ payload = {
302
+ "workspace_id": workspace["id"],
303
+ "ssh_host": ssh_info["ssh_host"],
304
+ "ssh_port": ssh_info["ssh_port"],
305
+ "ssh_user": ssh_info["ssh_user"],
306
+ }
307
+ return json.dumps(payload, indent=2)
308
+ return (
309
+ f"Workspace ready: {workspace['name']} ({workspace['id']})\n"
310
+ f"SSH: ssh -p {ssh_info['ssh_port']} {ssh_info['ssh_user']}@{ssh_info['ssh_host']}"
311
+ )
312
+
279
313
  if json_output:
280
314
  return json.dumps(workspace, indent=2)
281
315
 
282
- return f"Created workspace: {workspace['name']} ({workspace['id']})"
283
-
316
+ return (
317
+ f"Creating workspace: {workspace['name']} ({workspace['id']})\n"
318
+ "Check status with: wafer workspaces list\n"
319
+ "Estimated time: ~30 seconds"
320
+ )
284
321
 
285
- def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
286
- """Delete a workspace.
287
322
 
288
- Args:
289
- workspace_id: Workspace ID to delete
290
- json_output: If True, return raw JSON; otherwise return formatted text
323
+ def _wait_for_provisioning(workspace_id: str) -> dict[str, str | int]:
324
+ """Wait for workspace provisioning to complete via SSE."""
325
+ import sys
291
326
 
292
- Returns:
293
- Deletion status as string
294
- """
295
327
  assert workspace_id, "Workspace ID must be non-empty"
296
-
297
328
  api_url, headers = _get_client()
298
329
 
299
330
  try:
300
- with httpx.Client(timeout=30.0, headers=headers) as client:
301
- response = client.delete(f"{api_url}/v1/workspaces/{workspace_id}")
302
- response.raise_for_status()
303
- result = response.json()
304
- except httpx.HTTPStatusError as e:
305
- if e.response.status_code == 401:
306
- raise RuntimeError("Not authenticated. Run: wafer login") from e
307
- if e.response.status_code == 404:
308
- raise RuntimeError(f"Workspace not found: {workspace_id}") from e
309
- raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
310
- except httpx.RequestError as e:
311
- raise RuntimeError(f"Could not reach API: {e}") from e
331
+ with httpx.Client(timeout=None, headers=headers) as client:
332
+ with client.stream(
333
+ "POST",
334
+ f"{api_url}/v1/workspaces/{workspace_id}/provision-stream",
335
+ ) as response:
336
+ if response.status_code != 200:
337
+ error_body = response.read().decode("utf-8", errors="replace")
338
+ raise RuntimeError(
339
+ _friendly_error(response.status_code, error_body, workspace_id)
340
+ )
312
341
 
313
- if json_output:
314
- return json.dumps(result, indent=2)
342
+ ssh_info: dict[str, str | int] | None = None
343
+ for line in response.iter_lines():
344
+ if not line or not line.startswith("data: "):
345
+ continue
346
+ content = line[6:]
347
+ if content.startswith("[STATUS:"):
348
+ status = content[8:-1]
349
+ print(f"[wafer] {status.lower()}...", file=sys.stderr)
350
+ if status == "ERROR":
351
+ raise RuntimeError(
352
+ "Workspace provisioning failed. Check status with: wafer workspaces list"
353
+ )
354
+ elif content.startswith("[SSH:"):
355
+ parts = content[5:-1].split(":")
356
+ if len(parts) != 3:
357
+ raise RuntimeError("Malformed SSH info in provisioning stream")
358
+ ssh_info = {
359
+ "ssh_host": parts[0],
360
+ "ssh_port": int(parts[1]),
361
+ "ssh_user": parts[2],
362
+ }
363
+ break
315
364
 
316
- return f"Deleted workspace: {workspace_id}"
365
+ if ssh_info is None:
366
+ raise RuntimeError("Provisioning did not return SSH credentials")
367
+ return ssh_info
368
+ except httpx.RequestError as e:
369
+ raise RuntimeError(f"Could not reach API: {e}") from e
317
370
 
318
371
 
319
- def attach_workspace(workspace_id: str, json_output: bool = False) -> str:
320
- """Attach to a workspace (get SSH credentials).
372
+ def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
373
+ """Delete a workspace.
321
374
 
322
375
  Args:
323
- workspace_id: Workspace ID to attach to
376
+ workspace_id: Workspace ID to delete
324
377
  json_output: If True, return raw JSON; otherwise return formatted text
325
378
 
326
379
  Returns:
327
- SSH connection info as string
380
+ Deletion status as string
328
381
  """
329
382
  assert workspace_id, "Workspace ID must be non-empty"
330
383
 
331
384
  api_url, headers = _get_client()
332
385
 
333
386
  try:
334
- with httpx.Client(timeout=120.0, headers=headers) as client:
335
- response = client.post(f"{api_url}/v1/workspaces/{workspace_id}/attach")
387
+ with httpx.Client(timeout=30.0, headers=headers) as client:
388
+ response = client.delete(f"{api_url}/v1/workspaces/{workspace_id}")
336
389
  response.raise_for_status()
337
- attach_info = response.json()
390
+ result = response.json()
338
391
  except httpx.HTTPStatusError as e:
339
392
  if e.response.status_code == 401:
340
393
  raise RuntimeError("Not authenticated. Run: wafer login") from e
341
394
  if e.response.status_code == 404:
342
395
  raise RuntimeError(f"Workspace not found: {workspace_id}") from e
343
- if e.response.status_code == 503:
344
- raise RuntimeError("No GPU available. Please try again later.") from e
345
396
  raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
346
397
  except httpx.RequestError as e:
347
398
  raise RuntimeError(f"Could not reach API: {e}") from e
348
399
 
349
- # Validate API response has required SSH fields
350
- assert "ssh_host" in attach_info, "API response must contain ssh_host"
351
- assert "ssh_port" in attach_info, "API response must contain ssh_port"
352
- assert "ssh_user" in attach_info, "API response must contain ssh_user"
353
- assert "private_key_pem" in attach_info, "API response must contain private_key_pem"
354
-
355
400
  if json_output:
356
- return json.dumps(attach_info, indent=2)
357
-
358
- # Write private key to temp file and generate SSH config
359
- ssh_host = attach_info["ssh_host"]
360
- ssh_port = attach_info["ssh_port"]
361
- ssh_user = attach_info["ssh_user"]
362
- private_key = attach_info["private_key_pem"]
363
-
364
- # Validate field values before using them
365
- assert ssh_host, "ssh_host must be non-empty"
366
- assert isinstance(ssh_port, int), "ssh_port must be an integer"
367
- assert ssh_port > 0, "ssh_port must be positive"
368
- assert ssh_user, "ssh_user must be non-empty"
369
- assert private_key, "private_key_pem must be non-empty"
370
-
371
- # Save private key
372
- key_dir = Path.home() / ".wafer" / "keys"
373
- key_dir.mkdir(parents=True, exist_ok=True)
374
- key_path = key_dir / f"{workspace_id}.pem"
375
- key_path.write_text(private_key)
376
- key_path.chmod(0o600)
377
-
378
- # Generate SSH config entry
379
- config_entry = f"""
380
- # Wafer workspace: {workspace_id}
381
- Host wafer-{workspace_id}
382
- HostName {ssh_host}
383
- Port {ssh_port}
384
- User {ssh_user}
385
- IdentityFile {key_path}
386
- StrictHostKeyChecking no
387
- UserKnownHostsFile /dev/null
388
- """
389
-
390
- lines = [
391
- f"Attached to workspace: {workspace_id}",
392
- "",
393
- "SSH Connection:",
394
- f" ssh -i {key_path} -p {ssh_port} {ssh_user}@{ssh_host}",
395
- "",
396
- "Or add to ~/.ssh/config:",
397
- config_entry,
398
- f"Then connect with: ssh wafer-{workspace_id}",
399
- ]
400
-
401
- return "\n".join(lines)
402
-
403
-
404
- def get_ssh_credentials(workspace_id: str) -> tuple[SSHCredentials, str]:
405
- """Get SSH credentials for a workspace.
406
-
407
- Calls attach API, saves key file, returns credentials for SSH.
401
+ return json.dumps(result, indent=2)
408
402
 
409
- Args:
410
- workspace_id: Workspace ID or name
403
+ return f"Deleted workspace: {workspace_id}"
411
404
 
412
- Returns:
413
- Tuple of (SSHCredentials, resolved_workspace_id)
414
405
 
415
- Raises:
416
- RuntimeError: If attach fails
417
- """
418
- assert workspace_id, "Workspace ID must be non-empty"
419
-
420
- api_url, headers = _get_client()
421
-
422
- try:
423
- with httpx.Client(timeout=120.0, headers=headers) as client:
424
- response = client.post(f"{api_url}/v1/workspaces/{workspace_id}/attach")
425
- response.raise_for_status()
426
- attach_info = response.json()
427
- except httpx.HTTPStatusError as e:
428
- if e.response.status_code == 401:
429
- raise RuntimeError("Not authenticated. Run: wafer login") from e
430
- if e.response.status_code == 404:
431
- raise RuntimeError(f"Workspace not found: {workspace_id}") from e
432
- if e.response.status_code == 503:
433
- raise RuntimeError("No GPU available. Please try again later.") from e
434
- raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
435
- except httpx.RequestError as e:
436
- raise RuntimeError(f"Could not reach API: {e}") from e
437
-
438
- # Validate and extract fields
439
- resolved_id = attach_info.get("workspace_id", workspace_id)
440
- ssh_host = attach_info.get("ssh_host")
441
- ssh_port = attach_info.get("ssh_port")
442
- ssh_user = attach_info.get("ssh_user")
443
- private_key = attach_info.get("private_key_pem")
444
-
445
- assert ssh_host, "API response must contain ssh_host"
446
- assert isinstance(ssh_port, int) and ssh_port > 0, "ssh_port must be positive integer"
447
- assert ssh_user, "API response must contain ssh_user"
448
- assert private_key, "API response must contain private_key_pem"
449
-
450
- # Save private key using resolved ID
451
- key_dir = Path.home() / ".wafer" / "keys"
452
- key_dir.mkdir(parents=True, exist_ok=True)
453
- key_path = key_dir / f"{resolved_id}.pem"
454
- key_path.write_text(private_key)
455
- key_path.chmod(0o600)
456
-
457
- return SSHCredentials(
458
- host=ssh_host,
459
- port=ssh_port,
460
- user=ssh_user,
461
- key_path=key_path,
462
- ), resolved_id
463
406
 
464
407
 
465
408
  def sync_files(
@@ -492,9 +435,23 @@ def sync_files(
492
435
  assert workspace_id, "Workspace ID must be non-empty"
493
436
  assert local_path.exists(), f"Path not found: {local_path}"
494
437
 
495
- # Get SSH credentials (also wakes up workspace if needed)
496
- # resolved_id is the UUID, workspace_id might be a name
497
- creds, resolved_id = get_ssh_credentials(workspace_id)
438
+ ws = get_workspace_raw(workspace_id)
439
+ resolved_id = ws["id"]
440
+ workspace_status = ws.get("status")
441
+ assert workspace_status in VALID_STATUSES, (
442
+ f"Workspace {workspace_id} has invalid status '{workspace_status}'. "
443
+ f"Valid statuses: {VALID_STATUSES}"
444
+ )
445
+ if workspace_status != "running":
446
+ raise RuntimeError(
447
+ f"Workspace is {workspace_status}. Wait for it to be running before syncing."
448
+ )
449
+ ssh_host = ws.get("ssh_host")
450
+ ssh_port = ws.get("ssh_port")
451
+ ssh_user = ws.get("ssh_user")
452
+ assert ssh_host, "Workspace missing ssh_host"
453
+ assert isinstance(ssh_port, int) and ssh_port > 0, "Workspace missing valid ssh_port"
454
+ assert ssh_user, "Workspace missing ssh_user"
498
455
 
499
456
  # Build rsync command
500
457
  # -a: archive mode (preserves permissions, etc.)
@@ -507,13 +464,17 @@ def sync_files(
507
464
  # Single file: sync the file itself
508
465
  source = str(local_path)
509
466
 
467
+ # Build SSH command for rsync
468
+ # If key_path is None (BYOK model), SSH will use default key from ~/.ssh/
469
+ ssh_opts = f"-p {ssh_port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
470
+
510
471
  rsync_cmd = [
511
472
  "rsync",
512
473
  "-avz",
513
474
  "-e",
514
- f"ssh -i {creds.key_path} -p {creds.port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null",
475
+ f"ssh {ssh_opts}",
515
476
  source,
516
- f"{creds.user}@{creds.host}:/workspace/",
477
+ f"{ssh_user}@{ssh_host}:/workspace/",
517
478
  ]
518
479
 
519
480
  try:
@@ -589,16 +550,8 @@ def _init_sync_state(workspace_id: str) -> str | None:
589
550
  return None
590
551
 
591
552
 
592
- def get_workspace(workspace_id: str, json_output: bool = False) -> str:
593
- """Get details of a specific workspace.
594
-
595
- Args:
596
- workspace_id: Workspace ID to get
597
- json_output: If True, return raw JSON; otherwise return formatted text
598
-
599
- Returns:
600
- Workspace details as string
601
- """
553
+ def get_workspace_raw(workspace_id: str) -> dict:
554
+ """Get workspace details as raw JSON dict."""
602
555
  assert workspace_id, "Workspace ID must be non-empty"
603
556
 
604
557
  api_url, headers = _get_client()
@@ -617,9 +570,29 @@ def get_workspace(workspace_id: str, json_output: bool = False) -> str:
617
570
  except httpx.RequestError as e:
618
571
  raise RuntimeError(f"Could not reach API: {e}") from e
619
572
 
620
- # Validate API response has required fields
621
573
  assert "id" in workspace, "API response must contain workspace id"
622
574
  assert "name" in workspace, "API response must contain workspace name"
575
+
576
+ status = workspace.get("status", "unknown")
577
+ assert status in VALID_STATUSES or status == "unknown", (
578
+ f"Workspace {workspace['id']} has invalid status '{status}'. "
579
+ f"Valid statuses: {VALID_STATUSES}"
580
+ )
581
+
582
+ return workspace
583
+
584
+
585
+ def get_workspace(workspace_id: str, json_output: bool = False) -> str:
586
+ """Get details of a specific workspace.
587
+
588
+ Args:
589
+ workspace_id: Workspace ID to get
590
+ json_output: If True, return raw JSON; otherwise return formatted text
591
+
592
+ Returns:
593
+ Workspace details as string
594
+ """
595
+ workspace = get_workspace_raw(workspace_id)
623
596
 
624
597
  if json_output:
625
598
  return json.dumps(workspace, indent=2)
@@ -643,15 +616,39 @@ def get_workspace(workspace_id: str, json_output: bool = False) -> str:
643
616
  f" Port: {workspace.get('ssh_port', 22)}",
644
617
  f" User: {workspace.get('ssh_user', 'root')}",
645
618
  ])
646
- elif status in ("stopped", "created"):
647
- lines.extend([
648
- "",
649
- "SSH: Run 'wafer workspaces attach' to get SSH credentials",
650
- ])
619
+ elif status == "creating":
620
+ lines.extend(["", "SSH: available once workspace is running"])
651
621
 
652
622
  return "\n".join(lines)
653
623
 
654
624
 
625
+ def _handle_sync_event(sync_type: str) -> None:
626
+ """Handle sync events and print status to stderr.
627
+
628
+ Sync events:
629
+ - FORWARD:START - Starting workspace → GPU sync
630
+ - FORWARD:DONE:N - Synced N files to GPU
631
+ - FORWARD:WARN:msg - Warning during forward sync
632
+ - REVERSE:START - Starting GPU → workspace sync
633
+ - REVERSE:DONE:N - Synced N artifacts back
634
+ """
635
+ import sys
636
+
637
+ if sync_type == "FORWARD:START":
638
+ print("[sync] Syncing workspace → GPU...", end="", file=sys.stderr, flush=True)
639
+ elif sync_type.startswith("FORWARD:DONE:"):
640
+ count = sync_type.split(":")[-1]
641
+ print(f" done ({count} files)", file=sys.stderr)
642
+ elif sync_type.startswith("FORWARD:WARN:"):
643
+ msg = sync_type[13:] # Remove "FORWARD:WARN:"
644
+ print(f" warning: {msg}", file=sys.stderr)
645
+ elif sync_type == "REVERSE:START":
646
+ print("[sync] Syncing artifacts back...", end="", file=sys.stderr, flush=True)
647
+ elif sync_type.startswith("REVERSE:DONE:"):
648
+ count = sync_type.split(":")[-1]
649
+ print(f" done ({count} files)", file=sys.stderr)
650
+
651
+
655
652
  @dataclass(frozen=True)
656
653
  class SSEEvent:
657
654
  """Parsed SSE event result."""
@@ -659,6 +656,7 @@ class SSEEvent:
659
656
  output: str | None # Content to print (None = no output)
660
657
  exit_code: int | None # Exit code if stream should end (None = continue)
661
658
  is_error: bool # Whether output goes to stderr
659
+ sync_event: str | None = None # Sync event type (e.g., "FORWARD:START")
662
660
 
663
661
 
664
662
  def _parse_sse_content(content: str) -> SSEEvent:
@@ -680,6 +678,16 @@ def _parse_sse_content(content: str) -> SSEEvent:
680
678
  if content.startswith("[ERROR]"):
681
679
  return SSEEvent(output=content[8:], exit_code=1, is_error=True)
682
680
 
681
+ # Sync events: [SYNC:FORWARD:START], [SYNC:FORWARD:DONE:5], etc.
682
+ if content.startswith("[SYNC:"):
683
+ # Extract sync type (e.g., "FORWARD:START" or "REVERSE:DONE:5")
684
+ sync_type = content[6:-1] # Remove [SYNC: and ]
685
+ return SSEEvent(output=None, exit_code=None, is_error=False, sync_event=sync_type)
686
+
687
+ # Status events we can ignore (already handled elsewhere)
688
+ if content.startswith("[STATUS:") or content.startswith("[CONTEXT:"):
689
+ return SSEEvent(output=None, exit_code=None, is_error=False)
690
+
683
691
  # Regular output
684
692
  return SSEEvent(output=content, exit_code=None, is_error=False)
685
693
 
@@ -688,13 +696,16 @@ def exec_command(
688
696
  workspace_id: str,
689
697
  command: str,
690
698
  timeout_seconds: int | None = None,
699
+ routing: str | None = None,
700
+ pull_image: bool = False,
691
701
  ) -> int:
692
- """Execute a command in workspace with GPU routing, streaming output.
702
+ """Execute a command in workspace, streaming output.
693
703
 
694
704
  Args:
695
705
  workspace_id: Workspace ID or name
696
706
  command: Command to execute
697
707
  timeout_seconds: Execution timeout (default: 300, from config)
708
+ routing: Routing hint - "auto", "gpu", "cpu", or "baremetal" (default: auto)
698
709
 
699
710
  Returns:
700
711
  Exit code (0 = success, non-zero = failure)
@@ -710,10 +721,14 @@ def exec_command(
710
721
  # Base64 encode command to avoid escaping issues
711
722
  command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
712
723
 
713
- request_body: dict = {"command_b64": command_b64}
724
+ request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
714
725
  if timeout_seconds:
715
726
  request_body["timeout_seconds"] = timeout_seconds
716
727
 
728
+ # Add routing hint if specified
729
+ if routing:
730
+ request_body["requirements"] = {"routing": routing}
731
+
717
732
  try:
718
733
  # Use streaming request for SSE output
719
734
  with httpx.Client(timeout=None, headers=headers) as client:
@@ -736,6 +751,11 @@ def exec_command(
736
751
 
737
752
  event = _parse_sse_content(line[6:])
738
753
 
754
+ # Handle sync events - display status to stderr
755
+ if event.sync_event:
756
+ _handle_sync_event(event.sync_event)
757
+ continue
758
+
739
759
  if event.output is not None:
740
760
  print(event.output, file=sys.stderr if event.is_error else sys.stdout)
741
761
 
@@ -751,3 +771,83 @@ def exec_command(
751
771
  ) from e
752
772
  except httpx.RequestError as e:
753
773
  raise RuntimeError(f"Could not reach API: {e}") from e
774
+
775
+
776
+ def exec_command_capture(
777
+ workspace_id: str,
778
+ command: str,
779
+ timeout_seconds: int | None = None,
780
+ routing: str | None = None,
781
+ pull_image: bool = False,
782
+ ) -> tuple[int, str]:
783
+ """Execute a command in workspace and capture output.
784
+
785
+ Similar to exec_command but returns output as string instead of printing.
786
+
787
+ Args:
788
+ workspace_id: Workspace ID or name
789
+ command: Command to execute
790
+ timeout_seconds: Execution timeout (default: 300)
791
+ routing: Routing hint - "auto", "gpu", "cpu", or "baremetal"
792
+
793
+ Returns:
794
+ Tuple of (exit_code, output_string)
795
+ """
796
+ import base64
797
+
798
+ assert workspace_id, "Workspace ID must be non-empty"
799
+ assert command, "Command must be non-empty"
800
+
801
+ api_url, headers = _get_client()
802
+
803
+ # Base64 encode command to avoid escaping issues
804
+ command_b64 = base64.b64encode(command.encode("utf-8")).decode("utf-8")
805
+
806
+ request_body: dict = {"command_b64": command_b64, "pull_image": pull_image}
807
+ if timeout_seconds:
808
+ request_body["timeout_seconds"] = timeout_seconds
809
+
810
+ if routing:
811
+ request_body["requirements"] = {"routing": routing}
812
+
813
+ output_lines: list[str] = []
814
+
815
+ try:
816
+ with httpx.Client(timeout=None, headers=headers) as client:
817
+ with client.stream(
818
+ "POST",
819
+ f"{api_url}/v1/workspaces/{workspace_id}/exec",
820
+ json=request_body,
821
+ ) as response:
822
+ if response.status_code != 200:
823
+ error_body = response.read().decode("utf-8", errors="replace")
824
+ raise RuntimeError(
825
+ _friendly_error(response.status_code, error_body, workspace_id)
826
+ )
827
+
828
+ exit_code = 0
829
+ for line in response.iter_lines():
830
+ if not line or not line.startswith("data: "):
831
+ continue
832
+
833
+ event = _parse_sse_content(line[6:])
834
+
835
+ # Skip sync events
836
+ if event.sync_event:
837
+ continue
838
+
839
+ if event.output is not None:
840
+ output_lines.append(event.output)
841
+
842
+ if event.exit_code is not None:
843
+ exit_code = event.exit_code
844
+ break
845
+
846
+ return exit_code, "\n".join(output_lines)
847
+
848
+ except httpx.HTTPStatusError as e:
849
+ raise RuntimeError(
850
+ _friendly_error(e.response.status_code, e.response.text, workspace_id)
851
+ ) from e
852
+ except httpx.RequestError as e:
853
+ raise RuntimeError(f"Could not reach API: {e}") from e
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: typer>=0.12.0