wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/wevin_cli.py ADDED
@@ -0,0 +1,577 @@
1
+ """Wafer Wevin CLI - thin wrapper that calls rollouts in-process.
2
+
3
+ Adds:
4
+ - Wafer auth (proxy token from ~/.wafer/credentials.json)
5
+ - Wafer templates (ask-docs, optimize-kernel, trace-analyze)
6
+ - Corpus path resolution (--corpus cuda -> ~/.cache/wafer/corpora/cuda)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ import sys
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ from wafer_core.rollouts import Endpoint, Environment
19
+ from wafer_core.rollouts.dtypes import StreamEvent, ToolCall
20
+ from wafer_core.rollouts.templates import TemplateConfig
21
+
22
+
23
+ class StreamingChunkFrontend:
24
+ """Frontend that emits real-time JSON chunk events.
25
+
26
+ Designed for programmatic consumption by extensions/UIs.
27
+ Emits events in the format expected by wevin-extension handleWevinEvent:
28
+ - {type: 'session_start', session_id: '...', model: '...'}
29
+ - {type: 'text_delta', delta: '...'}
30
+ - {type: 'tool_call_start', tool_name: '...'}
31
+ - {type: 'tool_call_end', tool_name: '...', args: {...}}
32
+ - {type: 'tool_result', is_error: bool}
33
+ - {type: 'session_end'}
34
+ - {type: 'error', error: '...'}
35
+ """
36
+
37
+ def __init__(self, session_id: str | None = None, model: str | None = None) -> None:
38
+ self._current_tool_call: dict | None = None
39
+ self._session_id = session_id
40
+ self._model = model
41
+
42
+ def _emit(self, obj: dict) -> None:
43
+ """Emit a single NDJSON line."""
44
+ print(json.dumps(obj, ensure_ascii=False), flush=True)
45
+
46
+ async def start(self) -> None:
47
+ """Initialize frontend and emit session_start if session_id is known."""
48
+ if self._session_id:
49
+ self._emit({
50
+ "type": "session_start",
51
+ "session_id": self._session_id,
52
+ "model": self._model,
53
+ })
54
+
55
+ def emit_session_start(self, session_id: str, model: str | None = None) -> None:
56
+ """Emit session_start event (for new sessions created during run)."""
57
+ self._emit({
58
+ "type": "session_start",
59
+ "session_id": session_id,
60
+ "model": model or self._model,
61
+ })
62
+
63
+ async def stop(self) -> None:
64
+ """Emit session_end event."""
65
+ self._emit({"type": "session_end"})
66
+
67
+ async def handle_event(self, event: StreamEvent) -> None:
68
+ """Handle streaming event by emitting JSON."""
69
+ from wafer_core.rollouts.dtypes import (
70
+ StreamDone,
71
+ StreamError,
72
+ TextDelta,
73
+ ThinkingDelta,
74
+ ToolCallEnd,
75
+ ToolCallStart,
76
+ ToolResultReceived,
77
+ )
78
+
79
+ if isinstance(event, TextDelta):
80
+ # Emit text delta immediately for real-time streaming
81
+ self._emit({"type": "text_delta", "delta": event.delta})
82
+
83
+ elif isinstance(event, ThinkingDelta):
84
+ # Skip thinking tokens (they clutter the output)
85
+ pass
86
+
87
+ elif isinstance(event, ToolCallStart):
88
+ # Emit tool_call_start event (ToolCallStart has flat attributes)
89
+ self._current_tool_call = {
90
+ "id": event.tool_call_id,
91
+ "name": event.tool_name,
92
+ }
93
+ self._emit({"type": "tool_call_start", "tool_name": event.tool_name})
94
+
95
+ elif isinstance(event, ToolCallEnd):
96
+ # Emit tool_call_end event with tool name and args
97
+ tool_call = event.tool_call
98
+ self._emit({
99
+ "type": "tool_call_end",
100
+ "tool_name": tool_call.name,
101
+ "args": tool_call.args if tool_call.args else {},
102
+ })
103
+
104
+ elif isinstance(event, ToolResultReceived):
105
+ # Emit tool_result event with error details
106
+ result_event = {"type": "tool_result", "is_error": event.is_error}
107
+ # Include error message and content if available
108
+ if event.error:
109
+ result_event["error"] = event.error
110
+ if event.content:
111
+ # Convert content to string if it's a list
112
+ if isinstance(event.content, list):
113
+ result_event["content"] = "\n".join(
114
+ str(item) if not isinstance(item, dict) else item.get("text", str(item))
115
+ for item in event.content
116
+ )
117
+ else:
118
+ result_event["content"] = str(event.content)
119
+ self._emit(result_event)
120
+
121
+ elif isinstance(event, StreamDone):
122
+ # Will be handled by stop()
123
+ pass
124
+
125
+ elif isinstance(event, StreamError):
126
+ self._emit({"type": "error", "error": str(event.error)})
127
+
128
+ async def get_input(self, prompt: str = "") -> str:
129
+ """Get user input - not supported in JSON mode."""
130
+ raise RuntimeError(
131
+ "StreamingChunkFrontend does not support interactive input. "
132
+ "Use -p to provide input or use -s for single-turn mode."
133
+ )
134
+
135
+ async def confirm_tool(self, tool_call: ToolCall) -> bool:
136
+ """Auto-approve all tools in JSON mode."""
137
+ return True
138
+
139
+ def show_loader(self, text: str) -> None:
140
+ """No-op for JSON mode."""
141
+ pass
142
+
143
+ def hide_loader(self) -> None:
144
+ """No-op for JSON mode."""
145
+ pass
146
+
147
+
148
+ def _get_wafer_auth() -> tuple[str | None, str | None]:
149
+ """Get wafer auth credentials with fallback chain.
150
+
151
+ Returns:
152
+ (api_base, api_key) or (None, None) if no auth found
153
+ """
154
+ from .auth import get_valid_token, load_credentials
155
+ from .global_config import get_api_url
156
+
157
+ # Check WAFER_AUTH_TOKEN env var first
158
+ wafer_token = os.environ.get("WAFER_AUTH_TOKEN", "")
159
+ token_source = "WAFER_AUTH_TOKEN" if wafer_token else None
160
+
161
+ # Try credentials file with automatic refresh
162
+ had_credentials = False
163
+ if not wafer_token:
164
+ try:
165
+ creds = load_credentials()
166
+ had_credentials = creds is not None and bool(creds.access_token)
167
+ except Exception:
168
+ pass
169
+ wafer_token = get_valid_token()
170
+ if wafer_token:
171
+ token_source = "~/.wafer/credentials.json"
172
+
173
+ # If we have a valid wafer token, use it
174
+ if wafer_token:
175
+ api_url = get_api_url()
176
+ print(f"🔑 Using wafer proxy ({token_source})\n", file=sys.stderr)
177
+ return f"{api_url}/v1/anthropic", wafer_token
178
+
179
+ # Fall back to direct anthropic
180
+ api_key = os.environ.get("ANTHROPIC_API_KEY", "")
181
+ if api_key:
182
+ if had_credentials:
183
+ print(
184
+ "⚠️ Wafer credentials expired/invalid, falling back to ANTHROPIC_API_KEY\n",
185
+ file=sys.stderr,
186
+ )
187
+ else:
188
+ print("🔑 Using ANTHROPIC_API_KEY\n", file=sys.stderr)
189
+ return "https://api.anthropic.com", api_key
190
+
191
+ return None, None
192
+
193
+
194
+ def _get_session_preview(session: object) -> str:
195
+ """Extract first user message preview from a session."""
196
+ messages = getattr(session, "messages", None)
197
+ if not messages:
198
+ return ""
199
+ for msg in messages:
200
+ if msg.role == "user" and isinstance(msg.content, str):
201
+ preview = msg.content[:50].replace("\n", " ")
202
+ if len(msg.content) > 50:
203
+ preview += "..."
204
+ return preview
205
+ return ""
206
+
207
+
208
+ def _setup_logging() -> None:
209
+ """Configure logging to file only (no console spam)."""
210
+ import logging.config
211
+
212
+ logging.config.dictConfig({
213
+ "version": 1,
214
+ "disable_existing_loggers": False,
215
+ "formatters": {
216
+ "json": {
217
+ "format": '{"ts": "%(asctime)s", "level": "%(levelname)s", "logger": "%(name)s", "msg": "%(message)s"}',
218
+ },
219
+ },
220
+ "handlers": {
221
+ "file": {
222
+ "class": "logging.handlers.RotatingFileHandler",
223
+ "filename": "/tmp/wevin_debug.log",
224
+ "maxBytes": 10_000_000,
225
+ "backupCount": 3,
226
+ "formatter": "json",
227
+ },
228
+ },
229
+ "root": {"level": "DEBUG", "handlers": ["file"]},
230
+ })
231
+
232
+
233
+ def _unwrap_exception(e: BaseException) -> BaseException:
234
+ """Unwrap ExceptionGroup from Trio to get the actual error."""
235
+ actual = e
236
+ while isinstance(actual, ExceptionGroup) and actual.exceptions:
237
+ actual = actual.exceptions[0]
238
+ return actual
239
+
240
+
241
+ def _build_endpoint(
242
+ tpl: TemplateConfig,
243
+ model_override: str | None,
244
+ api_base: str,
245
+ api_key: str,
246
+ ) -> Endpoint:
247
+ """Build an Endpoint from template config and auth."""
248
+ from wafer_core.rollouts import Endpoint
249
+
250
+ resolved_model = model_override or tpl.model
251
+ provider, model_id = resolved_model.split("/", 1)
252
+ thinking_config = (
253
+ {"type": "enabled", "budget_tokens": tpl.thinking_budget} if tpl.thinking else None
254
+ )
255
+ return Endpoint(
256
+ provider=provider,
257
+ model=model_id,
258
+ api_base=api_base,
259
+ api_key=api_key,
260
+ thinking=thinking_config,
261
+ max_tokens=tpl.max_tokens,
262
+ )
263
+
264
+
265
+ def _build_environment(
266
+ tpl: TemplateConfig,
267
+ tools_override: list[str] | None,
268
+ corpus_path: str | None,
269
+ ) -> Environment:
270
+ """Build a CodingEnvironment from template config."""
271
+ from wafer_core.environments.coding import CodingEnvironment
272
+ from wafer_core.rollouts.templates import DANGEROUS_BASH_COMMANDS
273
+
274
+ working_dir = Path(corpus_path) if corpus_path else Path.cwd()
275
+ resolved_tools = tools_override or tpl.tools
276
+ env: Environment = CodingEnvironment(
277
+ working_dir=working_dir,
278
+ enabled_tools=resolved_tools,
279
+ bash_allowlist=tpl.bash_allowlist,
280
+ bash_denylist=DANGEROUS_BASH_COMMANDS,
281
+ ) # type: ignore[assignment]
282
+ return env
283
+
284
+
285
+ def _resolve_session_id(resume: str | None, session_store: object) -> str | None:
286
+ """Resolve session ID from resume arg. Exits on error."""
287
+ if not resume:
288
+ return None
289
+ session_id = resume if resume != "last" else session_store.get_latest_id_sync() # type: ignore[union-attr]
290
+ if not session_id:
291
+ print("Error: No session to resume", file=sys.stderr)
292
+ sys.exit(1)
293
+ return session_id
294
+
295
+
296
+ def _get_default_template() -> TemplateConfig:
297
+ """Return the default agent template with full wafer tooling."""
298
+ from wafer_core.rollouts.templates import TemplateConfig
299
+
300
+ return TemplateConfig(
301
+ name="default",
302
+ description="GPU kernel development assistant",
303
+ system_prompt="""You are a GPU kernel development assistant. You help with CUDA/Triton kernel optimization, profiling, and debugging.
304
+
305
+ You have access to these tools:
306
+
307
+ **File tools:**
308
+ - read: Read file contents
309
+ - write: Create new files
310
+ - edit: Modify existing files
311
+ - glob: Find files by pattern
312
+ - grep: Search file contents
313
+
314
+ **Bash:** Run shell commands including wafer CLI tools:
315
+ - `wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json` - Test kernel correctness and performance
316
+ - `wafer nvidia ncu analyze <file.ncu-rep>` - Analyze NCU profiling reports
317
+ - `wafer nvidia nsys analyze <file.nsys-rep>` - Analyze Nsight Systems traces
318
+ - `wafer nvidia perfetto tables <trace.json>` - Query Perfetto traces
319
+ - `wafer config targets list` - List available GPU targets
320
+
321
+ When asked to profile or analyze kernels, use the appropriate wafer commands. Be concise and focus on actionable insights.""",
322
+ tools=["read", "write", "edit", "glob", "grep", "bash"],
323
+ )
324
+
325
+
326
+ def _load_template(
327
+ template_name: str, template_args: dict[str, str] | None = None
328
+ ) -> tuple[TemplateConfig | None, str | None]:
329
+ """Load a wafer template. Returns (template, error)."""
330
+ try:
331
+ from wafer_core.rollouts.templates import load_template
332
+ from wafer_core.rollouts.templates.loader import _get_search_paths
333
+
334
+ # Prepend wafer-cli bundled templates to default search paths
335
+ bundled_templates = Path(__file__).parent / "templates"
336
+ search_paths = _get_search_paths()
337
+ if bundled_templates.exists():
338
+ search_paths = [bundled_templates] + search_paths
339
+
340
+ template: TemplateConfig = load_template(template_name, search_paths=search_paths)
341
+ # Interpolate prompt variables but keep the full config
342
+ _ = template.interpolate_prompt(template_args or {}) # validates variables exist
343
+ return template, None
344
+ except Exception as e:
345
+ return None, str(e)
346
+
347
+
348
+ def main( # noqa: PLR0913, PLR0915
349
+ prompt: str | None = None,
350
+ interactive: bool = False,
351
+ single_turn: bool | None = None, # None = use template default
352
+ model: str | None = None,
353
+ resume: str | None = None,
354
+ from_turn: int | None = None,
355
+ tools: list[str] | None = None,
356
+ allow_spawn: bool = False,
357
+ max_tool_fails: int | None = None,
358
+ max_turns: int | None = None,
359
+ template: str | None = None,
360
+ template_args: dict[str, str] | None = None,
361
+ corpus_path: str | None = None,
362
+ list_sessions: bool = False,
363
+ get_session: str | None = None,
364
+ json_output: bool = False,
365
+ ) -> None:
366
+ """Run wevin agent in-process via rollouts."""
367
+ from dataclasses import asdict
368
+
369
+ import trio
370
+ from wafer_core.rollouts import FileSessionStore
371
+
372
+ session_store = FileSessionStore()
373
+
374
+ # Handle --get-session: load session by ID and print
375
+ if get_session:
376
+ async def _get_session() -> None:
377
+ try:
378
+ session, err = await session_store.get(get_session)
379
+ if err or not session:
380
+ if json_output:
381
+ print(json.dumps({"error": err or f"Session {get_session} not found"}))
382
+ sys.exit(1)
383
+ else:
384
+ print(f"Error: {err or 'Session not found'}", file=sys.stderr)
385
+ sys.exit(1)
386
+
387
+ if json_output:
388
+ # Serialize messages to dicts
389
+ try:
390
+ messages_data = [asdict(msg) for msg in session.messages]
391
+ except Exception as e:
392
+ # If serialization fails, return error
393
+ error_msg = f"Failed to serialize messages: {e}"
394
+ print(json.dumps({"error": error_msg}))
395
+ sys.exit(1)
396
+
397
+ print(json.dumps({
398
+ "session_id": session.session_id,
399
+ "status": session.status.value,
400
+ "model": session.endpoint.model if session.endpoint else None,
401
+ "created_at": session.created_at,
402
+ "updated_at": session.updated_at,
403
+ "messages": messages_data,
404
+ "tags": session.tags,
405
+ }))
406
+ else:
407
+ print(f"Session: {session.session_id}")
408
+ print(f"Status: {session.status.value}")
409
+ print(f"Messages: {len(session.messages)}")
410
+ for i, msg in enumerate(session.messages):
411
+ # Fail fast if message can't be converted to string - corrupted data is a bug
412
+ content_preview = str(msg.content)[:100] if msg.content else ""
413
+ print(f" [{i}] {msg.role}: {content_preview}...")
414
+ except KeyboardInterrupt:
415
+ # User cancelled - exit cleanly
416
+ sys.exit(130) # Standard exit code for SIGINT
417
+ except Exception as e:
418
+ # Any other error - log and exit with error
419
+ error_msg = f"Failed to load session {get_session}: {e}"
420
+ if json_output:
421
+ print(json.dumps({"error": error_msg}))
422
+ else:
423
+ print(f"Error: {error_msg}", file=sys.stderr)
424
+ sys.exit(1)
425
+
426
+ try:
427
+ trio.run(_get_session)
428
+ except KeyboardInterrupt:
429
+ sys.exit(130)
430
+ except Exception as e:
431
+ error_msg = f"Failed to run session loader: {e}"
432
+ if json_output:
433
+ print(json.dumps({"error": error_msg}))
434
+ else:
435
+ print(f"Error: {error_msg}", file=sys.stderr)
436
+ sys.exit(1)
437
+ return
438
+
439
+ # Handle --list-sessions: show recent sessions and exit
440
+ if list_sessions:
441
+ sessions = session_store.list_sync(limit=50)
442
+ if json_output:
443
+ # Return metadata only - messages loaded on-demand via --get-session
444
+ sessions_data = []
445
+ for s in sessions:
446
+ sessions_data.append({
447
+ "session_id": s.session_id,
448
+ "status": s.status.value,
449
+ "model": s.endpoint.model if s.endpoint else None,
450
+ "created_at": s.created_at if hasattr(s, "created_at") else None,
451
+ "updated_at": s.updated_at if hasattr(s, "updated_at") else None,
452
+ "message_count": len(s.messages),
453
+ "preview": _get_session_preview(s),
454
+ })
455
+ print(json.dumps({"sessions": sessions_data}))
456
+ else:
457
+ if not sessions:
458
+ print("No sessions found.")
459
+ else:
460
+ print("Recent sessions:")
461
+ for s in sessions:
462
+ preview = _get_session_preview(s)
463
+ print(f" {s.session_id} {preview}")
464
+ return
465
+
466
+ # Emit early event for JSON mode before heavy imports
467
+ # This gives immediate feedback that the CLI started correctly
468
+ if json_output:
469
+ print(json.dumps({"type": "initializing"}), flush=True)
470
+
471
+ from wafer_core.rollouts import Message, Trajectory
472
+ from wafer_core.rollouts.frontends import NoneFrontend, RunnerConfig, run_interactive
473
+
474
+ _setup_logging()
475
+
476
+ # Auth
477
+ api_base, api_key = _get_wafer_auth()
478
+ if not api_base or not api_key:
479
+ print("Error: No API credentials found", file=sys.stderr)
480
+ print(" Run 'wafer login' or set ANTHROPIC_API_KEY", file=sys.stderr)
481
+ sys.exit(1)
482
+
483
+ assert api_base is not None
484
+ assert api_key is not None
485
+
486
+ # Load template or use defaults
487
+ if template:
488
+ loaded_template, err = _load_template(template, template_args)
489
+ if err or loaded_template is None:
490
+ print(f"Error loading template: {err}", file=sys.stderr)
491
+ sys.exit(1)
492
+ tpl = loaded_template
493
+ system_prompt = tpl.interpolate_prompt(template_args or {})
494
+ # Show template info when starting without a prompt
495
+ if not prompt and tpl.description:
496
+ print(f"Template: {tpl.name}", file=sys.stderr)
497
+ print(f" {tpl.description}", file=sys.stderr)
498
+ print(file=sys.stderr)
499
+ else:
500
+ tpl = _get_default_template()
501
+ system_prompt = tpl.system_prompt
502
+
503
+ # CLI args override template values
504
+ resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
505
+
506
+ # Build endpoint and environment
507
+ endpoint = _build_endpoint(tpl, model, api_base, api_key)
508
+ environment = _build_environment(tpl, tools, corpus_path)
509
+
510
+ # Session store
511
+ session_store = FileSessionStore()
512
+ session_id = _resolve_session_id(resume, session_store)
513
+
514
+ async def run() -> None:
515
+ nonlocal session_id
516
+
517
+ # Load trajectory - either from resumed session or fresh
518
+ if session_id:
519
+ existing_session, err = await session_store.get(session_id)
520
+ if err:
521
+ print(f"Error loading session: {err}", file=sys.stderr)
522
+ sys.exit(1)
523
+ assert existing_session is not None
524
+ trajectory = Trajectory(messages=existing_session.messages)
525
+ else:
526
+ trajectory = Trajectory(messages=[Message(role="system", content=system_prompt)])
527
+
528
+ try:
529
+ if interactive:
530
+ from wafer_core.rollouts.frontends.tui.interactive_agent import (
531
+ run_interactive_agent,
532
+ )
533
+
534
+ await run_interactive_agent(
535
+ trajectory,
536
+ endpoint,
537
+ environment,
538
+ session_store,
539
+ session_id,
540
+ theme_name="minimal",
541
+ debug=False,
542
+ debug_layout=False,
543
+ initial_prompt=prompt,
544
+ )
545
+ else:
546
+ if json_output:
547
+ # Emit session_start if we have a session_id (from --resume)
548
+ model_name = endpoint.model if hasattr(endpoint, 'model') else None
549
+ frontend = StreamingChunkFrontend(session_id=session_id, model=model_name)
550
+ else:
551
+ frontend = NoneFrontend(show_tool_calls=True, show_thinking=False)
552
+ config = RunnerConfig(
553
+ session_store=session_store,
554
+ session_id=session_id,
555
+ initial_prompt=prompt,
556
+ single_turn=resolved_single_turn,
557
+ hide_session_info=True, # We print our own resume command
558
+ )
559
+ states = await run_interactive(trajectory, endpoint, frontend, environment, config)
560
+ # Emit session_start for new sessions (if session_id was None and we got one)
561
+ # Check first state to emit as early as possible
562
+ if json_output and isinstance(frontend, StreamingChunkFrontend):
563
+ first_session_id = states[0].session_id if states and states[0].session_id else None
564
+ if first_session_id and not session_id: # New session created
565
+ model_name = endpoint.model if hasattr(endpoint, 'model') else None
566
+ frontend.emit_session_start(first_session_id, model_name)
567
+ # Print resume command with full wafer agent prefix
568
+ if states and states[-1].session_id:
569
+ print(f"\nResume with: wafer agent --resume {states[-1].session_id}")
570
+ except KeyboardInterrupt:
571
+ pass
572
+ except BaseException as e:
573
+ actual_error = _unwrap_exception(e)
574
+ print(f"\n{type(actual_error).__name__}: {actual_error}", file=sys.stderr)
575
+ sys.exit(1)
576
+
577
+ trio.run(run)