wafer-cli 0.2.14__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/auth.py +7 -0
- wafer/billing.py +6 -6
- wafer/cli.py +905 -131
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +313 -15
- wafer/evaluate.py +480 -146
- wafer/global_config.py +13 -0
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/specs_cli.py +157 -0
- wafer/ssh_keys.py +6 -6
- wafer/targets_cli.py +472 -0
- wafer/targets_ops.py +29 -2
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +3 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +274 -0
- wafer/wevin_cli.py +125 -26
- wafer/workspaces.py +163 -16
- wafer_cli-0.2.30.dist-info/METADATA +107 -0
- wafer_cli-0.2.30.dist-info/RECORD +47 -0
- wafer_cli-0.2.14.dist-info/METADATA +0 -16
- wafer_cli-0.2.14.dist-info/RECORD +0 -41
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/top_level.txt +0 -0
wafer/wevin_cli.py
CHANGED
|
@@ -15,6 +15,8 @@ from pathlib import Path
|
|
|
15
15
|
from typing import TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Awaitable, Callable
|
|
19
|
+
|
|
18
20
|
from wafer_core.rollouts import Endpoint, Environment
|
|
19
21
|
from wafer_core.rollouts.dtypes import StreamEvent, ToolCall
|
|
20
22
|
from wafer_core.rollouts.templates import TemplateConfig
|
|
@@ -145,21 +147,60 @@ class StreamingChunkFrontend:
|
|
|
145
147
|
pass
|
|
146
148
|
|
|
147
149
|
|
|
148
|
-
def
|
|
150
|
+
def _make_wafer_token_refresh() -> Callable[[], Awaitable[str | None]]:
|
|
151
|
+
"""Create an async callback that refreshes the wafer proxy token via Supabase."""
|
|
152
|
+
from .auth import load_credentials, refresh_access_token, save_credentials
|
|
153
|
+
|
|
154
|
+
async def _refresh() -> str | None:
|
|
155
|
+
creds = load_credentials()
|
|
156
|
+
if not creds or not creds.refresh_token:
|
|
157
|
+
return None
|
|
158
|
+
try:
|
|
159
|
+
new_access, new_refresh = refresh_access_token(creds.refresh_token)
|
|
160
|
+
save_credentials(new_access, new_refresh, creds.email)
|
|
161
|
+
return new_access
|
|
162
|
+
except Exception:
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
return _refresh
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _get_wafer_auth(
|
|
169
|
+
*, no_proxy: bool = False
|
|
170
|
+
) -> tuple[str | None, str | None, Callable[[], Awaitable[str | None]] | None]:
|
|
149
171
|
"""Get wafer auth credentials with fallback chain.
|
|
150
172
|
|
|
151
173
|
Returns:
|
|
152
|
-
(api_base, api_key) or (None, None) if no auth found
|
|
174
|
+
(api_base, api_key, api_key_refresh) or (None, None, None) if no auth found.
|
|
175
|
+
api_key_refresh is an async callback for mid-session token refresh (only set
|
|
176
|
+
when using wafer proxy via credentials file).
|
|
153
177
|
"""
|
|
154
178
|
from .auth import get_valid_token, load_credentials
|
|
155
179
|
from .global_config import get_api_url
|
|
156
180
|
|
|
181
|
+
if no_proxy:
|
|
182
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
183
|
+
if not api_key:
|
|
184
|
+
# Try auth.json stored key
|
|
185
|
+
from wafer_core.auth import get_api_key
|
|
186
|
+
|
|
187
|
+
api_key = get_api_key("anthropic") or ""
|
|
188
|
+
if api_key:
|
|
189
|
+
print("🔑 Using ANTHROPIC_API_KEY (--no-proxy)\n", file=sys.stderr)
|
|
190
|
+
return "https://api.anthropic.com", api_key, None
|
|
191
|
+
print(
|
|
192
|
+
"❌ --no-proxy requires ANTHROPIC_API_KEY env var or `wafer auth login anthropic`\n",
|
|
193
|
+
file=sys.stderr,
|
|
194
|
+
)
|
|
195
|
+
return None, None, None
|
|
196
|
+
|
|
157
197
|
# Check WAFER_AUTH_TOKEN env var first
|
|
158
198
|
wafer_token = os.environ.get("WAFER_AUTH_TOKEN", "")
|
|
159
199
|
token_source = "WAFER_AUTH_TOKEN" if wafer_token else None
|
|
160
200
|
|
|
161
201
|
# Try credentials file with automatic refresh
|
|
162
202
|
had_credentials = False
|
|
203
|
+
uses_credentials_file = False
|
|
163
204
|
if not wafer_token:
|
|
164
205
|
try:
|
|
165
206
|
creds = load_credentials()
|
|
@@ -169,12 +210,16 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
169
210
|
wafer_token = get_valid_token()
|
|
170
211
|
if wafer_token:
|
|
171
212
|
token_source = "~/.wafer/credentials.json"
|
|
213
|
+
uses_credentials_file = True
|
|
172
214
|
|
|
173
215
|
# If we have a valid wafer token, use it
|
|
174
216
|
if wafer_token:
|
|
175
217
|
api_url = get_api_url()
|
|
176
218
|
print(f"🔑 Using wafer proxy ({token_source})\n", file=sys.stderr)
|
|
177
|
-
|
|
219
|
+
# Only provide refresh callback when token came from credentials file
|
|
220
|
+
# (env var tokens are managed externally)
|
|
221
|
+
refresh = _make_wafer_token_refresh() if uses_credentials_file else None
|
|
222
|
+
return f"{api_url}/v1/anthropic", wafer_token, refresh
|
|
178
223
|
|
|
179
224
|
# Fall back to direct anthropic
|
|
180
225
|
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
@@ -186,9 +231,9 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
186
231
|
)
|
|
187
232
|
else:
|
|
188
233
|
print("🔑 Using ANTHROPIC_API_KEY\n", file=sys.stderr)
|
|
189
|
-
return "https://api.anthropic.com", api_key
|
|
234
|
+
return "https://api.anthropic.com", api_key, None
|
|
190
235
|
|
|
191
|
-
return None, None
|
|
236
|
+
return None, None, None
|
|
192
237
|
|
|
193
238
|
|
|
194
239
|
def _get_session_preview(session: object) -> str:
|
|
@@ -205,10 +250,22 @@ def _get_session_preview(session: object) -> str:
|
|
|
205
250
|
return ""
|
|
206
251
|
|
|
207
252
|
|
|
253
|
+
def _get_log_file_path() -> Path:
|
|
254
|
+
"""Get user-specific log file path, creating directory if needed.
|
|
255
|
+
|
|
256
|
+
Uses ~/.wafer/logs/ to avoid permission issues with shared /tmp.
|
|
257
|
+
"""
|
|
258
|
+
log_dir = Path.home() / ".wafer" / "logs"
|
|
259
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
260
|
+
return log_dir / "wevin_debug.log"
|
|
261
|
+
|
|
262
|
+
|
|
208
263
|
def _setup_logging() -> None:
|
|
209
264
|
"""Configure logging to file only (no console spam)."""
|
|
210
265
|
import logging.config
|
|
211
266
|
|
|
267
|
+
log_file = _get_log_file_path()
|
|
268
|
+
|
|
212
269
|
logging.config.dictConfig({
|
|
213
270
|
"version": 1,
|
|
214
271
|
"disable_existing_loggers": False,
|
|
@@ -220,7 +277,7 @@ def _setup_logging() -> None:
|
|
|
220
277
|
"handlers": {
|
|
221
278
|
"file": {
|
|
222
279
|
"class": "logging.handlers.RotatingFileHandler",
|
|
223
|
-
"filename":
|
|
280
|
+
"filename": str(log_file),
|
|
224
281
|
"maxBytes": 10_000_000,
|
|
225
282
|
"backupCount": 3,
|
|
226
283
|
"formatter": "json",
|
|
@@ -243,6 +300,7 @@ def _build_endpoint(
|
|
|
243
300
|
model_override: str | None,
|
|
244
301
|
api_base: str,
|
|
245
302
|
api_key: str,
|
|
303
|
+
api_key_refresh: Callable[[], Awaitable[str | None]] | None = None,
|
|
246
304
|
) -> Endpoint:
|
|
247
305
|
"""Build an Endpoint from template config and auth."""
|
|
248
306
|
from wafer_core.rollouts import Endpoint
|
|
@@ -257,6 +315,7 @@ def _build_endpoint(
|
|
|
257
315
|
model=model_id,
|
|
258
316
|
api_base=api_base,
|
|
259
317
|
api_key=api_key,
|
|
318
|
+
api_key_refresh=api_key_refresh,
|
|
260
319
|
thinking=thinking_config,
|
|
261
320
|
max_tokens=tpl.max_tokens,
|
|
262
321
|
)
|
|
@@ -266,18 +325,27 @@ def _build_environment(
|
|
|
266
325
|
tpl: TemplateConfig,
|
|
267
326
|
tools_override: list[str] | None,
|
|
268
327
|
corpus_path: str | None,
|
|
328
|
+
no_sandbox: bool = False,
|
|
269
329
|
) -> Environment:
|
|
270
330
|
"""Build a CodingEnvironment from template config."""
|
|
271
331
|
from wafer_core.environments.coding import CodingEnvironment
|
|
272
332
|
from wafer_core.rollouts.templates import DANGEROUS_BASH_COMMANDS
|
|
333
|
+
from wafer_core.sandbox import SandboxMode
|
|
273
334
|
|
|
274
335
|
working_dir = Path(corpus_path) if corpus_path else Path.cwd()
|
|
275
|
-
resolved_tools = tools_override or tpl.tools
|
|
336
|
+
resolved_tools = list(tools_override or tpl.tools)
|
|
337
|
+
|
|
338
|
+
# Add skill tool if skills are enabled
|
|
339
|
+
if tpl.include_skills and "skill" not in resolved_tools:
|
|
340
|
+
resolved_tools.append("skill")
|
|
341
|
+
|
|
342
|
+
sandbox_mode = SandboxMode.DISABLED if no_sandbox else SandboxMode.ENABLED
|
|
276
343
|
env: Environment = CodingEnvironment(
|
|
277
344
|
working_dir=working_dir,
|
|
278
345
|
enabled_tools=resolved_tools,
|
|
279
346
|
bash_allowlist=tpl.bash_allowlist,
|
|
280
347
|
bash_denylist=DANGEROUS_BASH_COMMANDS,
|
|
348
|
+
sandbox_mode=sandbox_mode,
|
|
281
349
|
) # type: ignore[assignment]
|
|
282
350
|
return env
|
|
283
351
|
|
|
@@ -362,6 +430,8 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
362
430
|
list_sessions: bool = False,
|
|
363
431
|
get_session: str | None = None,
|
|
364
432
|
json_output: bool = False,
|
|
433
|
+
no_sandbox: bool = False,
|
|
434
|
+
no_proxy: bool = False,
|
|
365
435
|
) -> None:
|
|
366
436
|
"""Run wevin agent in-process via rollouts."""
|
|
367
437
|
from dataclasses import asdict
|
|
@@ -373,6 +443,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
373
443
|
|
|
374
444
|
# Handle --get-session: load session by ID and print
|
|
375
445
|
if get_session:
|
|
446
|
+
|
|
376
447
|
async def _get_session() -> None:
|
|
377
448
|
try:
|
|
378
449
|
session, err = await session_store.get(get_session)
|
|
@@ -393,16 +464,18 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
393
464
|
error_msg = f"Failed to serialize messages: {e}"
|
|
394
465
|
print(json.dumps({"error": error_msg}))
|
|
395
466
|
sys.exit(1)
|
|
396
|
-
|
|
397
|
-
print(
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
467
|
+
|
|
468
|
+
print(
|
|
469
|
+
json.dumps({
|
|
470
|
+
"session_id": session.session_id,
|
|
471
|
+
"status": session.status.value,
|
|
472
|
+
"model": session.endpoint.model if session.endpoint else None,
|
|
473
|
+
"created_at": session.created_at,
|
|
474
|
+
"updated_at": session.updated_at,
|
|
475
|
+
"messages": messages_data,
|
|
476
|
+
"tags": session.tags,
|
|
477
|
+
})
|
|
478
|
+
)
|
|
406
479
|
else:
|
|
407
480
|
print(f"Session: {session.session_id}")
|
|
408
481
|
print(f"Status: {session.status.value}")
|
|
@@ -474,10 +547,10 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
474
547
|
_setup_logging()
|
|
475
548
|
|
|
476
549
|
# Auth
|
|
477
|
-
api_base, api_key = _get_wafer_auth()
|
|
550
|
+
api_base, api_key, api_key_refresh = _get_wafer_auth(no_proxy=no_proxy)
|
|
478
551
|
if not api_base or not api_key:
|
|
479
552
|
print("Error: No API credentials found", file=sys.stderr)
|
|
480
|
-
print(" Run 'wafer login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
553
|
+
print(" Run 'wafer auth login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
481
554
|
sys.exit(1)
|
|
482
555
|
|
|
483
556
|
assert api_base is not None
|
|
@@ -490,7 +563,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
490
563
|
print(f"Error loading template: {err}", file=sys.stderr)
|
|
491
564
|
sys.exit(1)
|
|
492
565
|
tpl = loaded_template
|
|
493
|
-
|
|
566
|
+
base_system_prompt = tpl.interpolate_prompt(template_args or {})
|
|
494
567
|
# Show template info when starting without a prompt
|
|
495
568
|
if not prompt and tpl.description:
|
|
496
569
|
print(f"Template: {tpl.name}", file=sys.stderr)
|
|
@@ -498,14 +571,38 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
498
571
|
print(file=sys.stderr)
|
|
499
572
|
else:
|
|
500
573
|
tpl = _get_default_template()
|
|
501
|
-
|
|
574
|
+
base_system_prompt = tpl.system_prompt
|
|
575
|
+
|
|
576
|
+
# Compose CLI instructions from --help text for allowed wafer commands
|
|
577
|
+
# TODO: The eval path doesn't have the skills layer below. If include_skills
|
|
578
|
+
# is ever enabled for optimize-kernelbench, the eval would need it too for parity.
|
|
579
|
+
# See test_eval_cli_parity.py for coverage notes.
|
|
580
|
+
if tpl.bash_allowlist:
|
|
581
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
582
|
+
|
|
583
|
+
cli_instructions = build_cli_instructions(tpl.bash_allowlist)
|
|
584
|
+
if cli_instructions:
|
|
585
|
+
base_system_prompt = base_system_prompt + "\n\n" + cli_instructions
|
|
586
|
+
|
|
587
|
+
# Append skill metadata if skills are enabled
|
|
588
|
+
if tpl.include_skills:
|
|
589
|
+
from wafer_core.rollouts.skills import discover_skills, format_skill_metadata_for_prompt
|
|
590
|
+
|
|
591
|
+
skill_metadata = discover_skills()
|
|
592
|
+
if skill_metadata:
|
|
593
|
+
skill_section = format_skill_metadata_for_prompt(skill_metadata)
|
|
594
|
+
system_prompt = base_system_prompt + "\n\n" + skill_section
|
|
595
|
+
else:
|
|
596
|
+
system_prompt = base_system_prompt
|
|
597
|
+
else:
|
|
598
|
+
system_prompt = base_system_prompt
|
|
502
599
|
|
|
503
600
|
# CLI args override template values
|
|
504
601
|
resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
|
|
505
602
|
|
|
506
603
|
# Build endpoint and environment
|
|
507
|
-
endpoint = _build_endpoint(tpl, model, api_base, api_key)
|
|
508
|
-
environment = _build_environment(tpl, tools, corpus_path)
|
|
604
|
+
endpoint = _build_endpoint(tpl, model, api_base, api_key, api_key_refresh)
|
|
605
|
+
environment = _build_environment(tpl, tools, corpus_path, no_sandbox)
|
|
509
606
|
|
|
510
607
|
# Session store
|
|
511
608
|
session_store = FileSessionStore()
|
|
@@ -545,7 +642,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
545
642
|
else:
|
|
546
643
|
if json_output:
|
|
547
644
|
# Emit session_start if we have a session_id (from --resume)
|
|
548
|
-
model_name = endpoint.model if hasattr(endpoint,
|
|
645
|
+
model_name = endpoint.model if hasattr(endpoint, "model") else None
|
|
549
646
|
frontend = StreamingChunkFrontend(session_id=session_id, model=model_name)
|
|
550
647
|
else:
|
|
551
648
|
frontend = NoneFrontend(show_tool_calls=True, show_thinking=False)
|
|
@@ -560,9 +657,11 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
560
657
|
# Emit session_start for new sessions (if session_id was None and we got one)
|
|
561
658
|
# Check first state to emit as early as possible
|
|
562
659
|
if json_output and isinstance(frontend, StreamingChunkFrontend):
|
|
563
|
-
first_session_id =
|
|
660
|
+
first_session_id = (
|
|
661
|
+
states[0].session_id if states and states[0].session_id else None
|
|
662
|
+
)
|
|
564
663
|
if first_session_id and not session_id: # New session created
|
|
565
|
-
model_name = endpoint.model if hasattr(endpoint,
|
|
664
|
+
model_name = endpoint.model if hasattr(endpoint, "model") else None
|
|
566
665
|
frontend.emit_session_start(first_session_id, model_name)
|
|
567
666
|
# Print resume command with full wafer agent prefix
|
|
568
667
|
if states and states[-1].session_id:
|
wafer/workspaces.py
CHANGED
|
@@ -13,7 +13,7 @@ import httpx
|
|
|
13
13
|
from .api_client import get_api_url
|
|
14
14
|
from .auth import get_auth_headers
|
|
15
15
|
|
|
16
|
-
VALID_STATUSES = {"creating", "running"}
|
|
16
|
+
VALID_STATUSES = {"creating", "running", "error"}
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def _get_client() -> tuple[str, dict[str, str]]:
|
|
@@ -39,13 +39,13 @@ def _friendly_error(status_code: int, response_text: str, workspace_id: str) ->
|
|
|
39
39
|
User-friendly error message with suggested next steps
|
|
40
40
|
"""
|
|
41
41
|
if status_code == 401:
|
|
42
|
-
return "Not authenticated. Run: wafer login"
|
|
42
|
+
return "Not authenticated. Run: wafer auth login"
|
|
43
43
|
|
|
44
44
|
if status_code == 402:
|
|
45
45
|
return (
|
|
46
46
|
"Insufficient credits.\n"
|
|
47
|
-
" Check usage: wafer billing\n"
|
|
48
|
-
" Add credits: wafer billing topup"
|
|
47
|
+
" Check usage: wafer config billing\n"
|
|
48
|
+
" Add credits: wafer config billing topup"
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
if status_code == 404:
|
|
@@ -107,7 +107,7 @@ def _list_workspaces_raw() -> list[dict]:
|
|
|
107
107
|
workspaces = response.json()
|
|
108
108
|
except httpx.HTTPStatusError as e:
|
|
109
109
|
if e.response.status_code == 401:
|
|
110
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
110
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
111
111
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
112
112
|
except httpx.RequestError as e:
|
|
113
113
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
@@ -188,7 +188,7 @@ def list_workspaces(json_output: bool = False) -> str:
|
|
|
188
188
|
workspaces = response.json()
|
|
189
189
|
except httpx.HTTPStatusError as e:
|
|
190
190
|
if e.response.status_code == 401:
|
|
191
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
191
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
192
192
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
193
193
|
except httpx.RequestError as e:
|
|
194
194
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
@@ -211,17 +211,39 @@ def list_workspaces(json_output: bool = False) -> str:
|
|
|
211
211
|
lines = ["Workspaces:", ""]
|
|
212
212
|
for ws in workspaces:
|
|
213
213
|
status = ws.get("status", "unknown")
|
|
214
|
-
status_icon = {"running": "●", "creating": "◐"}.get(status, "?")
|
|
214
|
+
status_icon = {"running": "●", "creating": "◐", "error": "✗"}.get(status, "?")
|
|
215
215
|
lines.append(f" {status_icon} {ws['name']} ({ws['id']})")
|
|
216
216
|
lines.append(f" GPU: {ws.get('gpu_type', 'N/A')} | Image: {ws.get('image', 'N/A')}")
|
|
217
|
-
|
|
217
|
+
|
|
218
|
+
if status == "error":
|
|
219
|
+
lines.append(
|
|
220
|
+
f" Status: Provisioning failed. Delete and recreate: wafer workspaces delete {ws['name']}"
|
|
221
|
+
)
|
|
222
|
+
elif ws.get("ssh_host") and ws.get("ssh_port") and ws.get("ssh_user"):
|
|
223
|
+
ssh_line = f" SSH: ssh -p {ws['ssh_port']} {ws['ssh_user']}@{ws['ssh_host']}"
|
|
224
|
+
if status == "creating":
|
|
225
|
+
ssh_line += " (finalizing...)"
|
|
226
|
+
lines.append(ssh_line)
|
|
227
|
+
elif status == "running":
|
|
218
228
|
lines.append(
|
|
219
|
-
f" SSH:
|
|
229
|
+
f" Status: Running but SSH not ready. Try: wafer workspaces delete {ws['name']} && wafer workspaces create {ws['name']} --wait"
|
|
220
230
|
)
|
|
221
231
|
else:
|
|
222
|
-
lines.append(" SSH: Not ready (
|
|
232
|
+
lines.append(" SSH: Not ready (workspace is still creating)")
|
|
223
233
|
lines.append("")
|
|
224
234
|
|
|
235
|
+
# Add SSH tip for users with running workspaces
|
|
236
|
+
has_running_with_ssh = any(
|
|
237
|
+
ws.get("status") == "running" and ws.get("ssh_host")
|
|
238
|
+
for ws in workspaces
|
|
239
|
+
)
|
|
240
|
+
if has_running_with_ssh:
|
|
241
|
+
lines.append("Tip: SSH directly for interactive work. 'exec' is for quick commands only.")
|
|
242
|
+
|
|
243
|
+
has_error = any(ws.get("status") == "error" for ws in workspaces)
|
|
244
|
+
if has_error:
|
|
245
|
+
lines.append("Note: Error workspaces are auto-cleaned after 12 hours.")
|
|
246
|
+
|
|
225
247
|
return "\n".join(lines)
|
|
226
248
|
|
|
227
249
|
|
|
@@ -285,7 +307,7 @@ def create_workspace(
|
|
|
285
307
|
workspace = response.json()
|
|
286
308
|
except httpx.HTTPStatusError as e:
|
|
287
309
|
if e.response.status_code == 401:
|
|
288
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
310
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
289
311
|
if e.response.status_code == 400:
|
|
290
312
|
raise RuntimeError(f"Bad request: {e.response.text}") from e
|
|
291
313
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
@@ -391,7 +413,7 @@ def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
|
391
413
|
result = response.json()
|
|
392
414
|
except httpx.HTTPStatusError as e:
|
|
393
415
|
if e.response.status_code == 401:
|
|
394
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
416
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
395
417
|
if e.response.status_code == 404:
|
|
396
418
|
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
397
419
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
@@ -441,6 +463,12 @@ def sync_files(
|
|
|
441
463
|
f"Workspace {workspace_id} has invalid status '{workspace_status}'. "
|
|
442
464
|
f"Valid statuses: {VALID_STATUSES}"
|
|
443
465
|
)
|
|
466
|
+
if workspace_status == "error":
|
|
467
|
+
raise RuntimeError(
|
|
468
|
+
f"Workspace provisioning failed. Delete and recreate:\n"
|
|
469
|
+
f" wafer workspaces delete {workspace_id}\n"
|
|
470
|
+
f" wafer workspaces create {ws.get('name', workspace_id)} --wait"
|
|
471
|
+
)
|
|
444
472
|
if workspace_status != "running":
|
|
445
473
|
raise RuntimeError(
|
|
446
474
|
f"Workspace is {workspace_status}. Wait for it to be running before syncing."
|
|
@@ -448,9 +476,14 @@ def sync_files(
|
|
|
448
476
|
ssh_host = ws.get("ssh_host")
|
|
449
477
|
ssh_port = ws.get("ssh_port")
|
|
450
478
|
ssh_user = ws.get("ssh_user")
|
|
451
|
-
|
|
479
|
+
if not ssh_host or not ssh_port or not ssh_user:
|
|
480
|
+
# Workspace is running but SSH credentials are missing - unusual state
|
|
481
|
+
raise RuntimeError(
|
|
482
|
+
f"Workspace is running but SSH not ready.\n"
|
|
483
|
+
f" Delete and recreate: wafer workspaces delete {workspace_id}\n"
|
|
484
|
+
f" Then: wafer workspaces create {ws.get('name', workspace_id)} --wait"
|
|
485
|
+
)
|
|
452
486
|
assert isinstance(ssh_port, int) and ssh_port > 0, "Workspace missing valid ssh_port"
|
|
453
|
-
assert ssh_user, "Workspace missing ssh_user"
|
|
454
487
|
|
|
455
488
|
# Build rsync command
|
|
456
489
|
# -a: archive mode (preserves permissions, etc.)
|
|
@@ -509,6 +542,102 @@ def sync_files(
|
|
|
509
542
|
return file_count, warning
|
|
510
543
|
|
|
511
544
|
|
|
545
|
+
def pull_files(
|
|
546
|
+
workspace_id: str,
|
|
547
|
+
remote_path: str,
|
|
548
|
+
local_path: Path,
|
|
549
|
+
on_progress: Callable[[str], None] | None = None,
|
|
550
|
+
) -> int:
|
|
551
|
+
"""Pull files from workspace to local via rsync over SSH.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
workspace_id: Workspace ID or name
|
|
555
|
+
remote_path: Remote path in workspace (relative to /workspace or absolute)
|
|
556
|
+
local_path: Local destination path
|
|
557
|
+
on_progress: Optional callback for progress messages
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
Number of files transferred
|
|
561
|
+
|
|
562
|
+
Raises:
|
|
563
|
+
RuntimeError: If rsync fails or workspace not accessible
|
|
564
|
+
"""
|
|
565
|
+
import subprocess
|
|
566
|
+
|
|
567
|
+
def emit(msg: str) -> None:
|
|
568
|
+
if on_progress:
|
|
569
|
+
on_progress(msg)
|
|
570
|
+
|
|
571
|
+
assert workspace_id, "Workspace ID must be non-empty"
|
|
572
|
+
|
|
573
|
+
ws = get_workspace_raw(workspace_id)
|
|
574
|
+
workspace_status = ws.get("status")
|
|
575
|
+
assert workspace_status in VALID_STATUSES, (
|
|
576
|
+
f"Workspace {workspace_id} has invalid status '{workspace_status}'. "
|
|
577
|
+
f"Valid statuses: {VALID_STATUSES}"
|
|
578
|
+
)
|
|
579
|
+
if workspace_status == "error":
|
|
580
|
+
raise RuntimeError(
|
|
581
|
+
f"Workspace provisioning failed. Delete and recreate:\n"
|
|
582
|
+
f" wafer workspaces delete {workspace_id}\n"
|
|
583
|
+
f" wafer workspaces create {ws.get('name', workspace_id)} --wait"
|
|
584
|
+
)
|
|
585
|
+
if workspace_status != "running":
|
|
586
|
+
raise RuntimeError(
|
|
587
|
+
f"Workspace is {workspace_status}. Wait for it to be running before pulling files."
|
|
588
|
+
)
|
|
589
|
+
ssh_host = ws.get("ssh_host")
|
|
590
|
+
ssh_port = ws.get("ssh_port")
|
|
591
|
+
ssh_user = ws.get("ssh_user")
|
|
592
|
+
if not ssh_host or not ssh_port or not ssh_user:
|
|
593
|
+
raise RuntimeError(
|
|
594
|
+
f"Workspace is running but SSH not ready.\n"
|
|
595
|
+
f" Delete and recreate: wafer workspaces delete {workspace_id}\n"
|
|
596
|
+
f" Then: wafer workspaces create {ws.get('name', workspace_id)} --wait"
|
|
597
|
+
)
|
|
598
|
+
assert isinstance(ssh_port, int) and ssh_port > 0, "Workspace missing valid ssh_port"
|
|
599
|
+
|
|
600
|
+
# Normalize remote path - if not absolute, assume relative to /workspace
|
|
601
|
+
if not remote_path.startswith("/"):
|
|
602
|
+
remote_path = f"/workspace/{remote_path}"
|
|
603
|
+
|
|
604
|
+
# Build SSH command for rsync
|
|
605
|
+
ssh_opts = f"-p {ssh_port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
|
|
606
|
+
|
|
607
|
+
# Build rsync command (reverse of sync - from remote to local)
|
|
608
|
+
rsync_cmd = [
|
|
609
|
+
"rsync",
|
|
610
|
+
"-avz",
|
|
611
|
+
"-e",
|
|
612
|
+
f"ssh {ssh_opts}",
|
|
613
|
+
f"{ssh_user}@{ssh_host}:{remote_path}",
|
|
614
|
+
str(local_path),
|
|
615
|
+
]
|
|
616
|
+
|
|
617
|
+
emit(f"Pulling {remote_path} from workspace...")
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
result = subprocess.run(rsync_cmd, capture_output=True, text=True)
|
|
621
|
+
if result.returncode != 0:
|
|
622
|
+
raise RuntimeError(f"rsync failed: {result.stderr}")
|
|
623
|
+
|
|
624
|
+
# Count files from rsync output
|
|
625
|
+
lines = result.stdout.strip().split("\n")
|
|
626
|
+
file_count = sum(
|
|
627
|
+
1
|
|
628
|
+
for line in lines
|
|
629
|
+
if line and not line.startswith((" ", "sent", "total", "receiving", "building"))
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
except FileNotFoundError:
|
|
633
|
+
raise RuntimeError("rsync not found. Install rsync to use pull feature.") from None
|
|
634
|
+
except subprocess.SubprocessError as e:
|
|
635
|
+
raise RuntimeError(f"Pull failed: {e}") from e
|
|
636
|
+
|
|
637
|
+
emit(f"Pulled {file_count} files")
|
|
638
|
+
return file_count
|
|
639
|
+
|
|
640
|
+
|
|
512
641
|
def _init_sync_state(workspace_id: str) -> str | None:
|
|
513
642
|
"""Tell API to sync files from bare metal to Modal volume.
|
|
514
643
|
|
|
@@ -562,7 +691,7 @@ def get_workspace_raw(workspace_id: str) -> dict:
|
|
|
562
691
|
workspace = response.json()
|
|
563
692
|
except httpx.HTTPStatusError as e:
|
|
564
693
|
if e.response.status_code == 401:
|
|
565
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
694
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
566
695
|
if e.response.status_code == 404:
|
|
567
696
|
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
568
697
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
@@ -607,16 +736,34 @@ def get_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
|
607
736
|
f" Last Used: {workspace.get('last_used_at', 'N/A')}",
|
|
608
737
|
]
|
|
609
738
|
|
|
610
|
-
if
|
|
739
|
+
if status == "error":
|
|
740
|
+
lines.extend([
|
|
741
|
+
"",
|
|
742
|
+
"Provisioning failed. Delete and recreate:",
|
|
743
|
+
f" wafer workspaces delete {workspace['name']}",
|
|
744
|
+
f" wafer workspaces create {workspace['name']} --wait",
|
|
745
|
+
"",
|
|
746
|
+
"Note: Error workspaces are auto-cleaned after 12 hours.",
|
|
747
|
+
])
|
|
748
|
+
elif workspace.get("ssh_host"):
|
|
611
749
|
lines.extend([
|
|
612
750
|
"",
|
|
613
751
|
"SSH Info:",
|
|
614
752
|
f" Host: {workspace['ssh_host']}",
|
|
615
753
|
f" Port: {workspace.get('ssh_port', 22)}",
|
|
616
754
|
f" User: {workspace.get('ssh_user', 'root')}",
|
|
755
|
+
"",
|
|
756
|
+
"Tip: SSH directly for interactive work. 'exec' is for quick commands only.",
|
|
617
757
|
])
|
|
618
758
|
elif status == "creating":
|
|
619
759
|
lines.extend(["", "SSH: available once workspace is running"])
|
|
760
|
+
elif status == "running":
|
|
761
|
+
# Running but no SSH credentials - unusual state
|
|
762
|
+
lines.extend([
|
|
763
|
+
"",
|
|
764
|
+
"Status: Running but SSH not ready.",
|
|
765
|
+
f" Delete and recreate: wafer workspaces delete {workspace['name']}",
|
|
766
|
+
])
|
|
620
767
|
|
|
621
768
|
return "\n".join(lines)
|
|
622
769
|
|