wafer-cli 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/billing.py +6 -6
- wafer/cli.py +502 -85
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +7 -1
- wafer/evaluate.py +13 -15
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/ssh_keys.py +6 -6
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +1 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +183 -0
- wafer/wevin_cli.py +80 -9
- wafer/workspaces.py +104 -8
- wafer_cli-0.2.25.dist-info/METADATA +107 -0
- wafer_cli-0.2.25.dist-info/RECORD +45 -0
- wafer_cli-0.2.23.dist-info/METADATA +0 -16
- wafer_cli-0.2.23.dist-info/RECORD +0 -41
- {wafer_cli-0.2.23.dist-info → wafer_cli-0.2.25.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.23.dist-info → wafer_cli-0.2.25.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.23.dist-info → wafer_cli-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Test that eval and CLI agent configs are in sync.
|
|
2
|
+
|
|
3
|
+
Run: python -m pytest apps/wafer-cli/wafer/tests/test_eval_cli_parity.py -v
|
|
4
|
+
or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py (standalone)
|
|
5
|
+
or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py --dump (side-by-side)
|
|
6
|
+
|
|
7
|
+
Checks:
|
|
8
|
+
- Same bash allowlist (identity check — must be the same object)
|
|
9
|
+
- Same enabled tools
|
|
10
|
+
- Same system prompt text (modulo runtime variables like pool name)
|
|
11
|
+
- CLI instructions generated from same allowlist
|
|
12
|
+
|
|
13
|
+
Coverage notes:
|
|
14
|
+
These tests verify config-level parity. After composition, both paths
|
|
15
|
+
flow through the same codepath (Actor → run_agent_step → rollout →
|
|
16
|
+
anthropic.py provider) with no further system prompt modifications.
|
|
17
|
+
|
|
18
|
+
Known infra-level differences NOT tested here:
|
|
19
|
+
- Claude Code identity prefix: anthropic.py prepends "You are Claude Code..."
|
|
20
|
+
when using OAuth/Claude Code API keys. Eval typically uses raw API keys,
|
|
21
|
+
so eval agents may not get this prefix. This is an auth concern, not a
|
|
22
|
+
prompt content concern.
|
|
23
|
+
- Skills layer: CLI path (wevin_cli.py) can append skill metadata if
|
|
24
|
+
template.include_skills is True. Currently False for optimize-kernelbench,
|
|
25
|
+
so this is a no-op. Eval path doesn't have this layer at all.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import difflib
|
|
31
|
+
import string
|
|
32
|
+
import sys
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _normalize_prompt(prompt: str) -> str:
|
|
36
|
+
"""Normalize runtime variables so eval and CLI prompts are comparable."""
|
|
37
|
+
# Replace pool/target specifics with placeholders
|
|
38
|
+
prompt = prompt.replace("--pool mi300x-pool", "--pool POOL")
|
|
39
|
+
prompt = prompt.replace("--pool kernelbench-pool", "--pool POOL")
|
|
40
|
+
prompt = prompt.replace("--target mi300x", "--target TARGET")
|
|
41
|
+
return prompt
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_bash_allowlist_parity() -> None:
|
|
45
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
46
|
+
from wafer.templates.optimize_kernelbench import template
|
|
47
|
+
|
|
48
|
+
# Template should use the shared allowlist (same object)
|
|
49
|
+
assert template.bash_allowlist is KERNELBENCH_BASH_ALLOWLIST, (
|
|
50
|
+
"Template bash_allowlist is not the shared KERNELBENCH_BASH_ALLOWLIST. "
|
|
51
|
+
"Import it from wafer.agent_defaults instead of defining a local copy."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Eval should also alias to the shared allowlist
|
|
55
|
+
from optimize_kernelbench_eval.base_config import BASH_ALLOWLIST
|
|
56
|
+
|
|
57
|
+
assert BASH_ALLOWLIST is KERNELBENCH_BASH_ALLOWLIST, (
|
|
58
|
+
"Eval BASH_ALLOWLIST is not the shared KERNELBENCH_BASH_ALLOWLIST. "
|
|
59
|
+
"Import it from wafer.agent_defaults instead of defining a local copy."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_enabled_tools_parity() -> None:
|
|
64
|
+
from wafer.agent_defaults import ENABLED_TOOLS
|
|
65
|
+
from wafer.templates.optimize_kernelbench import template
|
|
66
|
+
|
|
67
|
+
assert template.tools == ENABLED_TOOLS, (
|
|
68
|
+
f"Template tools {template.tools} != shared ENABLED_TOOLS {ENABLED_TOOLS}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_system_prompt_parity() -> None:
|
|
73
|
+
"""The task-specific system prompt should be identical between eval and CLI
|
|
74
|
+
(after normalizing runtime variables like pool name)."""
|
|
75
|
+
from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
|
|
76
|
+
|
|
77
|
+
from wafer.templates.optimize_kernelbench import template
|
|
78
|
+
|
|
79
|
+
# Format eval prompt with HIP defaults (most common)
|
|
80
|
+
eval_formatted = EVAL_PROMPT.format(
|
|
81
|
+
backend="HIP",
|
|
82
|
+
backend_lower="hip",
|
|
83
|
+
target_flag="--pool POOL",
|
|
84
|
+
reference_path="<reference_file>",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Format CLI template prompt with matching defaults
|
|
88
|
+
params = dict(template.defaults)
|
|
89
|
+
params["target_flag"] = "--pool POOL"
|
|
90
|
+
cli_formatted = string.Template(template.system_prompt).safe_substitute(**params)
|
|
91
|
+
|
|
92
|
+
eval_normalized = _normalize_prompt(eval_formatted)
|
|
93
|
+
cli_normalized = _normalize_prompt(cli_formatted)
|
|
94
|
+
|
|
95
|
+
if eval_normalized != cli_normalized:
|
|
96
|
+
diff = "\n".join(
|
|
97
|
+
difflib.unified_diff(
|
|
98
|
+
eval_normalized.splitlines(),
|
|
99
|
+
cli_normalized.splitlines(),
|
|
100
|
+
fromfile="eval (base_config.py)",
|
|
101
|
+
tofile="cli (optimize_kernelbench.py template)",
|
|
102
|
+
lineterm="",
|
|
103
|
+
n=2,
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
raise AssertionError(
|
|
107
|
+
f"System prompts differ between eval and CLI template:\n\n{diff}\n\n"
|
|
108
|
+
"Both should define the same task instructions. "
|
|
109
|
+
"Edit one to match the other."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_cli_instructions_identical() -> None:
|
|
114
|
+
"""Both paths should generate the same CLI instructions
|
|
115
|
+
(since they use the same bash_allowlist)."""
|
|
116
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
117
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
118
|
+
from wafer.templates.optimize_kernelbench import template
|
|
119
|
+
|
|
120
|
+
eval_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
|
|
121
|
+
cli_instructions = build_cli_instructions(template.bash_allowlist)
|
|
122
|
+
|
|
123
|
+
assert eval_instructions == cli_instructions, (
|
|
124
|
+
"CLI instructions differ — this means the bash allowlists diverged."
|
|
125
|
+
)
|
|
126
|
+
assert len(eval_instructions) > 0, "CLI instructions should not be empty"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _dump_full_prompts() -> None:
|
|
130
|
+
"""Standalone: dump both composed prompts for manual comparison."""
|
|
131
|
+
from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
|
|
132
|
+
|
|
133
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
134
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
135
|
+
from wafer.templates.optimize_kernelbench import template
|
|
136
|
+
|
|
137
|
+
cli_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
|
|
138
|
+
|
|
139
|
+
# Eval
|
|
140
|
+
eval_sys = EVAL_PROMPT.format(
|
|
141
|
+
backend="HIP",
|
|
142
|
+
backend_lower="hip",
|
|
143
|
+
target_flag="--pool mi300x-pool",
|
|
144
|
+
reference_path="<reference_file>",
|
|
145
|
+
)
|
|
146
|
+
eval_sys += "\n\n" + cli_instructions
|
|
147
|
+
|
|
148
|
+
# CLI
|
|
149
|
+
params = dict(template.defaults)
|
|
150
|
+
cli_sys = string.Template(template.system_prompt).safe_substitute(**params)
|
|
151
|
+
cli_sys += "\n\n" + build_cli_instructions(template.bash_allowlist)
|
|
152
|
+
|
|
153
|
+
print("=" * 60)
|
|
154
|
+
print("EVAL SYSTEM PROMPT")
|
|
155
|
+
print("=" * 60)
|
|
156
|
+
print(eval_sys)
|
|
157
|
+
print()
|
|
158
|
+
print("=" * 60)
|
|
159
|
+
print("CLI SYSTEM PROMPT")
|
|
160
|
+
print("=" * 60)
|
|
161
|
+
print(cli_sys)
|
|
162
|
+
print()
|
|
163
|
+
|
|
164
|
+
# Diff
|
|
165
|
+
eval_norm = _normalize_prompt(eval_sys)
|
|
166
|
+
cli_norm = _normalize_prompt(cli_sys)
|
|
167
|
+
diff = list(
|
|
168
|
+
difflib.unified_diff(
|
|
169
|
+
eval_norm.splitlines(),
|
|
170
|
+
cli_norm.splitlines(),
|
|
171
|
+
fromfile="eval",
|
|
172
|
+
tofile="cli",
|
|
173
|
+
lineterm="",
|
|
174
|
+
n=1,
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
if diff:
|
|
178
|
+
print("=" * 60)
|
|
179
|
+
print("DIFFERENCES (after normalizing pool names)")
|
|
180
|
+
print("=" * 60)
|
|
181
|
+
for line in diff:
|
|
182
|
+
print(line)
|
|
183
|
+
else:
|
|
184
|
+
print("IDENTICAL (after normalizing pool names)")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if __name__ == "__main__":
|
|
188
|
+
if "--dump" in sys.argv:
|
|
189
|
+
_dump_full_prompts()
|
|
190
|
+
else:
|
|
191
|
+
test_bash_allowlist_parity()
|
|
192
|
+
print("PASS: bash_allowlist_parity")
|
|
193
|
+
test_enabled_tools_parity()
|
|
194
|
+
print("PASS: enabled_tools_parity")
|
|
195
|
+
test_system_prompt_parity()
|
|
196
|
+
print("PASS: system_prompt_parity")
|
|
197
|
+
test_cli_instructions_identical()
|
|
198
|
+
print("PASS: cli_instructions_identical")
|
|
199
|
+
print("\nAll parity checks passed.")
|
wafer/trace_compare.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""CLI wrapper for trace comparison commands.
|
|
2
|
+
|
|
3
|
+
This module provides the CLI interface for the `wafer compare` commands.
|
|
4
|
+
All core logic is in wafer_core.lib.trace_compare.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import typer
|
|
11
|
+
|
|
12
|
+
from wafer_core.lib.trace_compare import (
|
|
13
|
+
analyze_fusion_differences,
|
|
14
|
+
analyze_traces,
|
|
15
|
+
format_csv,
|
|
16
|
+
format_fusion_csv,
|
|
17
|
+
format_fusion_json,
|
|
18
|
+
format_fusion_text,
|
|
19
|
+
format_json,
|
|
20
|
+
format_text,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compare_traces(
|
|
25
|
+
trace1: Path,
|
|
26
|
+
trace2: Path,
|
|
27
|
+
output: Path | None = None,
|
|
28
|
+
output_format: str = "text",
|
|
29
|
+
phase: str = "all",
|
|
30
|
+
show_layers: bool = False,
|
|
31
|
+
show_all: bool = False,
|
|
32
|
+
show_stack_traces: bool = False,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Compare two GPU traces and generate performance report.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
trace1: Path to first trace file (AMD or NVIDIA)
|
|
38
|
+
trace2: Path to second trace file (AMD or NVIDIA)
|
|
39
|
+
output: Optional output file path (default: stdout)
|
|
40
|
+
output_format: Output format ('text', 'text-layers', 'csv', 'csv-layers', or 'json')
|
|
41
|
+
phase: Filter by phase ('all', 'prefill', or 'decode')
|
|
42
|
+
show_layers: Show layer-wise performance breakdown (text format only)
|
|
43
|
+
show_all: Show all items without truncation (applies to layers, operations, kernels)
|
|
44
|
+
show_stack_traces: Show Python stack traces for operations
|
|
45
|
+
"""
|
|
46
|
+
# Validate files exist
|
|
47
|
+
if not trace1.exists():
|
|
48
|
+
typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
|
|
49
|
+
raise typer.Exit(1)
|
|
50
|
+
|
|
51
|
+
if not trace2.exists():
|
|
52
|
+
typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
|
|
53
|
+
raise typer.Exit(1)
|
|
54
|
+
|
|
55
|
+
# Analyze traces
|
|
56
|
+
# Only show progress messages for non-JSON formats (JSON needs clean stdout)
|
|
57
|
+
if output_format != 'json':
|
|
58
|
+
typer.echo("📊 Loading traces...")
|
|
59
|
+
|
|
60
|
+
# Determine how many stack traces to collect
|
|
61
|
+
max_stacks = 0 if (show_stack_traces and show_all) else (3 if show_stack_traces else 3)
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
results = analyze_traces(
|
|
65
|
+
trace1,
|
|
66
|
+
trace2,
|
|
67
|
+
phase_filter=phase,
|
|
68
|
+
max_stacks=max_stacks,
|
|
69
|
+
)
|
|
70
|
+
except ValueError as e:
|
|
71
|
+
typer.secho(f"❌ {e}", fg=typer.colors.RED, err=True)
|
|
72
|
+
raise typer.Exit(1)
|
|
73
|
+
except Exception as e:
|
|
74
|
+
typer.secho(f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True)
|
|
75
|
+
raise typer.Exit(1)
|
|
76
|
+
|
|
77
|
+
# Show loading confirmation
|
|
78
|
+
if output_format != 'json':
|
|
79
|
+
meta = results["metadata"]
|
|
80
|
+
# Determine which trace is AMD and which is NVIDIA
|
|
81
|
+
if meta['trace1_platform'] == 'AMD':
|
|
82
|
+
amd_gpu, nvidia_gpu = meta['trace1_gpu'], meta['trace2_gpu']
|
|
83
|
+
else:
|
|
84
|
+
amd_gpu, nvidia_gpu = meta['trace2_gpu'], meta['trace1_gpu']
|
|
85
|
+
typer.echo(f"✅ Loaded: AMD ({amd_gpu}) vs NVIDIA ({nvidia_gpu})")
|
|
86
|
+
typer.echo()
|
|
87
|
+
|
|
88
|
+
# Generate output based on format
|
|
89
|
+
if output_format == "text":
|
|
90
|
+
output_str = format_text(results, show_layers=show_layers, show_all=show_all, show_stack_traces=show_stack_traces)
|
|
91
|
+
elif output_format == "text-layers":
|
|
92
|
+
output_str = format_text(results, show_layers=True, show_all=show_all, show_stack_traces=show_stack_traces)
|
|
93
|
+
elif output_format == "csv":
|
|
94
|
+
output_str = format_csv(results, report_type="operations")
|
|
95
|
+
elif output_format == "csv-layers":
|
|
96
|
+
output_str = format_csv(results, report_type="layers")
|
|
97
|
+
elif output_format == "json":
|
|
98
|
+
output_str = format_json(results)
|
|
99
|
+
else:
|
|
100
|
+
typer.secho(f"❌ Unknown format: {output_format}", fg=typer.colors.RED, err=True)
|
|
101
|
+
raise typer.Exit(1)
|
|
102
|
+
|
|
103
|
+
# Write output
|
|
104
|
+
if output:
|
|
105
|
+
output.write_text(output_str)
|
|
106
|
+
typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
|
|
107
|
+
else:
|
|
108
|
+
typer.echo(output_str)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def compare_fusion(
|
|
112
|
+
trace1: Path,
|
|
113
|
+
trace2: Path,
|
|
114
|
+
output: Path | None = None,
|
|
115
|
+
format_type: str = "text",
|
|
116
|
+
min_group_size: int = 50,
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Analyze kernel fusion differences between AMD and NVIDIA traces.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
trace1: Path to first trace file (AMD or NVIDIA)
|
|
122
|
+
trace2: Path to second trace file (AMD or NVIDIA)
|
|
123
|
+
output: Optional output file path (default: stdout)
|
|
124
|
+
format_type: Output format ('text', 'csv', or 'json')
|
|
125
|
+
min_group_size: Minimum correlation group size to analyze
|
|
126
|
+
"""
|
|
127
|
+
# Validate files exist
|
|
128
|
+
if not trace1.exists():
|
|
129
|
+
typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
|
|
130
|
+
raise typer.Exit(1)
|
|
131
|
+
|
|
132
|
+
if not trace2.exists():
|
|
133
|
+
typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
|
|
134
|
+
raise typer.Exit(1)
|
|
135
|
+
|
|
136
|
+
# Analyze fusion
|
|
137
|
+
# Only show progress messages for non-JSON formats (JSON needs clean stdout)
|
|
138
|
+
if format_type != 'json':
|
|
139
|
+
typer.echo("📊 Loading traces...")
|
|
140
|
+
try:
|
|
141
|
+
results = analyze_fusion_differences(
|
|
142
|
+
trace1,
|
|
143
|
+
trace2,
|
|
144
|
+
min_group_size=min_group_size,
|
|
145
|
+
)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
typer.secho(
|
|
148
|
+
f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True
|
|
149
|
+
)
|
|
150
|
+
import traceback
|
|
151
|
+
|
|
152
|
+
traceback.print_exc()
|
|
153
|
+
raise typer.Exit(1)
|
|
154
|
+
|
|
155
|
+
# Show loading confirmation
|
|
156
|
+
if format_type != 'json':
|
|
157
|
+
meta = results["metadata"]
|
|
158
|
+
# Note: fusion analyzer always uses trace1=AMD, trace2=NVIDIA
|
|
159
|
+
typer.echo(f"✅ Loaded: {meta['trace1_gpu']} vs {meta['trace2_gpu']}")
|
|
160
|
+
typer.echo(
|
|
161
|
+
f"Found {meta['trace1_correlation_groups']} trace1 groups and "
|
|
162
|
+
f"{meta['trace2_correlation_groups']} trace2 groups with ≥{min_group_size} kernels"
|
|
163
|
+
)
|
|
164
|
+
typer.echo(f"✅ Matched {meta['matched_groups']} correlation groups")
|
|
165
|
+
typer.echo()
|
|
166
|
+
|
|
167
|
+
# Generate output
|
|
168
|
+
if format_type == "text":
|
|
169
|
+
output_str = format_fusion_text(results)
|
|
170
|
+
elif format_type == "csv":
|
|
171
|
+
output_str = format_fusion_csv(results)
|
|
172
|
+
elif format_type == "json":
|
|
173
|
+
output_str = format_fusion_json(results)
|
|
174
|
+
else:
|
|
175
|
+
typer.secho(f"❌ Unknown format: {format_type}", fg=typer.colors.RED, err=True)
|
|
176
|
+
raise typer.Exit(1)
|
|
177
|
+
|
|
178
|
+
# Write output
|
|
179
|
+
if output:
|
|
180
|
+
output.write_text(output_str)
|
|
181
|
+
typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
|
|
182
|
+
else:
|
|
183
|
+
typer.echo(output_str)
|
wafer/wevin_cli.py
CHANGED
|
@@ -15,6 +15,8 @@ from pathlib import Path
|
|
|
15
15
|
from typing import TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Awaitable, Callable
|
|
19
|
+
|
|
18
20
|
from wafer_core.rollouts import Endpoint, Environment
|
|
19
21
|
from wafer_core.rollouts.dtypes import StreamEvent, ToolCall
|
|
20
22
|
from wafer_core.rollouts.templates import TemplateConfig
|
|
@@ -145,21 +147,60 @@ class StreamingChunkFrontend:
|
|
|
145
147
|
pass
|
|
146
148
|
|
|
147
149
|
|
|
148
|
-
def
|
|
150
|
+
def _make_wafer_token_refresh() -> Callable[[], Awaitable[str | None]]:
|
|
151
|
+
"""Create an async callback that refreshes the wafer proxy token via Supabase."""
|
|
152
|
+
from .auth import load_credentials, refresh_access_token, save_credentials
|
|
153
|
+
|
|
154
|
+
async def _refresh() -> str | None:
|
|
155
|
+
creds = load_credentials()
|
|
156
|
+
if not creds or not creds.refresh_token:
|
|
157
|
+
return None
|
|
158
|
+
try:
|
|
159
|
+
new_access, new_refresh = refresh_access_token(creds.refresh_token)
|
|
160
|
+
save_credentials(new_access, new_refresh, creds.email)
|
|
161
|
+
return new_access
|
|
162
|
+
except Exception:
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
return _refresh
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _get_wafer_auth(
|
|
169
|
+
*, no_proxy: bool = False
|
|
170
|
+
) -> tuple[str | None, str | None, Callable[[], Awaitable[str | None]] | None]:
|
|
149
171
|
"""Get wafer auth credentials with fallback chain.
|
|
150
172
|
|
|
151
173
|
Returns:
|
|
152
|
-
(api_base, api_key) or (None, None) if no auth found
|
|
174
|
+
(api_base, api_key, api_key_refresh) or (None, None, None) if no auth found.
|
|
175
|
+
api_key_refresh is an async callback for mid-session token refresh (only set
|
|
176
|
+
when using wafer proxy via credentials file).
|
|
153
177
|
"""
|
|
154
178
|
from .auth import get_valid_token, load_credentials
|
|
155
179
|
from .global_config import get_api_url
|
|
156
180
|
|
|
181
|
+
if no_proxy:
|
|
182
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
183
|
+
if not api_key:
|
|
184
|
+
# Try auth.json stored key
|
|
185
|
+
from wafer_core.auth import get_api_key
|
|
186
|
+
|
|
187
|
+
api_key = get_api_key("anthropic") or ""
|
|
188
|
+
if api_key:
|
|
189
|
+
print("🔑 Using ANTHROPIC_API_KEY (--no-proxy)\n", file=sys.stderr)
|
|
190
|
+
return "https://api.anthropic.com", api_key, None
|
|
191
|
+
print(
|
|
192
|
+
"❌ --no-proxy requires ANTHROPIC_API_KEY env var or `wafer auth login anthropic`\n",
|
|
193
|
+
file=sys.stderr,
|
|
194
|
+
)
|
|
195
|
+
return None, None, None
|
|
196
|
+
|
|
157
197
|
# Check WAFER_AUTH_TOKEN env var first
|
|
158
198
|
wafer_token = os.environ.get("WAFER_AUTH_TOKEN", "")
|
|
159
199
|
token_source = "WAFER_AUTH_TOKEN" if wafer_token else None
|
|
160
200
|
|
|
161
201
|
# Try credentials file with automatic refresh
|
|
162
202
|
had_credentials = False
|
|
203
|
+
uses_credentials_file = False
|
|
163
204
|
if not wafer_token:
|
|
164
205
|
try:
|
|
165
206
|
creds = load_credentials()
|
|
@@ -169,12 +210,16 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
169
210
|
wafer_token = get_valid_token()
|
|
170
211
|
if wafer_token:
|
|
171
212
|
token_source = "~/.wafer/credentials.json"
|
|
213
|
+
uses_credentials_file = True
|
|
172
214
|
|
|
173
215
|
# If we have a valid wafer token, use it
|
|
174
216
|
if wafer_token:
|
|
175
217
|
api_url = get_api_url()
|
|
176
218
|
print(f"🔑 Using wafer proxy ({token_source})\n", file=sys.stderr)
|
|
177
|
-
|
|
219
|
+
# Only provide refresh callback when token came from credentials file
|
|
220
|
+
# (env var tokens are managed externally)
|
|
221
|
+
refresh = _make_wafer_token_refresh() if uses_credentials_file else None
|
|
222
|
+
return f"{api_url}/v1/anthropic", wafer_token, refresh
|
|
178
223
|
|
|
179
224
|
# Fall back to direct anthropic
|
|
180
225
|
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
@@ -186,9 +231,9 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
186
231
|
)
|
|
187
232
|
else:
|
|
188
233
|
print("🔑 Using ANTHROPIC_API_KEY\n", file=sys.stderr)
|
|
189
|
-
return "https://api.anthropic.com", api_key
|
|
234
|
+
return "https://api.anthropic.com", api_key, None
|
|
190
235
|
|
|
191
|
-
return None, None
|
|
236
|
+
return None, None, None
|
|
192
237
|
|
|
193
238
|
|
|
194
239
|
def _get_session_preview(session: object) -> str:
|
|
@@ -205,10 +250,22 @@ def _get_session_preview(session: object) -> str:
|
|
|
205
250
|
return ""
|
|
206
251
|
|
|
207
252
|
|
|
253
|
+
def _get_log_file_path() -> Path:
|
|
254
|
+
"""Get user-specific log file path, creating directory if needed.
|
|
255
|
+
|
|
256
|
+
Uses ~/.wafer/logs/ to avoid permission issues with shared /tmp.
|
|
257
|
+
"""
|
|
258
|
+
log_dir = Path.home() / ".wafer" / "logs"
|
|
259
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
260
|
+
return log_dir / "wevin_debug.log"
|
|
261
|
+
|
|
262
|
+
|
|
208
263
|
def _setup_logging() -> None:
|
|
209
264
|
"""Configure logging to file only (no console spam)."""
|
|
210
265
|
import logging.config
|
|
211
266
|
|
|
267
|
+
log_file = _get_log_file_path()
|
|
268
|
+
|
|
212
269
|
logging.config.dictConfig({
|
|
213
270
|
"version": 1,
|
|
214
271
|
"disable_existing_loggers": False,
|
|
@@ -220,7 +277,7 @@ def _setup_logging() -> None:
|
|
|
220
277
|
"handlers": {
|
|
221
278
|
"file": {
|
|
222
279
|
"class": "logging.handlers.RotatingFileHandler",
|
|
223
|
-
"filename":
|
|
280
|
+
"filename": str(log_file),
|
|
224
281
|
"maxBytes": 10_000_000,
|
|
225
282
|
"backupCount": 3,
|
|
226
283
|
"formatter": "json",
|
|
@@ -243,6 +300,7 @@ def _build_endpoint(
|
|
|
243
300
|
model_override: str | None,
|
|
244
301
|
api_base: str,
|
|
245
302
|
api_key: str,
|
|
303
|
+
api_key_refresh: Callable[[], Awaitable[str | None]] | None = None,
|
|
246
304
|
) -> Endpoint:
|
|
247
305
|
"""Build an Endpoint from template config and auth."""
|
|
248
306
|
from wafer_core.rollouts import Endpoint
|
|
@@ -257,6 +315,7 @@ def _build_endpoint(
|
|
|
257
315
|
model=model_id,
|
|
258
316
|
api_base=api_base,
|
|
259
317
|
api_key=api_key,
|
|
318
|
+
api_key_refresh=api_key_refresh,
|
|
260
319
|
thinking=thinking_config,
|
|
261
320
|
max_tokens=tpl.max_tokens,
|
|
262
321
|
)
|
|
@@ -372,6 +431,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
372
431
|
get_session: str | None = None,
|
|
373
432
|
json_output: bool = False,
|
|
374
433
|
no_sandbox: bool = False,
|
|
434
|
+
no_proxy: bool = False,
|
|
375
435
|
) -> None:
|
|
376
436
|
"""Run wevin agent in-process via rollouts."""
|
|
377
437
|
from dataclasses import asdict
|
|
@@ -487,10 +547,10 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
487
547
|
_setup_logging()
|
|
488
548
|
|
|
489
549
|
# Auth
|
|
490
|
-
api_base, api_key = _get_wafer_auth()
|
|
550
|
+
api_base, api_key, api_key_refresh = _get_wafer_auth(no_proxy=no_proxy)
|
|
491
551
|
if not api_base or not api_key:
|
|
492
552
|
print("Error: No API credentials found", file=sys.stderr)
|
|
493
|
-
print(" Run 'wafer login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
553
|
+
print(" Run 'wafer auth login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
494
554
|
sys.exit(1)
|
|
495
555
|
|
|
496
556
|
assert api_base is not None
|
|
@@ -513,6 +573,17 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
513
573
|
tpl = _get_default_template()
|
|
514
574
|
base_system_prompt = tpl.system_prompt
|
|
515
575
|
|
|
576
|
+
# Compose CLI instructions from --help text for allowed wafer commands
|
|
577
|
+
# TODO: The eval path doesn't have the skills layer below. If include_skills
|
|
578
|
+
# is ever enabled for optimize-kernelbench, the eval would need it too for parity.
|
|
579
|
+
# See test_eval_cli_parity.py for coverage notes.
|
|
580
|
+
if tpl.bash_allowlist:
|
|
581
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
582
|
+
|
|
583
|
+
cli_instructions = build_cli_instructions(tpl.bash_allowlist)
|
|
584
|
+
if cli_instructions:
|
|
585
|
+
base_system_prompt = base_system_prompt + "\n\n" + cli_instructions
|
|
586
|
+
|
|
516
587
|
# Append skill metadata if skills are enabled
|
|
517
588
|
if tpl.include_skills:
|
|
518
589
|
from wafer_core.rollouts.skills import discover_skills, format_skill_metadata_for_prompt
|
|
@@ -530,7 +601,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
530
601
|
resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
|
|
531
602
|
|
|
532
603
|
# Build endpoint and environment
|
|
533
|
-
endpoint = _build_endpoint(tpl, model, api_base, api_key)
|
|
604
|
+
endpoint = _build_endpoint(tpl, model, api_base, api_key, api_key_refresh)
|
|
534
605
|
environment = _build_environment(tpl, tools, corpus_path, no_sandbox)
|
|
535
606
|
|
|
536
607
|
# Session store
|