wafer-cli 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,199 @@
1
+ """Test that eval and CLI agent configs are in sync.
2
+
3
+ Run: python -m pytest apps/wafer-cli/wafer/tests/test_eval_cli_parity.py -v
4
+ or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py (standalone)
5
+ or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py --dump (side-by-side)
6
+
7
+ Checks:
8
+ - Same bash allowlist (identity check — must be the same object)
9
+ - Same enabled tools
10
+ - Same system prompt text (modulo runtime variables like pool name)
11
+ - CLI instructions generated from same allowlist
12
+
13
+ Coverage notes:
14
+ These tests verify config-level parity. After composition, both paths
15
+ flow through the same codepath (Actor → run_agent_step → rollout →
16
+ anthropic.py provider) with no further system prompt modifications.
17
+
18
+ Known infra-level differences NOT tested here:
19
+ - Claude Code identity prefix: anthropic.py prepends "You are Claude Code..."
20
+ when using OAuth/Claude Code API keys. Eval typically uses raw API keys,
21
+ so eval agents may not get this prefix. This is an auth concern, not a
22
+ prompt content concern.
23
+ - Skills layer: CLI path (wevin_cli.py) can append skill metadata if
24
+ template.include_skills is True. Currently False for optimize-kernelbench,
25
+ so this is a no-op. Eval path doesn't have this layer at all.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import difflib
31
+ import string
32
+ import sys
33
+
34
+
35
+ def _normalize_prompt(prompt: str) -> str:
36
+ """Normalize runtime variables so eval and CLI prompts are comparable."""
37
+ # Replace pool/target specifics with placeholders
38
+ prompt = prompt.replace("--pool mi300x-pool", "--pool POOL")
39
+ prompt = prompt.replace("--pool kernelbench-pool", "--pool POOL")
40
+ prompt = prompt.replace("--target mi300x", "--target TARGET")
41
+ return prompt
42
+
43
+
44
+ def test_bash_allowlist_parity() -> None:
45
+ from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
46
+ from wafer.templates.optimize_kernelbench import template
47
+
48
+ # Template should use the shared allowlist (same object)
49
+ assert template.bash_allowlist is KERNELBENCH_BASH_ALLOWLIST, (
50
+ "Template bash_allowlist is not the shared KERNELBENCH_BASH_ALLOWLIST. "
51
+ "Import it from wafer.agent_defaults instead of defining a local copy."
52
+ )
53
+
54
+ # Eval should also alias to the shared allowlist
55
+ from optimize_kernelbench_eval.base_config import BASH_ALLOWLIST
56
+
57
+ assert BASH_ALLOWLIST is KERNELBENCH_BASH_ALLOWLIST, (
58
+ "Eval BASH_ALLOWLIST is not the shared KERNELBENCH_BASH_ALLOWLIST. "
59
+ "Import it from wafer.agent_defaults instead of defining a local copy."
60
+ )
61
+
62
+
63
+ def test_enabled_tools_parity() -> None:
64
+ from wafer.agent_defaults import ENABLED_TOOLS
65
+ from wafer.templates.optimize_kernelbench import template
66
+
67
+ assert template.tools == ENABLED_TOOLS, (
68
+ f"Template tools {template.tools} != shared ENABLED_TOOLS {ENABLED_TOOLS}"
69
+ )
70
+
71
+
72
+ def test_system_prompt_parity() -> None:
73
+ """The task-specific system prompt should be identical between eval and CLI
74
+ (after normalizing runtime variables like pool name)."""
75
+ from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
76
+
77
+ from wafer.templates.optimize_kernelbench import template
78
+
79
+ # Format eval prompt with HIP defaults (most common)
80
+ eval_formatted = EVAL_PROMPT.format(
81
+ backend="HIP",
82
+ backend_lower="hip",
83
+ target_flag="--pool POOL",
84
+ reference_path="<reference_file>",
85
+ )
86
+
87
+ # Format CLI template prompt with matching defaults
88
+ params = dict(template.defaults)
89
+ params["target_flag"] = "--pool POOL"
90
+ cli_formatted = string.Template(template.system_prompt).safe_substitute(**params)
91
+
92
+ eval_normalized = _normalize_prompt(eval_formatted)
93
+ cli_normalized = _normalize_prompt(cli_formatted)
94
+
95
+ if eval_normalized != cli_normalized:
96
+ diff = "\n".join(
97
+ difflib.unified_diff(
98
+ eval_normalized.splitlines(),
99
+ cli_normalized.splitlines(),
100
+ fromfile="eval (base_config.py)",
101
+ tofile="cli (optimize_kernelbench.py template)",
102
+ lineterm="",
103
+ n=2,
104
+ )
105
+ )
106
+ raise AssertionError(
107
+ f"System prompts differ between eval and CLI template:\n\n{diff}\n\n"
108
+ "Both should define the same task instructions. "
109
+ "Edit one to match the other."
110
+ )
111
+
112
+
113
+ def test_cli_instructions_identical() -> None:
114
+ """Both paths should generate the same CLI instructions
115
+ (since they use the same bash_allowlist)."""
116
+ from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
117
+ from wafer.cli_instructions import build_cli_instructions
118
+ from wafer.templates.optimize_kernelbench import template
119
+
120
+ eval_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
121
+ cli_instructions = build_cli_instructions(template.bash_allowlist)
122
+
123
+ assert eval_instructions == cli_instructions, (
124
+ "CLI instructions differ — this means the bash allowlists diverged."
125
+ )
126
+ assert len(eval_instructions) > 0, "CLI instructions should not be empty"
127
+
128
+
129
+ def _dump_full_prompts() -> None:
130
+ """Standalone: dump both composed prompts for manual comparison."""
131
+ from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
132
+
133
+ from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
134
+ from wafer.cli_instructions import build_cli_instructions
135
+ from wafer.templates.optimize_kernelbench import template
136
+
137
+ cli_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
138
+
139
+ # Eval
140
+ eval_sys = EVAL_PROMPT.format(
141
+ backend="HIP",
142
+ backend_lower="hip",
143
+ target_flag="--pool mi300x-pool",
144
+ reference_path="<reference_file>",
145
+ )
146
+ eval_sys += "\n\n" + cli_instructions
147
+
148
+ # CLI
149
+ params = dict(template.defaults)
150
+ cli_sys = string.Template(template.system_prompt).safe_substitute(**params)
151
+ cli_sys += "\n\n" + build_cli_instructions(template.bash_allowlist)
152
+
153
+ print("=" * 60)
154
+ print("EVAL SYSTEM PROMPT")
155
+ print("=" * 60)
156
+ print(eval_sys)
157
+ print()
158
+ print("=" * 60)
159
+ print("CLI SYSTEM PROMPT")
160
+ print("=" * 60)
161
+ print(cli_sys)
162
+ print()
163
+
164
+ # Diff
165
+ eval_norm = _normalize_prompt(eval_sys)
166
+ cli_norm = _normalize_prompt(cli_sys)
167
+ diff = list(
168
+ difflib.unified_diff(
169
+ eval_norm.splitlines(),
170
+ cli_norm.splitlines(),
171
+ fromfile="eval",
172
+ tofile="cli",
173
+ lineterm="",
174
+ n=1,
175
+ )
176
+ )
177
+ if diff:
178
+ print("=" * 60)
179
+ print("DIFFERENCES (after normalizing pool names)")
180
+ print("=" * 60)
181
+ for line in diff:
182
+ print(line)
183
+ else:
184
+ print("IDENTICAL (after normalizing pool names)")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ if "--dump" in sys.argv:
189
+ _dump_full_prompts()
190
+ else:
191
+ test_bash_allowlist_parity()
192
+ print("PASS: bash_allowlist_parity")
193
+ test_enabled_tools_parity()
194
+ print("PASS: enabled_tools_parity")
195
+ test_system_prompt_parity()
196
+ print("PASS: system_prompt_parity")
197
+ test_cli_instructions_identical()
198
+ print("PASS: cli_instructions_identical")
199
+ print("\nAll parity checks passed.")
wafer/trace_compare.py ADDED
@@ -0,0 +1,274 @@
1
+ """CLI wrapper for trace comparison commands.
2
+
3
+ This module provides the CLI interface for the `wafer compare` commands.
4
+ All core logic is in wafer_core.lib.trace_compare.
5
+ """
6
+
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import typer
12
+
13
+ import json
14
+ import sys
15
+
16
+ from wafer_core.lib.trace_compare import (
17
+ analyze_trace_pair,
18
+ format_csv,
19
+ format_json,
20
+ format_text,
21
+ ArchitectureType,
22
+ detect_architecture,
23
+ )
24
+ from wafer_core.lib.trace_compare.loader import StreamingMetadata
25
+
26
+
27
+ def compare_traces(
28
+ trace1: Path,
29
+ trace2: Path,
30
+ output: Path | None = None,
31
+ output_format: str = "text",
32
+ phase: str = "all",
33
+ show_layers: bool = False,
34
+ show_all: bool = False,
35
+ show_stack_traces: bool = False,
36
+ recommendations: bool = False,
37
+ ) -> None:
38
+ """Compare two GPU traces and generate performance report.
39
+
40
+ Args:
41
+ trace1: Path to first trace file (AMD or NVIDIA)
42
+ trace2: Path to second trace file (AMD or NVIDIA)
43
+ output: Optional output file path (default: stdout)
44
+ output_format: Output format ('text', 'text-layers', 'csv', 'csv-layers', or 'json')
45
+ phase: Filter by phase ('all', 'prefill', or 'decode')
46
+ show_layers: Show layer-wise performance breakdown (text format only)
47
+ show_all: Show all items without truncation (applies to layers, operations, kernels)
48
+ show_stack_traces: Show Python stack traces for operations
49
+ """
50
+ # Validate files exist
51
+ if not trace1.exists():
52
+ typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
53
+ raise typer.Exit(1)
54
+
55
+ if not trace2.exists():
56
+ typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
57
+ raise typer.Exit(1)
58
+
59
+ # Progress callback for JSON format (emits NDJSON to stdout)
60
+ def progress_callback(stage: str, fraction: float) -> None:
61
+ if output_format == 'json':
62
+ progress_msg = json.dumps({"type": "progress", "stage": stage, "fraction": fraction})
63
+ print(progress_msg, file=sys.stdout, flush=True)
64
+ elif output_format != 'json':
65
+ percent = int(fraction * 100)
66
+ typer.echo(f"📊 {stage}: {percent}%", err=True)
67
+
68
+ # Metadata callback for JSON format (emits NDJSON with early GPU info)
69
+ def metadata_callback(meta1: StreamingMetadata, meta2: StreamingMetadata) -> None:
70
+ if output_format == 'json':
71
+ metadata_msg = json.dumps({
72
+ "type": "metadata",
73
+ "trace1": {
74
+ "platform": meta1.platform,
75
+ "gpu": meta1.gpu_name,
76
+ "file_size_mb": round(meta1.file_size_mb, 1),
77
+ },
78
+ "trace2": {
79
+ "platform": meta2.platform,
80
+ "gpu": meta2.gpu_name,
81
+ "file_size_mb": round(meta2.file_size_mb, 1),
82
+ },
83
+ })
84
+ print(metadata_msg, file=sys.stdout, flush=True)
85
+ else:
86
+ typer.echo(f"📊 Trace 1: {meta1.platform} - {meta1.gpu_name} ({meta1.file_size_mb:.1f}MB)", err=True)
87
+ typer.echo(f"📊 Trace 2: {meta2.platform} - {meta2.gpu_name} ({meta2.file_size_mb:.1f}MB)", err=True)
88
+
89
+ # Analyze traces using unified API
90
+ if output_format != 'json':
91
+ typer.echo("📊 Loading traces...")
92
+
93
+ try:
94
+ result_obj = analyze_trace_pair(
95
+ trace1,
96
+ trace2,
97
+ phase=phase,
98
+ include_stacks=True,
99
+ on_progress=progress_callback,
100
+ on_metadata=metadata_callback,
101
+ )
102
+
103
+ results = {
104
+ "metadata": result_obj.metadata,
105
+ "operations": result_obj.operations,
106
+ "layers": result_obj.layers,
107
+ "warnings": [{"code": w.code, "severity": w.severity, "message": w.message, "suggestion": w.suggestion} for w in result_obj.warnings],
108
+ "architecture": result_obj.architecture.value,
109
+ "layer_alignments": result_obj.layer_alignments,
110
+ "fusion_analysis": result_obj.fusion_analysis,
111
+ "same_kernel_analysis": result_obj.same_kernel_analysis,
112
+ }
113
+ except ValueError as e:
114
+ typer.secho(f"❌ {e}", fg=typer.colors.RED, err=True)
115
+ raise typer.Exit(1)
116
+ except Exception as e:
117
+ typer.secho(f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True)
118
+ raise typer.Exit(1)
119
+
120
+ if output_format != 'json':
121
+ meta = results["metadata"]
122
+ if meta['trace1_platform'] == 'AMD':
123
+ amd_gpu, nvidia_gpu = meta['trace1_gpu'], meta['trace2_gpu']
124
+ else:
125
+ amd_gpu, nvidia_gpu = meta['trace2_gpu'], meta['trace1_gpu']
126
+ typer.echo(f"✅ Loaded: AMD ({amd_gpu}) vs NVIDIA ({nvidia_gpu})")
127
+
128
+ # Display warnings
129
+ warnings = results.get("warnings", [])
130
+ if warnings:
131
+ typer.echo()
132
+ for warning in warnings:
133
+ icon = "❌" if warning["severity"] == "error" else "⚠️" if warning["severity"] == "warning" else "ℹ️"
134
+ typer.secho(f"{icon} {warning['message']}", fg=typer.colors.YELLOW if warning["severity"] == "warning" else typer.colors.BLUE)
135
+ if warning.get("suggestion"):
136
+ typer.secho(f" Suggestion: {warning['suggestion']}", fg=typer.colors.BLUE)
137
+ typer.echo()
138
+
139
+
140
+ # Generate output based on format
141
+ if output_format == "text":
142
+ output_str = format_text(results, show_layers=show_layers, show_all=show_all, show_stack_traces=show_stack_traces)
143
+ elif output_format == "text-layers":
144
+ output_str = format_text(results, show_layers=True, show_all=show_all, show_stack_traces=show_stack_traces)
145
+ elif output_format == "csv":
146
+ output_str = format_csv(results, report_type="operations")
147
+ elif output_format == "csv-layers":
148
+ output_str = format_csv(results, report_type="layers")
149
+ elif output_format == "json":
150
+ output_str = format_json(results)
151
+ else:
152
+ typer.secho(f"❌ Unknown format: {output_format}", fg=typer.colors.RED, err=True)
153
+ raise typer.Exit(1)
154
+
155
+ # Write output
156
+ if output:
157
+ output.write_text(output_str)
158
+ typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
159
+ else:
160
+ typer.echo(output_str)
161
+
162
+
163
+ def compare_align(
164
+ trace1: Path,
165
+ trace2: Path,
166
+ output: Path | None = None,
167
+ output_format: str = "json",
168
+ phase: str = "all",
169
+ layer: int | None = None,
170
+ ) -> None:
171
+ """Align kernels at layer level for exact kernel-to-kernel comparison.
172
+
173
+ Args:
174
+ trace1: Path to first trace file (AMD or NVIDIA)
175
+ trace2: Path to second trace file (AMD or NVIDIA)
176
+ output: Optional output file path (default: stdout)
177
+ output_format: Output format ('json' only for now)
178
+ phase: Filter by phase ('all', 'prefill', or 'decode')
179
+ layer: Focus on specific layer number (optional)
180
+ """
181
+ # Validate files exist
182
+ if not trace1.exists():
183
+ typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
184
+ raise typer.Exit(1)
185
+
186
+ if not trace2.exists():
187
+ typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
188
+ raise typer.Exit(1)
189
+
190
+ # Progress callback for JSON format (emits NDJSON to stdout)
191
+ def progress_callback(stage: str, fraction: float) -> None:
192
+ if output_format == 'json':
193
+ progress_msg = json.dumps({"type": "progress", "stage": stage, "fraction": fraction})
194
+ print(progress_msg, file=sys.stdout, flush=True)
195
+ else:
196
+ percent = int(fraction * 100)
197
+ typer.echo(f"📊 {stage}: {percent}%", err=True)
198
+
199
+ # Metadata callback for JSON format
200
+ def metadata_callback(meta1: StreamingMetadata, meta2: StreamingMetadata) -> None:
201
+ if output_format == 'json':
202
+ metadata_msg = json.dumps({
203
+ "type": "metadata",
204
+ "trace1": {
205
+ "platform": meta1.platform,
206
+ "gpu": meta1.gpu_name,
207
+ "file_size_mb": round(meta1.file_size_mb, 1),
208
+ },
209
+ "trace2": {
210
+ "platform": meta2.platform,
211
+ "gpu": meta2.gpu_name,
212
+ "file_size_mb": round(meta2.file_size_mb, 1),
213
+ },
214
+ })
215
+ print(metadata_msg, file=sys.stdout, flush=True)
216
+ else:
217
+ typer.echo(f"📊 Trace 1: {meta1.platform} - {meta1.gpu_name} ({meta1.file_size_mb:.1f}MB)", err=True)
218
+ typer.echo(f"📊 Trace 2: {meta2.platform} - {meta2.gpu_name} ({meta2.file_size_mb:.1f}MB)", err=True)
219
+
220
+ # Analyze traces using unified API
221
+ if output_format != 'json':
222
+ typer.echo("📊 Loading traces...")
223
+
224
+ try:
225
+ result_obj = analyze_trace_pair(
226
+ trace1,
227
+ trace2,
228
+ phase=phase,
229
+ include_stacks=True,
230
+ on_progress=progress_callback,
231
+ on_metadata=metadata_callback,
232
+ )
233
+
234
+ results = {
235
+ "metadata": result_obj.metadata,
236
+ "layer_alignments": result_obj.layer_alignments or [],
237
+ "fusion_analysis": result_obj.fusion_analysis or {},
238
+ "same_kernel_analysis": result_obj.same_kernel_analysis or {},
239
+ "operations": result_obj.operations,
240
+ "layers": result_obj.layers,
241
+ "warnings": [{"code": w.code, "severity": w.severity, "message": w.message, "suggestion": w.suggestion} for w in result_obj.warnings],
242
+ "architecture": result_obj.architecture.value,
243
+ }
244
+
245
+ if layer is not None:
246
+ results["layer_alignments"] = [
247
+ la for la in results["layer_alignments"] if la.get("layer") == layer
248
+ ]
249
+ except ValueError as e:
250
+ typer.secho(f"❌ {e}", fg=typer.colors.RED, err=True)
251
+ raise typer.Exit(1)
252
+ except Exception as e:
253
+ typer.secho(f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True)
254
+ import traceback
255
+ traceback.print_exc()
256
+ raise typer.Exit(1)
257
+
258
+ if output_format != 'json':
259
+ meta = results["metadata"]
260
+ typer.echo(f"✅ Loaded: {meta.get('amd_gpu', 'Unknown')} vs {meta.get('nvidia_gpu', 'Unknown')}")
261
+ typer.echo(f"✅ Found {len(results['layer_alignments'])} layers")
262
+ typer.echo()
263
+
264
+ if output_format == "json":
265
+ output_str = format_json(results)
266
+ else:
267
+ typer.secho(f"❌ Format {output_format} not yet supported for align command. Use 'json'.", fg=typer.colors.RED, err=True)
268
+ raise typer.Exit(1)
269
+
270
+ if output:
271
+ output.write_text(output_str)
272
+ typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
273
+ else:
274
+ typer.echo(output_str)
wafer/wevin_cli.py CHANGED
@@ -15,6 +15,8 @@ from pathlib import Path
15
15
  from typing import TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
+ from collections.abc import Awaitable, Callable
19
+
18
20
  from wafer_core.rollouts import Endpoint, Environment
19
21
  from wafer_core.rollouts.dtypes import StreamEvent, ToolCall
20
22
  from wafer_core.rollouts.templates import TemplateConfig
@@ -145,21 +147,60 @@ class StreamingChunkFrontend:
145
147
  pass
146
148
 
147
149
 
148
- def _get_wafer_auth() -> tuple[str | None, str | None]:
150
+ def _make_wafer_token_refresh() -> Callable[[], Awaitable[str | None]]:
151
+ """Create an async callback that refreshes the wafer proxy token via Supabase."""
152
+ from .auth import load_credentials, refresh_access_token, save_credentials
153
+
154
+ async def _refresh() -> str | None:
155
+ creds = load_credentials()
156
+ if not creds or not creds.refresh_token:
157
+ return None
158
+ try:
159
+ new_access, new_refresh = refresh_access_token(creds.refresh_token)
160
+ save_credentials(new_access, new_refresh, creds.email)
161
+ return new_access
162
+ except Exception:
163
+ return None
164
+
165
+ return _refresh
166
+
167
+
168
+ def _get_wafer_auth(
169
+ *, no_proxy: bool = False
170
+ ) -> tuple[str | None, str | None, Callable[[], Awaitable[str | None]] | None]:
149
171
  """Get wafer auth credentials with fallback chain.
