wafer-cli 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/billing.py +6 -6
- wafer/cli.py +513 -86
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +72 -6
- wafer/evaluate.py +13 -6
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/ssh_keys.py +6 -6
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +1 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +274 -0
- wafer/wevin_cli.py +68 -9
- wafer/workspaces.py +8 -8
- wafer_cli-0.2.26.dist-info/METADATA +107 -0
- wafer_cli-0.2.26.dist-info/RECORD +45 -0
- wafer_cli-0.2.24.dist-info/METADATA +0 -16
- wafer_cli-0.2.24.dist-info/RECORD +0 -41
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.26.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.26.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Test that eval and CLI agent configs are in sync.
|
|
2
|
+
|
|
3
|
+
Run: python -m pytest apps/wafer-cli/wafer/tests/test_eval_cli_parity.py -v
|
|
4
|
+
or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py (standalone)
|
|
5
|
+
or: python apps/wafer-cli/wafer/tests/test_eval_cli_parity.py --dump (side-by-side)
|
|
6
|
+
|
|
7
|
+
Checks:
|
|
8
|
+
- Same bash allowlist (identity check — must be the same object)
|
|
9
|
+
- Same enabled tools
|
|
10
|
+
- Same system prompt text (modulo runtime variables like pool name)
|
|
11
|
+
- CLI instructions generated from same allowlist
|
|
12
|
+
|
|
13
|
+
Coverage notes:
|
|
14
|
+
These tests verify config-level parity. After composition, both paths
|
|
15
|
+
flow through the same codepath (Actor → run_agent_step → rollout →
|
|
16
|
+
anthropic.py provider) with no further system prompt modifications.
|
|
17
|
+
|
|
18
|
+
Known infra-level differences NOT tested here:
|
|
19
|
+
- Claude Code identity prefix: anthropic.py prepends "You are Claude Code..."
|
|
20
|
+
when using OAuth/Claude Code API keys. Eval typically uses raw API keys,
|
|
21
|
+
so eval agents may not get this prefix. This is an auth concern, not a
|
|
22
|
+
prompt content concern.
|
|
23
|
+
- Skills layer: CLI path (wevin_cli.py) can append skill metadata if
|
|
24
|
+
template.include_skills is True. Currently False for optimize-kernelbench,
|
|
25
|
+
so this is a no-op. Eval path doesn't have this layer at all.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import difflib
|
|
31
|
+
import string
|
|
32
|
+
import sys
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _normalize_prompt(prompt: str) -> str:
|
|
36
|
+
"""Normalize runtime variables so eval and CLI prompts are comparable."""
|
|
37
|
+
# Replace pool/target specifics with placeholders
|
|
38
|
+
prompt = prompt.replace("--pool mi300x-pool", "--pool POOL")
|
|
39
|
+
prompt = prompt.replace("--pool kernelbench-pool", "--pool POOL")
|
|
40
|
+
prompt = prompt.replace("--target mi300x", "--target TARGET")
|
|
41
|
+
return prompt
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_bash_allowlist_parity() -> None:
|
|
45
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
46
|
+
from wafer.templates.optimize_kernelbench import template
|
|
47
|
+
|
|
48
|
+
# Template should use the shared allowlist (same object)
|
|
49
|
+
assert template.bash_allowlist is KERNELBENCH_BASH_ALLOWLIST, (
|
|
50
|
+
"Template bash_allowlist is not the shared KERNELBENCH_BASH_ALLOWLIST. "
|
|
51
|
+
"Import it from wafer.agent_defaults instead of defining a local copy."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Eval should also alias to the shared allowlist
|
|
55
|
+
from optimize_kernelbench_eval.base_config import BASH_ALLOWLIST
|
|
56
|
+
|
|
57
|
+
assert BASH_ALLOWLIST is KERNELBENCH_BASH_ALLOWLIST, (
|
|
58
|
+
"Eval BASH_ALLOWLIST is not the shared KERNELBENCH_BASH_ALLOWLIST. "
|
|
59
|
+
"Import it from wafer.agent_defaults instead of defining a local copy."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_enabled_tools_parity() -> None:
|
|
64
|
+
from wafer.agent_defaults import ENABLED_TOOLS
|
|
65
|
+
from wafer.templates.optimize_kernelbench import template
|
|
66
|
+
|
|
67
|
+
assert template.tools == ENABLED_TOOLS, (
|
|
68
|
+
f"Template tools {template.tools} != shared ENABLED_TOOLS {ENABLED_TOOLS}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_system_prompt_parity() -> None:
|
|
73
|
+
"""The task-specific system prompt should be identical between eval and CLI
|
|
74
|
+
(after normalizing runtime variables like pool name)."""
|
|
75
|
+
from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
|
|
76
|
+
|
|
77
|
+
from wafer.templates.optimize_kernelbench import template
|
|
78
|
+
|
|
79
|
+
# Format eval prompt with HIP defaults (most common)
|
|
80
|
+
eval_formatted = EVAL_PROMPT.format(
|
|
81
|
+
backend="HIP",
|
|
82
|
+
backend_lower="hip",
|
|
83
|
+
target_flag="--pool POOL",
|
|
84
|
+
reference_path="<reference_file>",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Format CLI template prompt with matching defaults
|
|
88
|
+
params = dict(template.defaults)
|
|
89
|
+
params["target_flag"] = "--pool POOL"
|
|
90
|
+
cli_formatted = string.Template(template.system_prompt).safe_substitute(**params)
|
|
91
|
+
|
|
92
|
+
eval_normalized = _normalize_prompt(eval_formatted)
|
|
93
|
+
cli_normalized = _normalize_prompt(cli_formatted)
|
|
94
|
+
|
|
95
|
+
if eval_normalized != cli_normalized:
|
|
96
|
+
diff = "\n".join(
|
|
97
|
+
difflib.unified_diff(
|
|
98
|
+
eval_normalized.splitlines(),
|
|
99
|
+
cli_normalized.splitlines(),
|
|
100
|
+
fromfile="eval (base_config.py)",
|
|
101
|
+
tofile="cli (optimize_kernelbench.py template)",
|
|
102
|
+
lineterm="",
|
|
103
|
+
n=2,
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
raise AssertionError(
|
|
107
|
+
f"System prompts differ between eval and CLI template:\n\n{diff}\n\n"
|
|
108
|
+
"Both should define the same task instructions. "
|
|
109
|
+
"Edit one to match the other."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_cli_instructions_identical() -> None:
|
|
114
|
+
"""Both paths should generate the same CLI instructions
|
|
115
|
+
(since they use the same bash_allowlist)."""
|
|
116
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
117
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
118
|
+
from wafer.templates.optimize_kernelbench import template
|
|
119
|
+
|
|
120
|
+
eval_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
|
|
121
|
+
cli_instructions = build_cli_instructions(template.bash_allowlist)
|
|
122
|
+
|
|
123
|
+
assert eval_instructions == cli_instructions, (
|
|
124
|
+
"CLI instructions differ — this means the bash allowlists diverged."
|
|
125
|
+
)
|
|
126
|
+
assert len(eval_instructions) > 0, "CLI instructions should not be empty"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _dump_full_prompts() -> None:
|
|
130
|
+
"""Standalone: dump both composed prompts for manual comparison."""
|
|
131
|
+
from optimize_kernelbench_eval.base_config import SYSTEM_PROMPT as EVAL_PROMPT
|
|
132
|
+
|
|
133
|
+
from wafer.agent_defaults import KERNELBENCH_BASH_ALLOWLIST
|
|
134
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
135
|
+
from wafer.templates.optimize_kernelbench import template
|
|
136
|
+
|
|
137
|
+
cli_instructions = build_cli_instructions(KERNELBENCH_BASH_ALLOWLIST)
|
|
138
|
+
|
|
139
|
+
# Eval
|
|
140
|
+
eval_sys = EVAL_PROMPT.format(
|
|
141
|
+
backend="HIP",
|
|
142
|
+
backend_lower="hip",
|
|
143
|
+
target_flag="--pool mi300x-pool",
|
|
144
|
+
reference_path="<reference_file>",
|
|
145
|
+
)
|
|
146
|
+
eval_sys += "\n\n" + cli_instructions
|
|
147
|
+
|
|
148
|
+
# CLI
|
|
149
|
+
params = dict(template.defaults)
|
|
150
|
+
cli_sys = string.Template(template.system_prompt).safe_substitute(**params)
|
|
151
|
+
cli_sys += "\n\n" + build_cli_instructions(template.bash_allowlist)
|
|
152
|
+
|
|
153
|
+
print("=" * 60)
|
|
154
|
+
print("EVAL SYSTEM PROMPT")
|
|
155
|
+
print("=" * 60)
|
|
156
|
+
print(eval_sys)
|
|
157
|
+
print()
|
|
158
|
+
print("=" * 60)
|
|
159
|
+
print("CLI SYSTEM PROMPT")
|
|
160
|
+
print("=" * 60)
|
|
161
|
+
print(cli_sys)
|
|
162
|
+
print()
|
|
163
|
+
|
|
164
|
+
# Diff
|
|
165
|
+
eval_norm = _normalize_prompt(eval_sys)
|
|
166
|
+
cli_norm = _normalize_prompt(cli_sys)
|
|
167
|
+
diff = list(
|
|
168
|
+
difflib.unified_diff(
|
|
169
|
+
eval_norm.splitlines(),
|
|
170
|
+
cli_norm.splitlines(),
|
|
171
|
+
fromfile="eval",
|
|
172
|
+
tofile="cli",
|
|
173
|
+
lineterm="",
|
|
174
|
+
n=1,
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
if diff:
|
|
178
|
+
print("=" * 60)
|
|
179
|
+
print("DIFFERENCES (after normalizing pool names)")
|
|
180
|
+
print("=" * 60)
|
|
181
|
+
for line in diff:
|
|
182
|
+
print(line)
|
|
183
|
+
else:
|
|
184
|
+
print("IDENTICAL (after normalizing pool names)")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if __name__ == "__main__":
|
|
188
|
+
if "--dump" in sys.argv:
|
|
189
|
+
_dump_full_prompts()
|
|
190
|
+
else:
|
|
191
|
+
test_bash_allowlist_parity()
|
|
192
|
+
print("PASS: bash_allowlist_parity")
|
|
193
|
+
test_enabled_tools_parity()
|
|
194
|
+
print("PASS: enabled_tools_parity")
|
|
195
|
+
test_system_prompt_parity()
|
|
196
|
+
print("PASS: system_prompt_parity")
|
|
197
|
+
test_cli_instructions_identical()
|
|
198
|
+
print("PASS: cli_instructions_identical")
|
|
199
|
+
print("\nAll parity checks passed.")
|
wafer/trace_compare.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""CLI wrapper for trace comparison commands.
|
|
2
|
+
|
|
3
|
+
This module provides the CLI interface for the `wafer compare` commands.
|
|
4
|
+
All core logic is in wafer_core.lib.trace_compare.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import typer
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import sys
|
|
15
|
+
|
|
16
|
+
from wafer_core.lib.trace_compare import (
|
|
17
|
+
analyze_trace_pair,
|
|
18
|
+
format_csv,
|
|
19
|
+
format_json,
|
|
20
|
+
format_text,
|
|
21
|
+
ArchitectureType,
|
|
22
|
+
detect_architecture,
|
|
23
|
+
)
|
|
24
|
+
from wafer_core.lib.trace_compare.loader import StreamingMetadata
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compare_traces(
|
|
28
|
+
trace1: Path,
|
|
29
|
+
trace2: Path,
|
|
30
|
+
output: Path | None = None,
|
|
31
|
+
output_format: str = "text",
|
|
32
|
+
phase: str = "all",
|
|
33
|
+
show_layers: bool = False,
|
|
34
|
+
show_all: bool = False,
|
|
35
|
+
show_stack_traces: bool = False,
|
|
36
|
+
recommendations: bool = False,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Compare two GPU traces and generate performance report.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
trace1: Path to first trace file (AMD or NVIDIA)
|
|
42
|
+
trace2: Path to second trace file (AMD or NVIDIA)
|
|
43
|
+
output: Optional output file path (default: stdout)
|
|
44
|
+
output_format: Output format ('text', 'text-layers', 'csv', 'csv-layers', or 'json')
|
|
45
|
+
phase: Filter by phase ('all', 'prefill', or 'decode')
|
|
46
|
+
show_layers: Show layer-wise performance breakdown (text format only)
|
|
47
|
+
show_all: Show all items without truncation (applies to layers, operations, kernels)
|
|
48
|
+
show_stack_traces: Show Python stack traces for operations
|
|
49
|
+
"""
|
|
50
|
+
# Validate files exist
|
|
51
|
+
if not trace1.exists():
|
|
52
|
+
typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
|
|
53
|
+
raise typer.Exit(1)
|
|
54
|
+
|
|
55
|
+
if not trace2.exists():
|
|
56
|
+
typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
|
|
57
|
+
raise typer.Exit(1)
|
|
58
|
+
|
|
59
|
+
# Progress callback for JSON format (emits NDJSON to stdout)
|
|
60
|
+
def progress_callback(stage: str, fraction: float) -> None:
|
|
61
|
+
if output_format == 'json':
|
|
62
|
+
progress_msg = json.dumps({"type": "progress", "stage": stage, "fraction": fraction})
|
|
63
|
+
print(progress_msg, file=sys.stdout, flush=True)
|
|
64
|
+
elif output_format != 'json':
|
|
65
|
+
percent = int(fraction * 100)
|
|
66
|
+
typer.echo(f"📊 {stage}: {percent}%", err=True)
|
|
67
|
+
|
|
68
|
+
# Metadata callback for JSON format (emits NDJSON with early GPU info)
|
|
69
|
+
def metadata_callback(meta1: StreamingMetadata, meta2: StreamingMetadata) -> None:
|
|
70
|
+
if output_format == 'json':
|
|
71
|
+
metadata_msg = json.dumps({
|
|
72
|
+
"type": "metadata",
|
|
73
|
+
"trace1": {
|
|
74
|
+
"platform": meta1.platform,
|
|
75
|
+
"gpu": meta1.gpu_name,
|
|
76
|
+
"file_size_mb": round(meta1.file_size_mb, 1),
|
|
77
|
+
},
|
|
78
|
+
"trace2": {
|
|
79
|
+
"platform": meta2.platform,
|
|
80
|
+
"gpu": meta2.gpu_name,
|
|
81
|
+
"file_size_mb": round(meta2.file_size_mb, 1),
|
|
82
|
+
},
|
|
83
|
+
})
|
|
84
|
+
print(metadata_msg, file=sys.stdout, flush=True)
|
|
85
|
+
else:
|
|
86
|
+
typer.echo(f"📊 Trace 1: {meta1.platform} - {meta1.gpu_name} ({meta1.file_size_mb:.1f}MB)", err=True)
|
|
87
|
+
typer.echo(f"📊 Trace 2: {meta2.platform} - {meta2.gpu_name} ({meta2.file_size_mb:.1f}MB)", err=True)
|
|
88
|
+
|
|
89
|
+
# Analyze traces using unified API
|
|
90
|
+
if output_format != 'json':
|
|
91
|
+
typer.echo("📊 Loading traces...")
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
result_obj = analyze_trace_pair(
|
|
95
|
+
trace1,
|
|
96
|
+
trace2,
|
|
97
|
+
phase=phase,
|
|
98
|
+
include_stacks=True,
|
|
99
|
+
on_progress=progress_callback,
|
|
100
|
+
on_metadata=metadata_callback,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
results = {
|
|
104
|
+
"metadata": result_obj.metadata,
|
|
105
|
+
"operations": result_obj.operations,
|
|
106
|
+
"layers": result_obj.layers,
|
|
107
|
+
"warnings": [{"code": w.code, "severity": w.severity, "message": w.message, "suggestion": w.suggestion} for w in result_obj.warnings],
|
|
108
|
+
"architecture": result_obj.architecture.value,
|
|
109
|
+
"layer_alignments": result_obj.layer_alignments,
|
|
110
|
+
"fusion_analysis": result_obj.fusion_analysis,
|
|
111
|
+
"same_kernel_analysis": result_obj.same_kernel_analysis,
|
|
112
|
+
}
|
|
113
|
+
except ValueError as e:
|
|
114
|
+
typer.secho(f"❌ {e}", fg=typer.colors.RED, err=True)
|
|
115
|
+
raise typer.Exit(1)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
typer.secho(f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True)
|
|
118
|
+
raise typer.Exit(1)
|
|
119
|
+
|
|
120
|
+
if output_format != 'json':
|
|
121
|
+
meta = results["metadata"]
|
|
122
|
+
if meta['trace1_platform'] == 'AMD':
|
|
123
|
+
amd_gpu, nvidia_gpu = meta['trace1_gpu'], meta['trace2_gpu']
|
|
124
|
+
else:
|
|
125
|
+
amd_gpu, nvidia_gpu = meta['trace2_gpu'], meta['trace1_gpu']
|
|
126
|
+
typer.echo(f"✅ Loaded: AMD ({amd_gpu}) vs NVIDIA ({nvidia_gpu})")
|
|
127
|
+
|
|
128
|
+
# Display warnings
|
|
129
|
+
warnings = results.get("warnings", [])
|
|
130
|
+
if warnings:
|
|
131
|
+
typer.echo()
|
|
132
|
+
for warning in warnings:
|
|
133
|
+
icon = "❌" if warning["severity"] == "error" else "⚠️" if warning["severity"] == "warning" else "ℹ️"
|
|
134
|
+
typer.secho(f"{icon} {warning['message']}", fg=typer.colors.YELLOW if warning["severity"] == "warning" else typer.colors.BLUE)
|
|
135
|
+
if warning.get("suggestion"):
|
|
136
|
+
typer.secho(f" Suggestion: {warning['suggestion']}", fg=typer.colors.BLUE)
|
|
137
|
+
typer.echo()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# Generate output based on format
|
|
141
|
+
if output_format == "text":
|
|
142
|
+
output_str = format_text(results, show_layers=show_layers, show_all=show_all, show_stack_traces=show_stack_traces)
|
|
143
|
+
elif output_format == "text-layers":
|
|
144
|
+
output_str = format_text(results, show_layers=True, show_all=show_all, show_stack_traces=show_stack_traces)
|
|
145
|
+
elif output_format == "csv":
|
|
146
|
+
output_str = format_csv(results, report_type="operations")
|
|
147
|
+
elif output_format == "csv-layers":
|
|
148
|
+
output_str = format_csv(results, report_type="layers")
|
|
149
|
+
elif output_format == "json":
|
|
150
|
+
output_str = format_json(results)
|
|
151
|
+
else:
|
|
152
|
+
typer.secho(f"❌ Unknown format: {output_format}", fg=typer.colors.RED, err=True)
|
|
153
|
+
raise typer.Exit(1)
|
|
154
|
+
|
|
155
|
+
# Write output
|
|
156
|
+
if output:
|
|
157
|
+
output.write_text(output_str)
|
|
158
|
+
typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
|
|
159
|
+
else:
|
|
160
|
+
typer.echo(output_str)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def compare_align(
|
|
164
|
+
trace1: Path,
|
|
165
|
+
trace2: Path,
|
|
166
|
+
output: Path | None = None,
|
|
167
|
+
output_format: str = "json",
|
|
168
|
+
phase: str = "all",
|
|
169
|
+
layer: int | None = None,
|
|
170
|
+
) -> None:
|
|
171
|
+
"""Align kernels at layer level for exact kernel-to-kernel comparison.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
trace1: Path to first trace file (AMD or NVIDIA)
|
|
175
|
+
trace2: Path to second trace file (AMD or NVIDIA)
|
|
176
|
+
output: Optional output file path (default: stdout)
|
|
177
|
+
output_format: Output format ('json' only for now)
|
|
178
|
+
phase: Filter by phase ('all', 'prefill', or 'decode')
|
|
179
|
+
layer: Focus on specific layer number (optional)
|
|
180
|
+
"""
|
|
181
|
+
# Validate files exist
|
|
182
|
+
if not trace1.exists():
|
|
183
|
+
typer.secho(f"❌ File not found: {trace1}", fg=typer.colors.RED, err=True)
|
|
184
|
+
raise typer.Exit(1)
|
|
185
|
+
|
|
186
|
+
if not trace2.exists():
|
|
187
|
+
typer.secho(f"❌ File not found: {trace2}", fg=typer.colors.RED, err=True)
|
|
188
|
+
raise typer.Exit(1)
|
|
189
|
+
|
|
190
|
+
# Progress callback for JSON format (emits NDJSON to stdout)
|
|
191
|
+
def progress_callback(stage: str, fraction: float) -> None:
|
|
192
|
+
if output_format == 'json':
|
|
193
|
+
progress_msg = json.dumps({"type": "progress", "stage": stage, "fraction": fraction})
|
|
194
|
+
print(progress_msg, file=sys.stdout, flush=True)
|
|
195
|
+
else:
|
|
196
|
+
percent = int(fraction * 100)
|
|
197
|
+
typer.echo(f"📊 {stage}: {percent}%", err=True)
|
|
198
|
+
|
|
199
|
+
# Metadata callback for JSON format
|
|
200
|
+
def metadata_callback(meta1: StreamingMetadata, meta2: StreamingMetadata) -> None:
|
|
201
|
+
if output_format == 'json':
|
|
202
|
+
metadata_msg = json.dumps({
|
|
203
|
+
"type": "metadata",
|
|
204
|
+
"trace1": {
|
|
205
|
+
"platform": meta1.platform,
|
|
206
|
+
"gpu": meta1.gpu_name,
|
|
207
|
+
"file_size_mb": round(meta1.file_size_mb, 1),
|
|
208
|
+
},
|
|
209
|
+
"trace2": {
|
|
210
|
+
"platform": meta2.platform,
|
|
211
|
+
"gpu": meta2.gpu_name,
|
|
212
|
+
"file_size_mb": round(meta2.file_size_mb, 1),
|
|
213
|
+
},
|
|
214
|
+
})
|
|
215
|
+
print(metadata_msg, file=sys.stdout, flush=True)
|
|
216
|
+
else:
|
|
217
|
+
typer.echo(f"📊 Trace 1: {meta1.platform} - {meta1.gpu_name} ({meta1.file_size_mb:.1f}MB)", err=True)
|
|
218
|
+
typer.echo(f"📊 Trace 2: {meta2.platform} - {meta2.gpu_name} ({meta2.file_size_mb:.1f}MB)", err=True)
|
|
219
|
+
|
|
220
|
+
# Analyze traces using unified API
|
|
221
|
+
if output_format != 'json':
|
|
222
|
+
typer.echo("📊 Loading traces...")
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
result_obj = analyze_trace_pair(
|
|
226
|
+
trace1,
|
|
227
|
+
trace2,
|
|
228
|
+
phase=phase,
|
|
229
|
+
include_stacks=True,
|
|
230
|
+
on_progress=progress_callback,
|
|
231
|
+
on_metadata=metadata_callback,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
results = {
|
|
235
|
+
"metadata": result_obj.metadata,
|
|
236
|
+
"layer_alignments": result_obj.layer_alignments or [],
|
|
237
|
+
"fusion_analysis": result_obj.fusion_analysis or {},
|
|
238
|
+
"same_kernel_analysis": result_obj.same_kernel_analysis or {},
|
|
239
|
+
"operations": result_obj.operations,
|
|
240
|
+
"layers": result_obj.layers,
|
|
241
|
+
"warnings": [{"code": w.code, "severity": w.severity, "message": w.message, "suggestion": w.suggestion} for w in result_obj.warnings],
|
|
242
|
+
"architecture": result_obj.architecture.value,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
if layer is not None:
|
|
246
|
+
results["layer_alignments"] = [
|
|
247
|
+
la for la in results["layer_alignments"] if la.get("layer") == layer
|
|
248
|
+
]
|
|
249
|
+
except ValueError as e:
|
|
250
|
+
typer.secho(f"❌ {e}", fg=typer.colors.RED, err=True)
|
|
251
|
+
raise typer.Exit(1)
|
|
252
|
+
except Exception as e:
|
|
253
|
+
typer.secho(f"❌ Error analyzing traces: {e}", fg=typer.colors.RED, err=True)
|
|
254
|
+
import traceback
|
|
255
|
+
traceback.print_exc()
|
|
256
|
+
raise typer.Exit(1)
|
|
257
|
+
|
|
258
|
+
if output_format != 'json':
|
|
259
|
+
meta = results["metadata"]
|
|
260
|
+
typer.echo(f"✅ Loaded: {meta.get('amd_gpu', 'Unknown')} vs {meta.get('nvidia_gpu', 'Unknown')}")
|
|
261
|
+
typer.echo(f"✅ Found {len(results['layer_alignments'])} layers")
|
|
262
|
+
typer.echo()
|
|
263
|
+
|
|
264
|
+
if output_format == "json":
|
|
265
|
+
output_str = format_json(results)
|
|
266
|
+
else:
|
|
267
|
+
typer.secho(f"❌ Format {output_format} not yet supported for align command. Use 'json'.", fg=typer.colors.RED, err=True)
|
|
268
|
+
raise typer.Exit(1)
|
|
269
|
+
|
|
270
|
+
if output:
|
|
271
|
+
output.write_text(output_str)
|
|
272
|
+
typer.secho(f"✅ Report saved to {output}", fg=typer.colors.GREEN)
|
|
273
|
+
else:
|
|
274
|
+
typer.echo(output_str)
|
wafer/wevin_cli.py
CHANGED
|
@@ -15,6 +15,8 @@ from pathlib import Path
|
|
|
15
15
|
from typing import TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Awaitable, Callable
|
|
19
|
+
|
|
18
20
|
from wafer_core.rollouts import Endpoint, Environment
|
|
19
21
|
from wafer_core.rollouts.dtypes import StreamEvent, ToolCall
|
|
20
22
|
from wafer_core.rollouts.templates import TemplateConfig
|
|
@@ -145,21 +147,60 @@ class StreamingChunkFrontend:
|
|
|
145
147
|
pass
|
|
146
148
|
|
|
147
149
|
|
|
148
|
-
def
|
|
150
|
+
def _make_wafer_token_refresh() -> Callable[[], Awaitable[str | None]]:
|
|
151
|
+
"""Create an async callback that refreshes the wafer proxy token via Supabase."""
|
|
152
|
+
from .auth import load_credentials, refresh_access_token, save_credentials
|
|
153
|
+
|
|
154
|
+
async def _refresh() -> str | None:
|
|
155
|
+
creds = load_credentials()
|
|
156
|
+
if not creds or not creds.refresh_token:
|
|
157
|
+
return None
|
|
158
|
+
try:
|
|
159
|
+
new_access, new_refresh = refresh_access_token(creds.refresh_token)
|
|
160
|
+
save_credentials(new_access, new_refresh, creds.email)
|
|
161
|
+
return new_access
|
|
162
|
+
except Exception:
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
return _refresh
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _get_wafer_auth(
|
|
169
|
+
*, no_proxy: bool = False
|
|
170
|
+
) -> tuple[str | None, str | None, Callable[[], Awaitable[str | None]] | None]:
|
|
149
171
|
"""Get wafer auth credentials with fallback chain.
|
|
150
172
|
|
|
151
173
|
Returns:
|
|
152
|
-
(api_base, api_key) or (None, None) if no auth found
|
|
174
|
+
(api_base, api_key, api_key_refresh) or (None, None, None) if no auth found.
|
|
175
|
+
api_key_refresh is an async callback for mid-session token refresh (only set
|
|
176
|
+
when using wafer proxy via credentials file).
|
|
153
177
|
"""
|
|
154
178
|
from .auth import get_valid_token, load_credentials
|
|
155
179
|
from .global_config import get_api_url
|
|
156
180
|
|
|
181
|
+
if no_proxy:
|
|
182
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
183
|
+
if not api_key:
|
|
184
|
+
# Try auth.json stored key
|
|
185
|
+
from wafer_core.auth import get_api_key
|
|
186
|
+
|
|
187
|
+
api_key = get_api_key("anthropic") or ""
|
|
188
|
+
if api_key:
|
|
189
|
+
print("🔑 Using ANTHROPIC_API_KEY (--no-proxy)\n", file=sys.stderr)
|
|
190
|
+
return "https://api.anthropic.com", api_key, None
|
|
191
|
+
print(
|
|
192
|
+
"❌ --no-proxy requires ANTHROPIC_API_KEY env var or `wafer auth login anthropic`\n",
|
|
193
|
+
file=sys.stderr,
|
|
194
|
+
)
|
|
195
|
+
return None, None, None
|
|
196
|
+
|
|
157
197
|
# Check WAFER_AUTH_TOKEN env var first
|
|
158
198
|
wafer_token = os.environ.get("WAFER_AUTH_TOKEN", "")
|
|
159
199
|
token_source = "WAFER_AUTH_TOKEN" if wafer_token else None
|
|
160
200
|
|
|
161
201
|
# Try credentials file with automatic refresh
|
|
162
202
|
had_credentials = False
|
|
203
|
+
uses_credentials_file = False
|
|
163
204
|
if not wafer_token:
|
|
164
205
|
try:
|
|
165
206
|
creds = load_credentials()
|
|
@@ -169,12 +210,16 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
169
210
|
wafer_token = get_valid_token()
|
|
170
211
|
if wafer_token:
|
|
171
212
|
token_source = "~/.wafer/credentials.json"
|
|
213
|
+
uses_credentials_file = True
|
|
172
214
|
|
|
173
215
|
# If we have a valid wafer token, use it
|
|
174
216
|
if wafer_token:
|
|
175
217
|
api_url = get_api_url()
|
|
176
218
|
print(f"🔑 Using wafer proxy ({token_source})\n", file=sys.stderr)
|
|
177
|
-
|
|
219
|
+
# Only provide refresh callback when token came from credentials file
|
|
220
|
+
# (env var tokens are managed externally)
|
|
221
|
+
refresh = _make_wafer_token_refresh() if uses_credentials_file else None
|
|
222
|
+
return f"{api_url}/v1/anthropic", wafer_token, refresh
|
|
178
223
|
|
|
179
224
|
# Fall back to direct anthropic
|
|
180
225
|
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|
@@ -186,9 +231,9 @@ def _get_wafer_auth() -> tuple[str | None, str | None]:
|
|
|
186
231
|
)
|
|
187
232
|
else:
|
|
188
233
|
print("🔑 Using ANTHROPIC_API_KEY\n", file=sys.stderr)
|
|
189
|
-
return "https://api.anthropic.com", api_key
|
|
234
|
+
return "https://api.anthropic.com", api_key, None
|
|
190
235
|
|
|
191
|
-
return None, None
|
|
236
|
+
return None, None, None
|
|
192
237
|
|
|
193
238
|
|
|
194
239
|
def _get_session_preview(session: object) -> str:
|
|
@@ -207,7 +252,7 @@ def _get_session_preview(session: object) -> str:
|
|
|
207
252
|
|
|
208
253
|
def _get_log_file_path() -> Path:
|
|
209
254
|
"""Get user-specific log file path, creating directory if needed.
|
|
210
|
-
|
|
255
|
+
|
|
211
256
|
Uses ~/.wafer/logs/ to avoid permission issues with shared /tmp.
|
|
212
257
|
"""
|
|
213
258
|
log_dir = Path.home() / ".wafer" / "logs"
|
|
@@ -255,6 +300,7 @@ def _build_endpoint(
|
|
|
255
300
|
model_override: str | None,
|
|
256
301
|
api_base: str,
|
|
257
302
|
api_key: str,
|
|
303
|
+
api_key_refresh: Callable[[], Awaitable[str | None]] | None = None,
|
|
258
304
|
) -> Endpoint:
|
|
259
305
|
"""Build an Endpoint from template config and auth."""
|
|
260
306
|
from wafer_core.rollouts import Endpoint
|
|
@@ -269,6 +315,7 @@ def _build_endpoint(
|
|
|
269
315
|
model=model_id,
|
|
270
316
|
api_base=api_base,
|
|
271
317
|
api_key=api_key,
|
|
318
|
+
api_key_refresh=api_key_refresh,
|
|
272
319
|
thinking=thinking_config,
|
|
273
320
|
max_tokens=tpl.max_tokens,
|
|
274
321
|
)
|
|
@@ -384,6 +431,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
384
431
|
get_session: str | None = None,
|
|
385
432
|
json_output: bool = False,
|
|
386
433
|
no_sandbox: bool = False,
|
|
434
|
+
no_proxy: bool = False,
|
|
387
435
|
) -> None:
|
|
388
436
|
"""Run wevin agent in-process via rollouts."""
|
|
389
437
|
from dataclasses import asdict
|
|
@@ -499,10 +547,10 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
499
547
|
_setup_logging()
|
|
500
548
|
|
|
501
549
|
# Auth
|
|
502
|
-
api_base, api_key = _get_wafer_auth()
|
|
550
|
+
api_base, api_key, api_key_refresh = _get_wafer_auth(no_proxy=no_proxy)
|
|
503
551
|
if not api_base or not api_key:
|
|
504
552
|
print("Error: No API credentials found", file=sys.stderr)
|
|
505
|
-
print(" Run 'wafer login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
553
|
+
print(" Run 'wafer auth login' or set ANTHROPIC_API_KEY", file=sys.stderr)
|
|
506
554
|
sys.exit(1)
|
|
507
555
|
|
|
508
556
|
assert api_base is not None
|
|
@@ -525,6 +573,17 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
525
573
|
tpl = _get_default_template()
|
|
526
574
|
base_system_prompt = tpl.system_prompt
|
|
527
575
|
|
|
576
|
+
# Compose CLI instructions from --help text for allowed wafer commands
|
|
577
|
+
# TODO: The eval path doesn't have the skills layer below. If include_skills
|
|
578
|
+
# is ever enabled for optimize-kernelbench, the eval would need it too for parity.
|
|
579
|
+
# See test_eval_cli_parity.py for coverage notes.
|
|
580
|
+
if tpl.bash_allowlist:
|
|
581
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
582
|
+
|
|
583
|
+
cli_instructions = build_cli_instructions(tpl.bash_allowlist)
|
|
584
|
+
if cli_instructions:
|
|
585
|
+
base_system_prompt = base_system_prompt + "\n\n" + cli_instructions
|
|
586
|
+
|
|
528
587
|
# Append skill metadata if skills are enabled
|
|
529
588
|
if tpl.include_skills:
|
|
530
589
|
from wafer_core.rollouts.skills import discover_skills, format_skill_metadata_for_prompt
|
|
@@ -542,7 +601,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
542
601
|
resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
|
|
543
602
|
|
|
544
603
|
# Build endpoint and environment
|
|
545
|
-
endpoint = _build_endpoint(tpl, model, api_base, api_key)
|
|
604
|
+
endpoint = _build_endpoint(tpl, model, api_base, api_key, api_key_refresh)
|
|
546
605
|
environment = _build_environment(tpl, tools, corpus_path, no_sandbox)
|
|
547
606
|
|
|
548
607
|
# Session store
|