wafer-cli 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/billing.py +6 -6
- wafer/cli.py +454 -86
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +7 -1
- wafer/evaluate.py +13 -6
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/ssh_keys.py +6 -6
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +1 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +183 -0
- wafer/wevin_cli.py +68 -9
- wafer/workspaces.py +8 -8
- wafer_cli-0.2.25.dist-info/METADATA +107 -0
- wafer_cli-0.2.25.dist-info/RECORD +45 -0
- wafer_cli-0.2.24.dist-info/METADATA +0 -16
- wafer_cli-0.2.24.dist-info/RECORD +0 -41
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.25.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.25.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Test that eval and CLI agent configs are in sync.
|
|
2
|
+
|
|
3
|
+
Run: python -m pytest apps/wafer-cli/wafer/tests/test_eval_cli_parity.py -v
|
|
4
|
+
or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py (standalone)
|
|
5
|
+
or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py --dump (side-by-side)
|
|
6
|
+
|
|
7
|
+
Checks:
|
|
8
|
+
- Same bash allowlist (identity check — must be the same object)
|
|
9
|
+
- Same enabled tools
|
|
10
|
+
- Same system prompt text (modulo runtime variables like pool name)
|
|
11
|
+
- CLI instructions generated from same allowlist
|
|
12
|
+
|
|
13
|
+
Coverage notes:
|
|
14
|
+
These tests verify config-level parity. After composition, both paths
|
|
15
|
+
flow through the same codepath (Actor → run_agent_step → rollout →
|
|
16
|
+
anthropic.py provider) with no further system prompt modifications.
|
|
17
|
+
|
|
18
|
+
Known infra-level differences NOT tested here:
|
|
19
|
+
- Claude Code identity prefix: anthropic.py prepends "You are Claude Code..."
|
|
20
|
+
when using OAuth/Claude Code API keys. Eval typically uses raw API keys,
|
|
21
|
+
so eval agents may not get this prefix. This is an auth concern, not a
|
|
22
|
+
prompt content concern.
|
|
23
|
+
- Skills layer: CLI path (wevin_cli.py) can append skill metadata if
|
|
24
|
+
template.include_skills is True. Currently False for optimize-kernelbench,
|
|
25
|
+
so this is a no-op. Eval path doesn't have this layer at all.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import difflib
|
|
31
|
+
import string
|
|
32
|
+
import sys
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _normalize_prompt(prompt: str) -> str:
|
|
36
|
+
"""Normalize runtime variables so eval and CLI prompts are comparable."""
|
|
37
|
+
# Replace pool/target specifics with placeholders
|
|
38
|
+
prompt = prompt.replace("--pool mi300x-pool", "--pool POOL")
|
|
39
|
+
prompt = prompt.replace("--pool kernelbench-pool", "--pool POOL")
|
|
40
|
+
prompt = prompt.replace("--target mi300x", "--target TARGET")
|
|
41
|
+
return prompt
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_bash_allowlist_parity() -> None:
|
|
45
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
46
|
+
from wafer.templates.optimize_kernelbench import template
|
|
47
|
+
|
|
48
|
+
# Template should use the shared allowlist (same object)
|
|
49
|
+
assert template.bash_allowlist is KERNELBENCH_BASH_ALLOWLIST, (
|
|
50
|
+
"Template bash_allowlist is not the shared KERNELBENCH_BASH_ALLOWLIST. "
|
|
51
|
+
"Import it from wafer.agent_defaults instead of defining a local copy."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Eval should also alias to the shared allowlist
|
|
55
|
+
from optimize_kernelbench_eval.base_config import BASH_ALLOWLIST
|
|
56
|
+
|
|
57
|
+
assert BASH_ALLOWLIST is KERNELBENCH_BASH_ALLOWLIST, (
|
|
58
|
+
"Eval BASH_ALLOWLIST is not the shared KERNELBENCH_BASH_ALLOWLIST. "
|
|
59
|
+
"Import it from wafer.agent_defaults instead of defining a local copy."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_enabled_tools_parity() -> None:
|
|
64
|
+
from wafer.agent_defaults import ENABLED_TOOLS
|
|
65
|
+
from wafer.templates.optimize_kernelbench import template
|
|
66
|
+
|
|
67
|
+
assert template.tools == ENABLED_TOOLS, (
|
|
68
|
+
f"Template tools {template.tools} != shared ENABLED_TOOLS {ENABLED_TOOLS}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_system_prompt_parity() -> None:
|
|
73
|
+
"""The task-specific system prompt should be identical between eval and CLI
|
|
74
|
+
(after normalizing runtime variables like pool name)."""
|
|
75
|
+
from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
|
|
76
|
+
|
|
77
|
+
from wafer.templates.optimize_kernelbench import template
|
|
78
|
+
|
|
79
|
+
# Format eval prompt with HIP defaults (most common)
|
|
80
|
+
eval_formatted = EVAL_PROMPT.format(
|
|
81
|
+
backend="HIP",
|
|
82
|
+
backend_lower="hip",
|
|
83
|
+
target_flag="--pool POOL",
|
|
84
|
+
reference_path="<reference_file>",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Format CLI template prompt with matching defaults
|
|
88
|
+
params = dict(template.defaults)
|
|
89
|
+
params["target_flag"] = "--pool POOL"
|
|
90
|
+
cli_formatted = string.Template(template.system_prompt).safe_substitute(**params)
|
|
91
|
+
|
|
92
|
+
eval_normalized = _normalize_prompt(eval_formatted)
|
|
93
|
+
cli_normalized = _normalize_prompt(cli_formatted)
|
|
94
|
+
|
|
95
|
+
if eval_normalized != cli_normalized:
|
|
96
|
+
diff = "\n".join(
|
|
97
|
+
difflib.unified_diff(
|
|
98
|
+
eval_normalized.splitlines(),
|
|
99
|
+
cli_normalized.splitlines(),
|
|
100
|
+
fromfile="eval (base_config.py)",
|
|
101
|
+
tofile="cli (optimize_kernelbench.py template)",
|
|
102
|
+
lineterm="",
|
|
103
|
+
n=2,
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
raise AssertionError(
|
|
107
|
+
f"System prompts differ between eval and CLI template:\n\n{diff}\n\n"
|
|
108
|
+
"Both should define the same task instructions. "
|
|
109
|
+
"Edit one to match the other."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_cli_instructions_identical() -> None:
|
|
114
|
+
"""Both paths should generate the same CLI instructions
|
|
115
|
+
(since they use the same bash_allowlist)."""
|
|
116
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
117
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
118
|
+
from wafer.templates.optimize_kernelbench import template
|
|
119
|
+
|
|
120
|
+
eval_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
|
|
121
|
+
cli_instructions = build_cli_instructions(template.bash_allowlist)
|
|
122
|
+
|
|
123
|
+
assert eval_instructions == cli_instructions, (
|
|
124
|
+
"CLI instructions differ — this means the bash allowlists diverged."
|
|
125
|
+
)
|
|
126
|
+
assert len(eval_instructions) > 0, "CLI instructions should not be empty"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _dump_full_prompts() -> None:
|
|
130
|
+
"""Standalone: dump both composed prompts for manual comparison."""
|
|
131
|
+
from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
|
|
132
|
+
|
|
133
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
134
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
135
|
+
from wafer.templates.optimize_kernelbench import template
|
|
136
|
+
|
|
137
|
+
cli_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
|
|
138
|
+
|
|
139
|
+
# Eval
|
|
140
|
+
eval_sys = EVAL_PROMPT.format(
|
|
141
|
+
backend="HIP",
|
|
142
|
+
backend_lower="hip",
|
|
143
|
+
target_flag="--pool mi300x-pool",
|
|
144
|
+
reference_path="<reference_file>",
|
|
145
|
+
)
|
|
146
|
+
eval_sys += "\n\n" + cli_instructions
|
|
147
|
+
|
|
148
|
+
# CLI
|
|
149
|
+
params = dict(template.defaults)
|
|
150
|
+
cli_sys = string.Template(template.system_prompt).safe_substitute(**params)
|
|
151
|
+
cli_sys += "\n\n" + build_cli_instructions(template.bash_allowlist)
|
|
152
|
+
|
|
153
|
+
print("=" * 60)
|
|
154
|
+
print("EVAL SYSTEM PROMPT")
|
|
155
|
+
print("=" * 60)
|
|
156
|
+
print(eval_sys)
|
|
157
|
+
print()
|
|
158
|
+
print("=" * 60)
|
|
159
|
+
print("CLI SYSTEM PROMPT")
|
|
160
|
+
print("=" * 60)
|
|
161
|
+
print(cli_sys)
|
|
162
|
+
print()
|
|
163
|
+
|
|
164
|
+
# Diff
|
|
165
|
+
eval_norm = _normalize_prompt(eval_sys)
|
|
166
|
+
cli_norm = _normalize_prompt(cli_sys)
|
|
167
|
+
diff = list(
|
|
168
|
+
difflib.unified_diff(
|
|
169
|
+
eval_norm.splitlines(),
|
|
170
|
+
cli_norm.splitlines(),
|
|
171
|
+
fromfile="eval",
|
|
172
|
+
tofile="cli",
|
|
173
|
+
lineterm="",
|
|
174
|
+
n=1,
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
if diff:
|
|
178
|
+
print("=" * 60)
|
|
179
|
+
print("DIFFERENCES (after normalizing pool names)")
|
|
180
|
+
print("=" * 60)
|
|
181
|
+
for line in diff:
|
|
182
|
+
print(line)
|
|
183
|
+
else:
|
|
184
|
+
print("IDENTICAL (after normalizing pool names)")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if __name__ == "__main__":
|
|
188
|
+
if "--dump" in sys.argv:
|
|
189
|
+
_dump_full_prompts()
|
|
190
|
+
else:
|
|
191
|
+
test_bash_allowlist_parity()
|
|
192
|
+
print("PASS: bash_allowlist_parity")
|
|
193
|
+
test_enabled_tools_parity()
|
|
194
|
+
print("PASS: enabled_tools_parity")
|
|
195
|
+
test_system_prompt_parity()
|
|
196
|
+
print("PASS: system_prompt_parity")
|
|
197
|
+
test_cli_instructions_identical()
|
|
198
|
+
print("PASS: cli_instructions_identical")
|
|
199
|
+
print("\nAll parity checks passed.")
|
wafer/trace_compare.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""CLI wrapper for trace comparison commands.
|
|
2
|
+
|
|
3
|
+
This module provides the CLI interface for the `wafer compare` commands.
|
|
4
|
+
All core logic is in wafer_core.lib.trace_compare.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import typer
|
|
11
|
+
|
|
12
|
+
from wafer_core.lib.trace_compare import (
|
|
13
|
+
analyze_fusion_differences,
|
|
14
|
+
analyze_traces,
|
|
15
|
+
format_csv,
|
|
16
|
+
format_fusion_csv,
|
|
17
|
+
format_fusion_json,
|
|
18
|
+
format_fusion_text,
|
|
19
|
+
format_json,
|
|
20
|
+
format_text,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compare_traces(
|
|
25
|
+
trace1: Path,
|
|
26
|
+
trace2: Path,
|
|
27
|
+
output: Path | None = None,
|
|
28
|
+
output_format: str = "text",
|
|
29
|
+
phase: str = "all",
|
|
30
|
+
show_layers: bool = False,
|
|
31
|
+
show_all: bool = False,
|
|
32
|
+
show_stack_traces: bool = False,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Compare two GPU traces and generate performance report.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
trace1: Path to first trace file (AMD or NVIDIA)
|
|
38
|
+
trace2: Path to second trace file (AMD or NVIDIA)
|
|
39
|
+
output: Optional output file path (default: stdout)
|
|
40
|
+
output_format: Output format ('text', 'text-layers', 'csv', 'csv-layers', or 'json')
|
|
41
|
+
phase: Filter by phase ('all', 'prefill', or 'decode')
|
|
42
|
+
show_layers: Show layer-wise performance breakdown (text format only)
|
|
43
|
+
show_all: Show all items without truncation (applies to layers, operations, kernels)
|
|
44
|
+
show_stack_traces: Show Python stack traces for operations
|
|
45
|
+
"""
|
|
46
|
+
# Validate files exist
|
|
47
|
+
if not trace1.exists():
|
|
48
|
+
typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
|
|
49
|
+
raise typer.Exit(1)
|
|
50
|
+
|
|
51
|
+
if not trace2.exists():
|
|
52
|
+
typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
|
|
53
|
+
raise typer.Exit(1)
|
|
54
|
+
|
|
55
|
+
# Analyze traces
|
|
56
|
+
# Only show progress messages for non-JSON formats (JSON needs clean stdout)
|
|
57
|
+
if output_format != 'json':
|
|
58
|
+
typer.echo("📊 Loading traces...")
|
|
59
|
+
|
|
60
|
+
# Determine how many stack traces to collect
|
|
61
|
+
max_stacks = 0 if (show_stack_traces and show_all) else (3 if show_stack_traces else 3)
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
results = analyze_traces(
|
|
65
|
+
trace1,
|
|
66
|
+
trace2,
|
|
67
|
+
phase_filter=phase,
|
|
68
|
+
max_stacks=max_stacks,
|
|
69
|
+
)
|
|
70
|
+
except ValueError as e:
|
|
71
|
+
typer.secho(f"❌ {e}", fg=typer.colors.RED, err=True)
|
|
72
|
+
raise typer.Exit(1)
|
|
73
|
+
except Exception as e:
|
|
74
|
+
typer.secho(f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True)
|
|
75
|
+
raise typer.Exit(1)
|
|
76
|
+
|
|
77
|
+
# Show loading confirmation
|
|
78
|
+
if output_format != 'json':
|
|
79
|
+
meta = results["metadata"]
|
|
80
|
+
# Determine which trace is AMD and which is NVIDIA
|
|
81
|
+
if meta['trace1_platform'] == 'AMD':
|
|
82
|
+
amd_gpu, nvidia_gpu = meta['trace1_gpu'], meta['trace2_gpu']
|
|
83
|
+
else:
|
|
84
|
+
amd_gpu, nvidia_gpu = meta['trace2_gpu'], meta['trace1_gpu']
|
|
85
|
+
typer.echo(f"✅ Loaded: AMD ({amd_gpu}) vs NVIDIA ({nvidia_gpu})")
|
|
86
|
+
typer.echo()
|
|
87
|
+
|
|
88
|
+
# Generate output based on format
|
|
89
|
+
if output_format == "text":
|
|
90
|
+
output_str = format_text(results, show_layers=show_layers, show_all=show_all, show_stack_traces=show_stack_traces)
|
|
91
|
+
elif output_format == "text-layers":
|
|
92
|
+
output_str = format_text(results, show_layers=True, show_all=show_all, show_stack_traces=show_stack_traces)
|
|
93
|
+
elif output_format == "csv":
|
|
94
|
+
output_str = format_csv(results, report_type="operations")
|
|
95
|
+
elif output_format == "csv-layers":
|
|
96
|
+
output_str = format_csv(results, report_type="layers")
|
|
97
|
+
elif output_format == "json":
|
|
98
|
+
output_str = format_json(results)
|
|
99
|
+
else:
|
|
100
|
+
typer.secho(f"❌ Unknown format: {output_format}", fg=typer.colors.RED, err=True)
|
|
101
|
+
raise typer.Exit(1)
|
|
102
|
+
|
|
103
|
+
# Write output
|
|
104
|
+
if output:
|
|
105
|
+
output.write_text(output_str)
|
|
106
|
+
typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
|
|
107
|
+
else:
|
|
108
|
+
typer.echo(output_str)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def compare_fusion(
|
|
112
|
+
trace1: Path,
|
|
113
|
+
trace2: Path,
|
|
114
|
+
output: Path | None = None,
|
|
115
|
+
format_type: str = "text",
|
|
116
|
+
min_group_size: int = 50,
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Analyze kernel fusion differences between AMD and NVIDIA traces.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
trace1: Path to first trace file (AMD or NVIDIA)
|
|
122
|
+
trace2: Path to second trace file (AMD or NVIDIA)
|
|
123
|
+
output: Optional output file path (default: stdout)
|
|
124
|
+
format_type: Output format ('text', 'csv', or 'json')
|
|
125
|
+
min_group_size: Minimum correlation group size to analyze
|
|
126
|
+
"""
|
|
127
|
+
# Validate files exist
|
|
128
|
+
if not trace1.exists():
|
|
129
|
+
typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
|
|
130
|
+
raise typer.Exit(1)
|
|
131
|
+
|
|
132
|
+
if not trace2.exists():
|
|
133
|
+
typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
|
|
134
|
+
raise typer.Exit(1)
|
|
135
|
+
|
|
136
|
+
# Analyze fusion
|
|
137
|
+
# Only show progress messages for non-JSON formats (JSON needs clean stdout)
|
|
138
|
+
if format_type != 'json':
|
|
139
|
+
typer.echo("📊 Loading traces...")
|
|
140
|
+
try:
|
|
141
|
+
results = analyze_fusion_differences(
|
|
142
|
+
trace1,
|
|
143
|
+
trace2,
|
|
144
|
+
min_group_size=min_group_size,
|
|
145
|
+
)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
typer.secho(
|
|
148
|
+
f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True
|
|
149
|
+
)
|
|
150
|
+
import traceback
|
|
151
|
+
|
|
152
|
+
traceback.print_exc()
|
|
153
|
+
raise typer.Exit(1)
|
|
154
|
+
|
|
155
|
+
# Show loading confirmation
|
|
156
|
+
if format_type != 'json':
|
|
157
|
+
meta = results["metadata"]
|
|
158
|
+
# Note: fusion analyzer always uses trace1=AMD, trace2=NVIDIA
|
|
159
|
+
typer.echo(f"✅ Loaded: {meta['trace1_gpu']} vs {meta['trace2_gpu']}")
|
|
160
|
+
typer.echo(
|
|
161
|
+
f"Found {meta['trace1_correlation_groups']} trace1 groups and "
|
|
162
|
+
f"{meta['trace2_correlation_groups']} trace2 groups with ≥{min_group_size} kernels"
|
|
163
|
+
)
|
|
164
|
+
typer.echo(f"✅ Matched {meta['matched_groups']} correlation groups")
|
|
165
|
+
typer.echo()
|
|
166
|
+
|
|
167
|
+
# Generate output
|
|
168
|
+
if format_type == "text":
|
|
169
|
+
output_str = format_fusion_text(results)
|
|
170
|
+
elif format_type == "csv":
|
|
171
|
+
output_str = format_fusion_csv(results)
|
|
172
|
+
elif format_type == "json":
|
|
173
|
+
output_str = format_fusion_json(results)
|
|
174
|
+
else:
|
|
175
|
+
typer.secho(f"❌ Unknown format: {format_type}", fg=typer.colors.RED, err=True)
|
|
176
|
+
raise typer.Exit(1)
|
|
177
|
+
|
|
178
|
+
# Write output
|
|
179
|
+
if output:
|
|
180
|
+
output.write_text(output_str)
|
|
181
|
+
typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
|
|
182
|
+
else:
|
|
183
|
+
typer.echo(output_str)
|
wafer/wevin_cli.py
CHANGED
|
@@ -15,6 +15,8 @@ from pathlib import Path
|
|
|
15
15
|
from typing import TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Awaitable, Callable
|
|
19
|
+
|
|
18
20
|
from wafer_core.rollouts import Endpoint, Environment
|
|
19
21
|
from wafer_core.rollouts.dtypes import StreamEvent, ToolCall
|
|
20
22
|
from wafer_core.rollouts.templates import TemplateConfig
|
|
@@ -145,21 +147,60 @@ class StreamingChunkFrontend:
|
|
|
145
147
|
pass
|
|
146
148
|
|
|
147
149
|
|
|
148
|
-
def
|
|
150
|
+
def _make_wafer_token_refresh() -> Callable[[], Awaitable[str | None]]:
|
|
151
|
+
"""Create an async callback that refreshes the wafer proxy token via Supabase."""
|
|
152
|
+
from .auth import load_credentials, refresh_access_token, save_credentials
|
|
153
|
+
|
|
154
|
+
async def _refresh() -> str | None:
|
|
155
|
+
creds = load_credentials()
|
|
156
|
+
if not creds or not creds.refresh_token:
|
|
157
|
+
return None
|
|
158
|
+
try:
|
|
159
|
+
new_access, new_refresh = refresh_access_token(creds.refresh_token)
|
|
160
|
+
save_credentials(new_access, new_refresh, creds.email)
|
|
161
|
+
return new_access
|
|
162
|
+
except Exception:
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
return _refresh
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _get_wafer_auth(
|
|
169
|
+
*, no_proxy: bool = False
|
|
170
|
+
) -> tuple[str | None, str | None, Callable[[], Awaitable[str | None]] | None]:
|
|
149
171
|
"""Get wafer auth credentials with fallback chain.
|
|
150
172
|
|
|
151
173
|
Returns:
|
|
152
|
-
(api_base, api_key) or (None, None) if no auth found
|
|
174
|
+
(api_base, api_key, api_key_refresh) or (None, None, None) if no auth found.
|
|
175
|
+
api_key_refresh is an async callback for mid-session token refresh (only set
|
|
176
|
+
when using wafer proxy via credentials file).
|
|
153
177
|
"""
|
|
154
178
|
from .auth import get_valid_token, load_credentials
|
|
155
179
|
from .global_config import get_api_url
|
|
156
180
|
|
|
181
|
+
if no_proxy:
|
|
182
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
183
|
+
if not api_key:
|
|
184
|
+
# Try auth.json stored key
|
|
185
|
+
from wafer_core.auth import get_api_key
|
|
186
|
+
|
|
187
|
+
api_key = get_api_key("anthropic") or ""
|
|
188
|
+
if api_key:
|
|
189
|
+
print("🔑 Using ANTHROPIC_API_KEY (--no-proxy)\n", file=sys.stderr)
|
|
190
|
+
return "https://api.anthropic.com", api_key, None
|
|
191
|
+
print(
|
|
192
|
+
"❌ --no-proxy requires ANTHROPIC_API_KEY env var or `wafer auth login anthropic`\n",
|
|
193
|
+
file=sys.stderr,
|
|
194
|
+
)
|
|
195
|
+
return None, None, None
|
|
196
|
+
|
|
157
197
|
# Check WAFER_AUTH_TOKEN env var first
|
|
158
198
|
wafer_token = os.environ.get("WAFER_AUTH_TOKEN", "")
|
|
159
199
|
token_source = "WAFER_AUTH_TOKEN" if wafer_token else None
|
|
160
200
|
|
|
161
201
|
# Try credentials file with automatic refresh
|
|
162
202
|
had_credentials = False
|
|
203
|
+
uses_credentials_file = False
|
|
163
204
|
if not wafer_token:
|
|
164
205
|
try:
|
|
165
206
|
creds = load_credentials()
|
|
@@ -169,12 +210,16 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
169
210
|
wafer_token = get_valid_token()
|
|
170
211
|
if wafer_token:
|
|
171
212
|
token_source = "~/.wafer/credentials.json"
|
|
213
|
+
uses_credentials_file = True
|
|
172
214
|
|
|
173
215
|
# If we have a valid wafer token, use it
|
|
174
216
|
if wafer_token:
|
|
175
217
|
api_url = get_api_url()
|
|
176
218
|
print(f"🔑 Using wafer proxy ({token_source})\n", file=sys.stderr)
|
|
177
|
-
|
|
219
|
+
# Only provide refresh callback when token came from credentials file
|
|
220
|
+
# (env var tokens are managed externally)
|
|
221
|
+
refresh = _make_wafer_token_refresh() if uses_credentials_file else None
|
|
222
|
+
return f"{api_url}/v1/anthropic", wafer_token, refresh
|
|
178
223
|
|
|
179
224
|
# Fall back to direct anthropic
|
|
180
225
|
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
@@ -186,9 +231,9 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
186
231
|
)
|
|
187
232
|
else:
|
|
188
233
|
print("🔑 Using ANTHROPIC_API_KEY\n", file=sys.stderr)
|
|
189
|
-
return "https://api.anthropic.com", api_key
|
|
234
|
+
return "https://api.anthropic.com", api_key, None
|
|
190
235
|
|
|
191
|
-
return None, None
|
|
236
|
+
return None, None, None
|
|
192
237
|
|
|
193
238
|
|
|
194
239
|
def _get_session_preview(session: object) -> str:
|
|
@@ -207,7 +252,7 @@ def _get_session_preview(session: object) -> str:
|
|
|
207
252
|
|
|
208
253
|
def _get_log_file_path() -> Path:
|
|
209
254
|
"""Get user-specific log file path, creating directory if needed.
|
|
210
|
-
|
|
255
|
+
|
|
211
256
|
Uses ~/.wafer/logs/ to avoid permission issues with shared /tmp.
|
|
212
257
|
"""
|
|
213
258
|
log_dir = Path.home() / ".wafer" / "logs"
|
|
@@ -255,6 +300,7 @@ def _build_endpoint(
|
|
|
255
300
|
model_override: str | None,
|
|
256
301
|
api_base: str,
|
|
257
302
|
api_key: str,
|
|
303
|
+
api_key_refresh: Callable[[], Awaitable[str | None]] | None = None,
|
|
258
304
|
) -> Endpoint:
|
|
259
305
|
"""Build an Endpoint from template config and auth."""
|
|
260
306
|
from wafer_core.rollouts import Endpoint
|
|
@@ -269,6 +315,7 @@ def _build_endpoint(
|
|
|
269
315
|
model=model_id,
|
|
270
316
|
api_base=api_base,
|
|
271
317
|
api_key=api_key,
|
|
318
|
+
api_key_refresh=api_key_refresh,
|
|
272
319
|
thinking=thinking_config,
|
|
273
320
|
max_tokens=tpl.max_tokens,
|
|
274
321
|
)
|
|
@@ -384,6 +431,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
384
431
|
get_session: str | None = None,
|
|
385
432
|
json_output: bool = False,
|
|
386
433
|
no_sandbox: bool = False,
|
|
434
|
+
no_proxy: bool = False,
|
|
387
435
|
) -> None:
|
|
388
436
|
"""Run wevin agent in-process via rollouts."""
|
|
389
437
|
from dataclasses import asdict
|
|
@@ -499,10 +547,10 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
499
547
|
_setup_logging()
|
|
500
548
|
|
|
501
549
|
# Auth
|
|
502
|
-
api_base, api_key = _get_wafer_auth()
|
|
550
|
+
api_base, api_key, api_key_refresh = _get_wafer_auth(no_proxy=no_proxy)
|
|
503
551
|
if not api_base or not api_key:
|
|
504
552
|
print("Error: No API credentials found", file=sys.stderr)
|
|
505
|
-
print(" Run 'wafer login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
553
|
+
print(" Run 'wafer auth login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
506
554
|
sys.exit(1)
|
|
507
555
|
|
|
508
556
|
assert api_base is not None
|
|
@@ -525,6 +573,17 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
525
573
|
tpl = _get_default_template()
|
|
526
574
|
base_system_prompt = tpl.system_prompt
|
|
527
575
|
|
|
576
|
+
# Compose CLI instructions from --help text for allowed wafer commands
|
|
577
|
+
# TODO: The eval path doesn't have the skills layer below. If include_skills
|
|
578
|
+
# is ever enabled for optimize-kernelbench, the eval would need it too for parity.
|
|
579
|
+
# See test_eval_cli_parity.py for coverage notes.
|
|
580
|
+
if tpl.bash_allowlist:
|
|
581
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
582
|
+
|
|
583
|
+
cli_instructions = build_cli_instructions(tpl.bash_allowlist)
|
|
584
|
+
if cli_instructions:
|
|
585
|
+
base_system_prompt = base_system_prompt + "\n\n" + cli_instructions
|
|
586
|
+
|
|
528
587
|
# Append skill metadata if skills are enabled
|
|
529
588
|
if tpl.include_skills:
|
|
530
589
|
from wafer_core.rollouts.skills import discover_skills, format_skill_metadata_for_prompt
|
|
@@ -542,7 +601,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
542
601
|
resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
|
|
543
602
|
|
|
544
603
|
# Build endpoint and environment
|
|
545
|
-
endpoint = _build_endpoint(tpl, model, api_base, api_key)
|
|
604
|
+
endpoint = _build_endpoint(tpl, model, api_base, api_key, api_key_refresh)
|
|
546
605
|
environment = _build_environment(tpl, tools, corpus_path, no_sandbox)
|
|
547
606
|
|
|
548
607
|
# Session store
|
wafer/workspaces.py
CHANGED
|
@@ -39,13 +39,13 @@ def _friendly_error(status_code: int, response_text: str, workspace_id: str) ->
|
|
|
39
39
|
User-friendly error message with suggested next steps
|
|
40
40
|
"""
|
|
41
41
|
if status_code == 401:
|
|
42
|
-
return "Not authenticated. Run: wafer login"
|
|
42
|
+
return "Not authenticated. Run: wafer auth login"
|
|
43
43
|
|
|
44
44
|
if status_code == 402:
|
|
45
45
|
return (
|
|
46
46
|
"Insufficient credits.\n"
|
|
47
|
-
" Check usage: wafer billing\n"
|
|
48
|
-
" Add credits: wafer billing topup"
|
|
47
|
+
" Check usage: wafer config billing\n"
|
|
48
|
+
" Add credits: wafer config billing topup"
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
if status_code == 404:
|
|
@@ -107,7 +107,7 @@ def _list_workspaces_raw() -> list[dict]:
|
|
|
107
107
|
workspaces = response.json()
|
|
108
108
|
except httpx.HTTPStatusError as e:
|
|
109
109
|
if e.response.status_code == 401:
|
|
110
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
110
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
111
111
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
112
112
|
except httpx.RequestError as e:
|
|
113
113
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
@@ -188,7 +188,7 @@ def list_workspaces(json_output: bool = False) -> str:
|
|
|
188
188
|
workspaces = response.json()
|
|
189
189
|
except httpx.HTTPStatusError as e:
|
|
190
190
|
if e.response.status_code == 401:
|
|
191
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
191
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
192
192
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
193
193
|
except httpx.RequestError as e:
|
|
194
194
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
@@ -307,7 +307,7 @@ def create_workspace(
|
|
|
307
307
|
workspace = response.json()
|
|
308
308
|
except httpx.HTTPStatusError as e:
|
|
309
309
|
if e.response.status_code == 401:
|
|
310
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
310
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
311
311
|
if e.response.status_code == 400:
|
|
312
312
|
raise RuntimeError(f"Bad request: {e.response.text}") from e
|
|
313
313
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
@@ -413,7 +413,7 @@ def delete_workspace(workspace_id: str, json_output: bool = False) -> str:
|
|
|
413
413
|
result = response.json()
|
|
414
414
|
except httpx.HTTPStatusError as e:
|
|
415
415
|
if e.response.status_code == 401:
|
|
416
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
416
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
417
417
|
if e.response.status_code == 404:
|
|
418
418
|
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
419
419
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
@@ -691,7 +691,7 @@ def get_workspace_raw(workspace_id: str) -> dict:
|
|
|
691
691
|
workspace = response.json()
|
|
692
692
|
except httpx.HTTPStatusError as e:
|
|
693
693
|
if e.response.status_code == 401:
|
|
694
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
694
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
695
695
|
if e.response.status_code == 404:
|
|
696
696
|
raise RuntimeError(f"Workspace not found: {workspace_id}") from e
|
|
697
697
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|