150
172
 
151
173
  Returns:
152
- (api_base, api_key) or (None, None) if no auth found
174
+ (api_base, api_key, api_key_refresh) or (None, None, None) if no auth found.
175
+ api_key_refresh is an async callback for mid-session token refresh (only set
176
+ when using wafer proxy via credentials file).
153
177
  """
154
178
  from .auth import get_valid_token, load_credentials
155
179
  from .global_config import get_api_url
156
180
 
181
+ if no_proxy:
182
+ api_key = os.environ.get("ANTHROPIC_API_KEY", "")
183
+ if not api_key:
184
+ # Try auth.json stored key
185
+ from wafer_core.auth import get_api_key
186
+
187
+ api_key = get_api_key("anthropic") or ""
188
+ if api_key:
189
+ print("🔑 Using ANTHROPIC_API_KEY (--no-proxy)\n", file=sys.stderr)
190
+ return "https://api.anthropic.com", api_key, None
191
+ print(
192
+ "❌ --no-proxy requires ANTHROPIC_API_KEY env var or `wafer auth login anthropic`\n",
193
+ file=sys.stderr,
194
+ )
195
+ return None, None, None
196
+
157
197
  # Check WAFER_AUTH_TOKEN env var first
158
198
  wafer_token = os.environ.get("WAFER_AUTH_TOKEN", "")
159
199
  token_source = "WAFER_AUTH_TOKEN" if wafer_token else None
160
200
 
161
201
  # Try credentials file with automatic refresh
162
202
  had_credentials = False
203
+ uses_credentials_file = False
163
204
  if not wafer_token:
164
205
  try:
165
206
  creds = load_credentials()
@@ -169,12 +210,16 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
169
210
  wafer_token = get_valid_token()
170
211
  if wafer_token:
171
212
  token_source = "~/.wafer/credentials.json"
213
+ uses_credentials_file = True
172
214
 
173
215
  # If we have a valid wafer token, use it
174
216
  if wafer_token:
175
217
  api_url = get_api_url()
176
218
  print(f"🔑 Using wafer proxy ({token_source})\n", file=sys.stderr)
177
- return f"{api_url}/v1/anthropic", wafer_token
219
+ # Only provide refresh callback when token came from credentials file
220
+ # (env var tokens are managed externally)
221
+ refresh = _make_wafer_token_refresh() if uses_credentials_file else None
222
+ return f"{api_url}/v1/anthropic", wafer_token, refresh
178
223
 
179
224
  # Fall back to direct anthropic
180
225
  api_key = os.environ.get("ANTHROPIC_API_KEY", "")
@@ -186,9 +231,9 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
186
231
  )
187
232
  else:
188
233
  print("🔑 Using ANTHROPIC_API_KEY\n", file=sys.stderr)
189
- return "https://api.anthropic.com", api_key
234
+ return "https://api.anthropic.com", api_key, None
190
235
 
191
- return None, None
236
+ return None, None, None
192
237
 
193
238
 
194
239
  def _get_session_preview(session: object) -> str:
@@ -207,7 +252,7 @@ def _get_session_preview(session: object) -> str:
207
252
 
208
253
  def _get_log_file_path() -> Path:
209
254
  """Get user-specific log file path, creating directory if needed.
