wafer-cli 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +118 -0
- wafer/__init__.py +3 -0
- wafer/analytics.py +306 -0
- wafer/api_client.py +195 -0
- wafer/auth.py +432 -0
- wafer/autotuner.py +1080 -0
- wafer/billing.py +233 -0
- wafer/cli.py +7289 -0
- wafer/config.py +105 -0
- wafer/corpus.py +366 -0
- wafer/evaluate.py +4593 -0
- wafer/global_config.py +350 -0
- wafer/gpu_run.py +307 -0
- wafer/inference.py +148 -0
- wafer/kernel_scope.py +552 -0
- wafer/ncu_analyze.py +651 -0
- wafer/nsys_analyze.py +1042 -0
- wafer/nsys_profile.py +510 -0
- wafer/output.py +248 -0
- wafer/problems.py +357 -0
- wafer/rocprof_compute.py +490 -0
- wafer/rocprof_sdk.py +274 -0
- wafer/rocprof_systems.py +520 -0
- wafer/skills/wafer-guide/SKILL.md +129 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +842 -0
- wafer/targets_ops.py +717 -0
- wafer/templates/__init__.py +0 -0
- wafer/templates/ask_docs.py +61 -0
- wafer/templates/optimize_kernel.py +71 -0
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/templates/trace_analyze.py +74 -0
- wafer/tracelens.py +218 -0
- wafer/wevin_cli.py +577 -0
- wafer/workspaces.py +852 -0
- wafer_cli-0.2.14.dist-info/METADATA +16 -0
- wafer_cli-0.2.14.dist-info/RECORD +41 -0
- wafer_cli-0.2.14.dist-info/WHEEL +5 -0
- wafer_cli-0.2.14.dist-info/entry_points.txt +2 -0
- wafer_cli-0.2.14.dist-info/top_level.txt +1 -0
wafer/wevin_cli.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
1
|
+
"""Wafer Wevin CLI - thin wrapper that calls rollouts in-process.
|
|
2
|
+
|
|
3
|
+
Adds:
|
|
4
|
+
- Wafer auth (proxy token from ~/.wafer/credentials.json)
|
|
5
|
+
- Wafer templates (ask-docs, optimize-kernel, trace-analyze)
|
|
6
|
+
- Corpus path resolution (--corpus cuda -> ~/.cache/wafer/corpora/cuda)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from wafer_core.rollouts import Endpoint, Environment
|
|
19
|
+
from wafer_core.rollouts.dtypes import StreamEvent, ToolCall
|
|
20
|
+
from wafer_core.rollouts.templates import TemplateConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class StreamingChunkFrontend:
|
|
24
|
+
"""Frontend that emits real-time JSON chunk events.
|
|
25
|
+
|
|
26
|
+
Designed for programmatic consumption by extensions/UIs.
|
|
27
|
+
Emits events in the format expected by wevin-extension handleWevinEvent:
|
|
28
|
+
- {type: 'session_start', session_id: '...', model: '...'}
|
|
29
|
+
- {type: 'text_delta', delta: '...'}
|
|
30
|
+
- {type: 'tool_call_start', tool_name: '...'}
|
|
31
|
+
- {type: 'tool_call_end', tool_name: '...', args: {...}}
|
|
32
|
+
- {type: 'tool_result', is_error: bool}
|
|
33
|
+
- {type: 'session_end'}
|
|
34
|
+
- {type: 'error', error: '...'}
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, session_id: str | None = None, model: str | None = None) -> None:
|
|
38
|
+
self._current_tool_call: dict | None = None
|
|
39
|
+
self._session_id = session_id
|
|
40
|
+
self._model = model
|
|
41
|
+
|
|
42
|
+
def _emit(self, obj: dict) -> None:
|
|
43
|
+
"""Emit a single NDJSON line."""
|
|
44
|
+
print(json.dumps(obj, ensure_ascii=False), flush=True)
|
|
45
|
+
|
|
46
|
+
async def start(self) -> None:
|
|
47
|
+
"""Initialize frontend and emit session_start if session_id is known."""
|
|
48
|
+
if self._session_id:
|
|
49
|
+
self._emit({
|
|
50
|
+
"type": "session_start",
|
|
51
|
+
"session_id": self._session_id,
|
|
52
|
+
"model": self._model,
|
|
53
|
+
})
|
|
54
|
+
|
|
55
|
+
def emit_session_start(self, session_id: str, model: str | None = None) -> None:
|
|
56
|
+
"""Emit session_start event (for new sessions created during run)."""
|
|
57
|
+
self._emit({
|
|
58
|
+
"type": "session_start",
|
|
59
|
+
"session_id": session_id,
|
|
60
|
+
"model": model or self._model,
|
|
61
|
+
})
|
|
62
|
+
|
|
63
|
+
async def stop(self) -> None:
|
|
64
|
+
"""Emit session_end event."""
|
|
65
|
+
self._emit({"type": "session_end"})
|
|
66
|
+
|
|
67
|
+
async def handle_event(self, event: StreamEvent) -> None:
|
|
68
|
+
"""Handle streaming event by emitting JSON."""
|
|
69
|
+
from wafer_core.rollouts.dtypes import (
|
|
70
|
+
StreamDone,
|
|
71
|
+
StreamError,
|
|
72
|
+
TextDelta,
|
|
73
|
+
ThinkingDelta,
|
|
74
|
+
ToolCallEnd,
|
|
75
|
+
ToolCallStart,
|
|
76
|
+
ToolResultReceived,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if isinstance(event, TextDelta):
|
|
80
|
+
# Emit text delta immediately for real-time streaming
|
|
81
|
+
self._emit({"type": "text_delta", "delta": event.delta})
|
|
82
|
+
|
|
83
|
+
elif isinstance(event, ThinkingDelta):
|
|
84
|
+
# Skip thinking tokens (they clutter the output)
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
elif isinstance(event, ToolCallStart):
|
|
88
|
+
# Emit tool_call_start event (ToolCallStart has flat attributes)
|
|
89
|
+
self._current_tool_call = {
|
|
90
|
+
"id": event.tool_call_id,
|
|
91
|
+
"name": event.tool_name,
|
|
92
|
+
}
|
|
93
|
+
self._emit({"type": "tool_call_start", "tool_name": event.tool_name})
|
|
94
|
+
|
|
95
|
+
elif isinstance(event, ToolCallEnd):
|
|
96
|
+
# Emit tool_call_end event with tool name and args
|
|
97
|
+
tool_call = event.tool_call
|
|
98
|
+
self._emit({
|
|
99
|
+
"type": "tool_call_end",
|
|
100
|
+
"tool_name": tool_call.name,
|
|
101
|
+
"args": tool_call.args if tool_call.args else {},
|
|
102
|
+
})
|
|
103
|
+
|
|
104
|
+
elif isinstance(event, ToolResultReceived):
|
|
105
|
+
# Emit tool_result event with error details
|
|
106
|
+
result_event = {"type": "tool_result", "is_error": event.is_error}
|
|
107
|
+
# Include error message and content if available
|
|
108
|
+
if event.error:
|
|
109
|
+
result_event["error"] = event.error
|
|
110
|
+
if event.content:
|
|
111
|
+
# Convert content to string if it's a list
|
|
112
|
+
if isinstance(event.content, list):
|
|
113
|
+
result_event["content"] = "\n".join(
|
|
114
|
+
str(item) if not isinstance(item, dict) else item.get("text", str(item))
|
|
115
|
+
for item in event.content
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
result_event["content"] = str(event.content)
|
|
119
|
+
self._emit(result_event)
|
|
120
|
+
|
|
121
|
+
elif isinstance(event, StreamDone):
|
|
122
|
+
# Will be handled by stop()
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
elif isinstance(event, StreamError):
|
|
126
|
+
self._emit({"type": "error", "error": str(event.error)})
|
|
127
|
+
|
|
128
|
+
async def get_input(self, prompt: str = "") -> str:
|
|
129
|
+
"""Get user input - not supported in JSON mode."""
|
|
130
|
+
raise RuntimeError(
|
|
131
|
+
"StreamingChunkFrontend does not support interactive input. "
|
|
132
|
+
"Use -p to provide input or use -s for single-turn mode."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
async def confirm_tool(self, tool_call: ToolCall) -> bool:
|
|
136
|
+
"""Auto-approve all tools in JSON mode."""
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
def show_loader(self, text: str) -> None:
|
|
140
|
+
"""No-op for JSON mode."""
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
def hide_loader(self) -> None:
|
|
144
|
+
"""No-op for JSON mode."""
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
149
|
+
"""Get wafer auth credentials with fallback chain.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
(api_base, api_key) or (None, None) if no auth found
|
|
153
|
+
"""
|
|
154
|
+
from .auth import get_valid_token, load_credentials
|
|
155
|
+
from .global_config import get_api_url
|
|
156
|
+
|
|
157
|
+
# Check WAFER_AUTH_TOKEN env var first
|
|
158
|
+
wafer_token = os.environ.get("WAFER_AUTH_TOKEN", "")
|
|
159
|
+
token_source = "WAFER_AUTH_TOKEN" if wafer_token else None
|
|
160
|
+
|
|
161
|
+
# Try credentials file with automatic refresh
|
|
162
|
+
had_credentials = False
|
|
163
|
+
if not wafer_token:
|
|
164
|
+
try:
|
|
165
|
+
creds = load_credentials()
|
|
166
|
+
had_credentials = creds is not None and bool(creds.access_token)
|
|
167
|
+
except Exception:
|
|
168
|
+
pass
|
|
169
|
+
wafer_token = get_valid_token()
|
|
170
|
+
if wafer_token:
|
|
171
|
+
token_source = "~/.wafer/credentials.json"
|
|
172
|
+
|
|
173
|
+
# If we have a valid wafer token, use it
|
|
174
|
+
if wafer_token:
|
|
175
|
+
api_url = get_api_url()
|
|
176
|
+
print(f"🔑 Using wafer proxy ({token_source})\n", file=sys.stderr)
|
|
177
|
+
return f"{api_url}/v1/anthropic", wafer_token
|
|
178
|
+
|
|
179
|
+
# Fall back to direct anthropic
|
|
180
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
181
|
+
if api_key:
|
|
182
|
+
if had_credentials:
|
|
183
|
+
print(
|
|
184
|
+
"⚠️ Wafer credentials expired/invalid, falling back to ANTHROPIC_API_KEY\n",
|
|
185
|
+
file=sys.stderr,
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
print("🔑 Using ANTHROPIC_API_KEY\n", file=sys.stderr)
|
|
189
|
+
return "https://api.anthropic.com", api_key
|
|
190
|
+
|
|
191
|
+
return None, None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _get_session_preview(session: object) -> str:
|
|
195
|
+
"""Extract first user message preview from a session."""
|
|
196
|
+
messages = getattr(session, "messages", None)
|
|
197
|
+
if not messages:
|
|
198
|
+
return ""
|
|
199
|
+
for msg in messages:
|
|
200
|
+
if msg.role == "user" and isinstance(msg.content, str):
|
|
201
|
+
preview = msg.content[:50].replace("\n", " ")
|
|
202
|
+
if len(msg.content) > 50:
|
|
203
|
+
preview += "..."
|
|
204
|
+
return preview
|
|
205
|
+
return ""
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _setup_logging() -> None:
|
|
209
|
+
"""Configure logging to file only (no console spam)."""
|
|
210
|
+
import logging.config
|
|
211
|
+
|
|
212
|
+
logging.config.dictConfig({
|
|
213
|
+
"version": 1,
|
|
214
|
+
"disable_existing_loggers": False,
|
|
215
|
+
"formatters": {
|
|
216
|
+
"json": {
|
|
217
|
+
"format": '{"ts": "%(asctime)s", "level": "%(levelname)s", "logger": "%(name)s", "msg": "%(message)s"}',
|
|
218
|
+
},
|
|
219
|
+
},
|
|
220
|
+
"handlers": {
|
|
221
|
+
"file": {
|
|
222
|
+
"class": "logging.handlers.RotatingFileHandler",
|
|
223
|
+
"filename": "/tmp/wevin_debug.log",
|
|
224
|
+
"maxBytes": 10_000_000,
|
|
225
|
+
"backupCount": 3,
|
|
226
|
+
"formatter": "json",
|
|
227
|
+
},
|
|
228
|
+
},
|
|
229
|
+
"root": {"level": "DEBUG", "handlers": ["file"]},
|
|
230
|
+
})
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _unwrap_exception(e: BaseException) -> BaseException:
|
|
234
|
+
"""Unwrap ExceptionGroup from Trio to get the actual error."""
|
|
235
|
+
actual = e
|
|
236
|
+
while isinstance(actual, ExceptionGroup) and actual.exceptions:
|
|
237
|
+
actual = actual.exceptions[0]
|
|
238
|
+
return actual
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _build_endpoint(
|
|
242
|
+
tpl: TemplateConfig,
|
|
243
|
+
model_override: str | None,
|
|
244
|
+
api_base: str,
|
|
245
|
+
api_key: str,
|
|
246
|
+
) -> Endpoint:
|
|
247
|
+
"""Build an Endpoint from template config and auth."""
|
|
248
|
+
from wafer_core.rollouts import Endpoint
|
|
249
|
+
|
|
250
|
+
resolved_model = model_override or tpl.model
|
|
251
|
+
provider, model_id = resolved_model.split("/", 1)
|
|
252
|
+
thinking_config = (
|
|
253
|
+
{"type": "enabled", "budget_tokens": tpl.thinking_budget} if tpl.thinking else None
|
|
254
|
+
)
|
|
255
|
+
return Endpoint(
|
|
256
|
+
provider=provider,
|
|
257
|
+
model=model_id,
|
|
258
|
+
api_base=api_base,
|
|
259
|
+
api_key=api_key,
|
|
260
|
+
thinking=thinking_config,
|
|
261
|
+
max_tokens=tpl.max_tokens,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _build_environment(
|
|
266
|
+
tpl: TemplateConfig,
|
|
267
|
+
tools_override: list[str] | None,
|
|
268
|
+
corpus_path: str | None,
|
|
269
|
+
) -> Environment:
|
|
270
|
+
"""Build a CodingEnvironment from template config."""
|
|
271
|
+
from wafer_core.environments.coding import CodingEnvironment
|
|
272
|
+
from wafer_core.rollouts.templates import DANGEROUS_BASH_COMMANDS
|
|
273
|
+
|
|
274
|
+
working_dir = Path(corpus_path) if corpus_path else Path.cwd()
|
|
275
|
+
resolved_tools = tools_override or tpl.tools
|
|
276
|
+
env: Environment = CodingEnvironment(
|
|
277
|
+
working_dir=working_dir,
|
|
278
|
+
enabled_tools=resolved_tools,
|
|
279
|
+
bash_allowlist=tpl.bash_allowlist,
|
|
280
|
+
bash_denylist=DANGEROUS_BASH_COMMANDS,
|
|
281
|
+
) # type: ignore[assignment]
|
|
282
|
+
return env
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _resolve_session_id(resume: str | None, session_store: object) -> str | None:
|
|
286
|
+
"""Resolve session ID from resume arg. Exits on error."""
|
|
287
|
+
if not resume:
|
|
288
|
+
return None
|
|
289
|
+
session_id = resume if resume != "last" else session_store.get_latest_id_sync() # type: ignore[union-attr]
|
|
290
|
+
if not session_id:
|
|
291
|
+
print("Error: No session to resume", file=sys.stderr)
|
|
292
|
+
sys.exit(1)
|
|
293
|
+
return session_id
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _get_default_template() -> TemplateConfig:
|
|
297
|
+
"""Return the default agent template with full wafer tooling."""
|
|
298
|
+
from wafer_core.rollouts.templates import TemplateConfig
|
|
299
|
+
|
|
300
|
+
return TemplateConfig(
|
|
301
|
+
name="default",
|
|
302
|
+
description="GPU kernel development assistant",
|
|
303
|
+
system_prompt="""You are a GPU kernel development assistant. You help with CUDA/Triton kernel optimization, profiling, and debugging.
|
|
304
|
+
|
|
305
|
+
You have access to these tools:
|
|
306
|
+
|
|
307
|
+
**File tools:**
|
|
308
|
+
- read: Read file contents
|
|
309
|
+
- write: Create new files
|
|
310
|
+
- edit: Modify existing files
|
|
311
|
+
- glob: Find files by pattern
|
|
312
|
+
- grep: Search file contents
|
|
313
|
+
|
|
314
|
+
**Bash:** Run shell commands including wafer CLI tools:
|
|
315
|
+
- `wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json` - Test kernel correctness and performance
|
|
316
|
+
- `wafer nvidia ncu analyze <file.ncu-rep>` - Analyze NCU profiling reports
|
|
317
|
+
- `wafer nvidia nsys analyze <file.nsys-rep>` - Analyze Nsight Systems traces
|
|
318
|
+
- `wafer nvidia perfetto tables <trace.json>` - Query Perfetto traces
|
|
319
|
+
- `wafer config targets list` - List available GPU targets
|
|
320
|
+
|
|
321
|
+
When asked to profile or analyze kernels, use the appropriate wafer commands. Be concise and focus on actionable insights.""",
|
|
322
|
+
tools=["read", "write", "edit", "glob", "grep", "bash"],
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _load_template(
|
|
327
|
+
template_name: str, template_args: dict[str, str] | None = None
|
|
328
|
+
) -> tuple[TemplateConfig | None, str | None]:
|
|
329
|
+
"""Load a wafer template. Returns (template, error)."""
|
|
330
|
+
try:
|
|
331
|
+
from wafer_core.rollouts.templates import load_template
|
|
332
|
+
from wafer_core.rollouts.templates.loader import _get_search_paths
|
|
333
|
+
|
|
334
|
+
# Prepend wafer-cli bundled templates to default search paths
|
|
335
|
+
bundled_templates = Path(__file__).parent / "templates"
|
|
336
|
+
search_paths = _get_search_paths()
|
|
337
|
+
if bundled_templates.exists():
|
|
338
|
+
search_paths = [bundled_templates] + search_paths
|
|
339
|
+
|
|
340
|
+
template: TemplateConfig = load_template(template_name, search_paths=search_paths)
|
|
341
|
+
# Interpolate prompt variables but keep the full config
|
|
342
|
+
_ = template.interpolate_prompt(template_args or {}) # validates variables exist
|
|
343
|
+
return template, None
|
|
344
|
+
except Exception as e:
|
|
345
|
+
return None, str(e)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def main( # noqa: PLR0913, PLR0915
|
|
349
|
+
prompt: str | None = None,
|
|
350
|
+
interactive: bool = False,
|
|
351
|
+
single_turn: bool | None = None, # None = use template default
|
|
352
|
+
model: str | None = None,
|
|
353
|
+
resume: str | None = None,
|
|
354
|
+
from_turn: int | None = None,
|
|
355
|
+
tools: list[str] | None = None,
|
|
356
|
+
allow_spawn: bool = False,
|
|
357
|
+
max_tool_fails: int | None = None,
|
|
358
|
+
max_turns: int | None = None,
|
|
359
|
+
template: str | None = None,
|
|
360
|
+
template_args: dict[str, str] | None = None,
|
|
361
|
+
corpus_path: str | None = None,
|
|
362
|
+
list_sessions: bool = False,
|
|
363
|
+
get_session: str | None = None,
|
|
364
|
+
json_output: bool = False,
|
|
365
|
+
) -> None:
|
|
366
|
+
"""Run wevin agent in-process via rollouts."""
|
|
367
|
+
from dataclasses import asdict
|
|
368
|
+
|
|
369
|
+
import trio
|
|
370
|
+
from wafer_core.rollouts import FileSessionStore
|
|
371
|
+
|
|
372
|
+
session_store = FileSessionStore()
|
|
373
|
+
|
|
374
|
+
# Handle --get-session: load session by ID and print
|
|
375
|
+
if get_session:
|
|
376
|
+
async def _get_session() -> None:
|
|
377
|
+
try:
|
|
378
|
+
session, err = await session_store.get(get_session)
|
|
379
|
+
if err or not session:
|
|
380
|
+
if json_output:
|
|
381
|
+
print(json.dumps({"error": err or f"Session {get_session} not found"}))
|
|
382
|
+
sys.exit(1)
|
|
383
|
+
else:
|
|
384
|
+
print(f"Error: {err or 'Session not found'}", file=sys.stderr)
|
|
385
|
+
sys.exit(1)
|
|
386
|
+
|
|
387
|
+
if json_output:
|
|
388
|
+
# Serialize messages to dicts
|
|
389
|
+
try:
|
|
390
|
+
messages_data = [asdict(msg) for msg in session.messages]
|
|
391
|
+
except Exception as e:
|
|
392
|
+
# If serialization fails, return error
|
|
393
|
+
error_msg = f"Failed to serialize messages: {e}"
|
|
394
|
+
print(json.dumps({"error": error_msg}))
|
|
395
|
+
sys.exit(1)
|
|
396
|
+
|
|
397
|
+
print(json.dumps({
|
|
398
|
+
"session_id": session.session_id,
|
|
399
|
+
"status": session.status.value,
|
|
400
|
+
"model": session.endpoint.model if session.endpoint else None,
|
|
401
|
+
"created_at": session.created_at,
|
|
402
|
+
"updated_at": session.updated_at,
|
|
403
|
+
"messages": messages_data,
|
|
404
|
+
"tags": session.tags,
|
|
405
|
+
}))
|
|
406
|
+
else:
|
|
407
|
+
print(f"Session: {session.session_id}")
|
|
408
|
+
print(f"Status: {session.status.value}")
|
|
409
|
+
print(f"Messages: {len(session.messages)}")
|
|
410
|
+
for i, msg in enumerate(session.messages):
|
|
411
|
+
# Fail fast if message can't be converted to string - corrupted data is a bug
|
|
412
|
+
content_preview = str(msg.content)[:100] if msg.content else ""
|
|
413
|
+
print(f" [{i}] {msg.role}: {content_preview}...")
|
|
414
|
+
except KeyboardInterrupt:
|
|
415
|
+
# User cancelled - exit cleanly
|
|
416
|
+
sys.exit(130) # Standard exit code for SIGINT
|
|
417
|
+
except Exception as e:
|
|
418
|
+
# Any other error - log and exit with error
|
|
419
|
+
error_msg = f"Failed to load session {get_session}: {e}"
|
|
420
|
+
if json_output:
|
|
421
|
+
print(json.dumps({"error": error_msg}))
|
|
422
|
+
else:
|
|
423
|
+
print(f"Error: {error_msg}", file=sys.stderr)
|
|
424
|
+
sys.exit(1)
|
|
425
|
+
|
|
426
|
+
try:
|
|
427
|
+
trio.run(_get_session)
|
|
428
|
+
except KeyboardInterrupt:
|
|
429
|
+
sys.exit(130)
|
|
430
|
+
except Exception as e:
|
|
431
|
+
error_msg = f"Failed to run session loader: {e}"
|
|
432
|
+
if json_output:
|
|
433
|
+
print(json.dumps({"error": error_msg}))
|
|
434
|
+
else:
|
|
435
|
+
print(f"Error: {error_msg}", file=sys.stderr)
|
|
436
|
+
sys.exit(1)
|
|
437
|
+
return
|
|
438
|
+
|
|
439
|
+
# Handle --list-sessions: show recent sessions and exit
|
|
440
|
+
if list_sessions:
|
|
441
|
+
sessions = session_store.list_sync(limit=50)
|
|
442
|
+
if json_output:
|
|
443
|
+
# Return metadata only - messages loaded on-demand via --get-session
|
|
444
|
+
sessions_data = []
|
|
445
|
+
for s in sessions:
|
|
446
|
+
sessions_data.append({
|
|
447
|
+
"session_id": s.session_id,
|
|
448
|
+
"status": s.status.value,
|
|
449
|
+
"model": s.endpoint.model if s.endpoint else None,
|
|
450
|
+
"created_at": s.created_at if hasattr(s, "created_at") else None,
|
|
451
|
+
"updated_at": s.updated_at if hasattr(s, "updated_at") else None,
|
|
452
|
+
"message_count": len(s.messages),
|
|
453
|
+
"preview": _get_session_preview(s),
|
|
454
|
+
})
|
|
455
|
+
print(json.dumps({"sessions": sessions_data}))
|
|
456
|
+
else:
|
|
457
|
+
if not sessions:
|
|
458
|
+
print("No sessions found.")
|
|
459
|
+
else:
|
|
460
|
+
print("Recent sessions:")
|
|
461
|
+
for s in sessions:
|
|
462
|
+
preview = _get_session_preview(s)
|
|
463
|
+
print(f" {s.session_id} {preview}")
|
|
464
|
+
return
|
|
465
|
+
|
|
466
|
+
# Emit early event for JSON mode before heavy imports
|
|
467
|
+
# This gives immediate feedback that the CLI started correctly
|
|
468
|
+
if json_output:
|
|
469
|
+
print(json.dumps({"type": "initializing"}), flush=True)
|
|
470
|
+
|
|
471
|
+
from wafer_core.rollouts import Message, Trajectory
|
|
472
|
+
from wafer_core.rollouts.frontends import NoneFrontend, RunnerConfig, run_interactive
|
|
473
|
+
|
|
474
|
+
_setup_logging()
|
|
475
|
+
|
|
476
|
+
# Auth
|
|
477
|
+
api_base, api_key = _get_wafer_auth()
|
|
478
|
+
if not api_base or not api_key:
|
|
479
|
+
print("Error: No API credentials found", file=sys.stderr)
|
|
480
|
+
print(" Run 'wafer login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
481
|
+
sys.exit(1)
|
|
482
|
+
|
|
483
|
+
assert api_base is not None
|
|
484
|
+
assert api_key is not None
|
|
485
|
+
|
|
486
|
+
# Load template or use defaults
|
|
487
|
+
if template:
|
|
488
|
+
loaded_template, err = _load_template(template, template_args)
|
|
489
|
+
if err or loaded_template is None:
|
|
490
|
+
print(f"Error loading template: {err}", file=sys.stderr)
|
|
491
|
+
sys.exit(1)
|
|
492
|
+
tpl = loaded_template
|
|
493
|
+
system_prompt = tpl.interpolate_prompt(template_args or {})
|
|
494
|
+
# Show template info when starting without a prompt
|
|
495
|
+
if not prompt and tpl.description:
|
|
496
|
+
print(f"Template: {tpl.name}", file=sys.stderr)
|
|
497
|
+
print(f" {tpl.description}", file=sys.stderr)
|
|
498
|
+
print(file=sys.stderr)
|
|
499
|
+
else:
|
|
500
|
+
tpl = _get_default_template()
|
|
501
|
+
system_prompt = tpl.system_prompt
|
|
502
|
+
|
|
503
|
+
# CLI args override template values
|
|
504
|
+
resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
|
|
505
|
+
|
|
506
|
+
# Build endpoint and environment
|
|
507
|
+
endpoint = _build_endpoint(tpl, model, api_base, api_key)
|
|
508
|
+
environment = _build_environment(tpl, tools, corpus_path)
|
|
509
|
+
|
|
510
|
+
# Session store
|
|
511
|
+
session_store = FileSessionStore()
|
|
512
|
+
session_id = _resolve_session_id(resume, session_store)
|
|
513
|
+
|
|
514
|
+
async def run() -> None:
|
|
515
|
+
nonlocal session_id
|
|
516
|
+
|
|
517
|
+
# Load trajectory - either from resumed session or fresh
|
|
518
|
+
if session_id:
|
|
519
|
+
existing_session, err = await session_store.get(session_id)
|
|
520
|
+
if err:
|
|
521
|
+
print(f"Error loading session: {err}", file=sys.stderr)
|
|
522
|
+
sys.exit(1)
|
|
523
|
+
assert existing_session is not None
|
|
524
|
+
trajectory = Trajectory(messages=existing_session.messages)
|
|
525
|
+
else:
|
|
526
|
+
trajectory = Trajectory(messages=[Message(role="system", content=system_prompt)])
|
|
527
|
+
|
|
528
|
+
try:
|
|
529
|
+
if interactive:
|
|
530
|
+
from wafer_core.rollouts.frontends.tui.interactive_agent import (
|
|
531
|
+
run_interactive_agent,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
await run_interactive_agent(
|
|
535
|
+
trajectory,
|
|
536
|
+
endpoint,
|
|
537
|
+
environment,
|
|
538
|
+
session_store,
|
|
539
|
+
session_id,
|
|
540
|
+
theme_name="minimal",
|
|
541
|
+
debug=False,
|
|
542
|
+
debug_layout=False,
|
|
543
|
+
initial_prompt=prompt,
|
|
544
|
+
)
|
|
545
|
+
else:
|
|
546
|
+
if json_output:
|
|
547
|
+
# Emit session_start if we have a session_id (from --resume)
|
|
548
|
+
model_name = endpoint.model if hasattr(endpoint, 'model') else None
|
|
549
|
+
frontend = StreamingChunkFrontend(session_id=session_id, model=model_name)
|
|
550
|
+
else:
|
|
551
|
+
frontend = NoneFrontend(show_tool_calls=True, show_thinking=False)
|
|
552
|
+
config = RunnerConfig(
|
|
553
|
+
session_store=session_store,
|
|
554
|
+
session_id=session_id,
|
|
555
|
+
initial_prompt=prompt,
|
|
556
|
+
single_turn=resolved_single_turn,
|
|
557
|
+
hide_session_info=True, # We print our own resume command
|
|
558
|
+
)
|
|
559
|
+
states = await run_interactive(trajectory, endpoint, frontend, environment, config)
|
|
560
|
+
# Emit session_start for new sessions (if session_id was None and we got one)
|
|
561
|
+
# Check first state to emit as early as possible
|
|
562
|
+
if json_output and isinstance(frontend, StreamingChunkFrontend):
|
|
563
|
+
first_session_id = states[0].session_id if states and states[0].session_id else None
|
|
564
|
+
if first_session_id and not session_id: # New session created
|
|
565
|
+
model_name = endpoint.model if hasattr(endpoint, 'model') else None
|
|
566
|
+
frontend.emit_session_start(first_session_id, model_name)
|
|
567
|
+
# Print resume command with full wafer agent prefix
|
|
568
|
+
if states and states[-1].session_id:
|
|
569
|
+
print(f"\nResume with: wafer agent --resume {states[-1].session_id}")
|
|
570
|
+
except KeyboardInterrupt:
|
|
571
|
+
pass
|
|
572
|
+
except BaseException as e:
|
|
573
|
+
actual_error = _unwrap_exception(e)
|
|
574
|
+
print(f"\n{type(actual_error).__name__}: {actual_error}", file=sys.stderr)
|
|
575
|
+
sys.exit(1)
|
|
576
|
+
|
|
577
|
+
trio.run(run)
|