210
-
255
+
211
256
  Uses ~/.wafer/logs/ to avoid permission issues with shared /tmp.
212
257
  """
213
258
  log_dir = Path.home() / ".wafer" / "logs"
@@ -255,6 +300,7 @@ def _build_endpoint(
255
300
  model_override: str | None,
256
301
  api_base: str,
257
302
  api_key: str,
303
+ api_key_refresh: Callable[[], Awaitable[str | None]] | None = None,
258
304
  ) -> Endpoint:
259
305
  """Build an Endpoint from template config and auth."""
260
306
  from wafer_core.rollouts import Endpoint
@@ -269,6 +315,7 @@ def _build_endpoint(
269
315
  model=model_id,
270
316
  api_base=api_base,
271
317
  api_key=api_key,
318
+ api_key_refresh=api_key_refresh,
272
319
  thinking=thinking_config,
273
320
  max_tokens=tpl.max_tokens,
274
321
  )
@@ -384,6 +431,7 @@ def main( # noqa: PLR0913, PLR0915
384
431
  get_session: str | None = None,
385
432
  json_output: bool = False,
386
433
  no_sandbox: bool = False,
434
+ no_proxy: bool = False,
387
435
  ) -> None:
388
436
  """Run wevin agent in-process via rollouts."""
389
437
  from dataclasses import asdict
@@ -499,10 +547,10 @@ def main( # noqa: PLR0913, PLR0915
499
547
  _setup_logging()
500
548
 
501
549
  # Auth
502
- api_base, api_key = _get_wafer_auth()
550
+ api_base, api_key, api_key_refresh = _get_wafer_auth(no_proxy=no_proxy)
503
551
  if not api_base or not api_key:
504
552
  print("Error: No API credentials found", file=sys.stderr)
505
- print(" Run 'wafer login' or set ANTHROPIC_API_KEY", file=sys.stderr)
553
+ print(" Run 'wafer auth login' or set ANTHROPIC_API_KEY", file=sys.stderr)
506
554
  sys.exit(1)
507
555
 
508
556
  assert api_base is not None
@@ -525,6 +573,17 @@ def main( # noqa: PLR0913, PLR0915
525
573
  tpl = _get_default_template()
526
574
  base_system_prompt = tpl.system_prompt
527
575
 
576
+ # Compose CLI instructions from --help text for allowed wafer commands
577
+ # TODO: The eval path doesn't have the skills layer below. If include_skills
578
+ # is ever enabled for optimize-kernelbench, the eval would need it too for parity.
579
+ # See test_eval_cli_parity.py for coverage notes.
580
+ if tpl.bash_allowlist:
581
+ from wafer.cli_instructions import build_cli_instructions
582
+
583
+ cli_instructions = build_cli_instructions(tpl.bash_allowlist)
584
+ if cli_instructions:
585
+ base_system_prompt = base_system_prompt + "\n\n" + cli_instructions
586
+
528
587
  # Append skill metadata if skills are enabled
529
588
  if tpl.include_skills:
530
589
  from wafer_core.rollouts.skills import discover_skills, format_skill_metadata_for_prompt
@@ -542,7 +601,7 @@ def main( # noqa: PLR0913, PLR0915
542
601
  resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
543
602
 
544
603
  # Build endpoint and environment
545
- endpoint = _build_endpoint(tpl, model, api_base, api_key)
604
+ endpoint = _build_endpoint(tpl, model, api_base, api_key, api_key_refresh)
546
605
  environment = _build_environment(tpl, tools, corpus_path, no_sandbox)
547
606
 
548
607
  # Session store