wafer-cli 0.2.59__tar.gz → 0.2.60__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/PKG-INFO +1 -1
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/pyproject.toml +1 -1
- wafer_cli-0.2.60/tests/test_direct_streaming.py +352 -0
- wafer_cli-0.2.60/wafer/templates/ask_docs.py +22 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/wevin_cli.py +160 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer_cli.egg-info/PKG-INFO +1 -1
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer_cli.egg-info/SOURCES.txt +1 -0
- wafer_cli-0.2.59/wafer/templates/ask_docs.py +0 -32
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/README.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/setup.cfg +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_agent_template_discovery.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_analytics.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_auth.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_billing.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_cli_coverage.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_cli_parity_integration.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_config_show.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_corpus_lockdown.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_deps.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_distributed_traces_cli.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_docker_progress.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_evaluate_ux.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_file_operations_integration.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_first_run.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_inference.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_json_output.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_kernel_scope_cli.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_ncu_run.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_ncu_run_e2e.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_ncu_run_local_e2e.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_nsys_analyze.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_nsys_profile.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_output.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_rocprof_compute_integration.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_skill_commands.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_ssh_integration.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_status.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_targets_ops.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_token_waste.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_ux_improvements.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/tests/test_wevin_cli.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/GUIDE.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/__init__.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/agent_defaults.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/analytics.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/api_client.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/auth.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/autotuner.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/baseline.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/billing.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/cli.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/cli_instructions.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/deps.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/distributed_traces.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/evaluate.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/global_config.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/gpu_run.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/inference.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/kernel_scope.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/ncu_analyze.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/ncu_run.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/nsys_analyze.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/nsys_profile.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/output.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/problems.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/rocprof_compute.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/rocprof_sdk.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/rocprof_systems.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/skills/packed-ops-guide/SKILL.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/SKILL.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/commands.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/evaluate.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/pitfalls.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/profiling.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/workspaces.md +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/specs_cli.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/ssh_keys.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/target_lock.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/targets.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/targets_cli.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/targets_ops.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/__init__.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/aiter_optimize.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/audit.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/optimize_flashinfer.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/optimize_kernel.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/optimize_kernelbench.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/optimize_vllm.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/templates/trace_analyze.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/tests/test_eval_cli_parity.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/trace_compare.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/tracelens.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer/workspaces.py +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer_cli.egg-info/dependency_links.txt +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer_cli.egg-info/entry_points.txt +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer_cli.egg-info/requires.txt +0 -0
- {wafer_cli-0.2.59 → wafer_cli-0.2.60}/wafer_cli.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""Tests for direct endpoint streaming (_stream_direct_endpoint).
|
|
2
|
+
|
|
3
|
+
Tests SSE parsing, output formatting for tool_call/tool_result/text/error
|
|
4
|
+
events, and JSON mode re-emission.
|
|
5
|
+
|
|
6
|
+
Run with:
|
|
7
|
+
PYTHONPATH=apps/wafer-cli uv run pytest apps/wafer-cli/tests/test_direct_streaming.py -v
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
13
|
+
|
|
14
|
+
import pytest
|
|
15
|
+
import trio
|
|
16
|
+
|
|
17
|
+
from wafer.wevin_cli import (
|
|
18
|
+
_format_tool_call_summary,
|
|
19
|
+
_format_tool_result_summary,
|
|
20
|
+
_stream_direct_endpoint,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
# Unit tests for formatting helpers
|
|
26
|
+
# ---------------------------------------------------------------------------
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestFormatToolCallSummary:
|
|
30
|
+
def test_grep(self) -> None:
|
|
31
|
+
result = _format_tool_call_summary("grep", {"pattern": "shared memory"})
|
|
32
|
+
assert result == 'searching: grep("shared memory")...'
|
|
33
|
+
|
|
34
|
+
def test_read_file(self) -> None:
|
|
35
|
+
result = _format_tool_call_summary("read_file", {"path": "./guide/memory.md"})
|
|
36
|
+
assert result == "reading: ./guide/memory.md..."
|
|
37
|
+
|
|
38
|
+
def test_list_files(self) -> None:
|
|
39
|
+
result = _format_tool_call_summary("list_files", {"pattern": "*.md"})
|
|
40
|
+
assert result == 'listing: find("*.md")...'
|
|
41
|
+
|
|
42
|
+
def test_list_files_default(self) -> None:
|
|
43
|
+
result = _format_tool_call_summary("list_files", {})
|
|
44
|
+
assert result == 'listing: find("*")...'
|
|
45
|
+
|
|
46
|
+
def test_unknown_tool(self) -> None:
|
|
47
|
+
result = _format_tool_call_summary("some_tool", {"x": 1})
|
|
48
|
+
assert "some_tool" in result
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TestFormatToolResultSummary:
|
|
52
|
+
def test_no_matches(self) -> None:
|
|
53
|
+
result = _format_tool_result_summary("grep", "No matches found.")
|
|
54
|
+
assert result == "no results"
|
|
55
|
+
|
|
56
|
+
def test_no_files(self) -> None:
|
|
57
|
+
result = _format_tool_result_summary("list_files", "No files found matching pattern.")
|
|
58
|
+
assert result == "no results"
|
|
59
|
+
|
|
60
|
+
def test_error(self) -> None:
|
|
61
|
+
result = _format_tool_result_summary("read_file", "Error: file not found")
|
|
62
|
+
assert result.startswith("Error:")
|
|
63
|
+
|
|
64
|
+
def test_grep_matches(self) -> None:
|
|
65
|
+
content = "line1\nline2\nline3\n"
|
|
66
|
+
result = _format_tool_result_summary("grep", content)
|
|
67
|
+
assert "3 matches" in result
|
|
68
|
+
|
|
69
|
+
def test_read_file_lines(self) -> None:
|
|
70
|
+
content = "a\nb\nc\nd\ne\n"
|
|
71
|
+
result = _format_tool_result_summary("read_file", content)
|
|
72
|
+
assert "5 lines" in result
|
|
73
|
+
|
|
74
|
+
def test_list_files_count(self) -> None:
|
|
75
|
+
content = "./a.md\n./b.md\n"
|
|
76
|
+
result = _format_tool_result_summary("list_files", content)
|
|
77
|
+
assert "2 files" in result
|
|
78
|
+
|
|
79
|
+
def test_unknown_tool_char_count(self) -> None:
|
|
80
|
+
result = _format_tool_result_summary("unknown", "abcdef")
|
|
81
|
+
assert "6 chars" in result
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ---------------------------------------------------------------------------
|
|
85
|
+
# SSE streaming integration tests (mocked httpx)
|
|
86
|
+
# ---------------------------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _make_sse_lines(events: list[dict]) -> list[str]:
|
|
90
|
+
"""Build SSE lines from a list of event dicts."""
|
|
91
|
+
lines = []
|
|
92
|
+
for ev in events:
|
|
93
|
+
lines.append(f"data: {json.dumps(ev)}")
|
|
94
|
+
lines.append("data: [DONE]")
|
|
95
|
+
return lines
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class _FakeResponse:
|
|
99
|
+
"""Fake httpx streaming response."""
|
|
100
|
+
|
|
101
|
+
def __init__(self, lines: list[str], status_code: int = 200) -> None:
|
|
102
|
+
self.status_code = status_code
|
|
103
|
+
self._lines = lines
|
|
104
|
+
self._raw_body = b""
|
|
105
|
+
|
|
106
|
+
async def aiter_lines(self):
|
|
107
|
+
for line in self._lines:
|
|
108
|
+
yield line
|
|
109
|
+
|
|
110
|
+
async def aread(self) -> bytes:
|
|
111
|
+
return self._raw_body
|
|
112
|
+
|
|
113
|
+
async def __aenter__(self):
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
async def __aexit__(self, *args):
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class _FakeClient:
|
|
121
|
+
"""Fake httpx.AsyncClient that returns a _FakeResponse from .stream()."""
|
|
122
|
+
|
|
123
|
+
def __init__(self, response: _FakeResponse) -> None:
|
|
124
|
+
self._response = response
|
|
125
|
+
self.last_url: str | None = None
|
|
126
|
+
self.last_json: dict | None = None
|
|
127
|
+
self.last_headers: dict | None = None
|
|
128
|
+
|
|
129
|
+
def stream(self, method: str, url: str, *, json: dict | None = None, headers: dict | None = None):
|
|
130
|
+
self.last_url = url
|
|
131
|
+
self.last_json = json
|
|
132
|
+
self.last_headers = headers
|
|
133
|
+
return self._response
|
|
134
|
+
|
|
135
|
+
async def __aenter__(self):
|
|
136
|
+
return self
|
|
137
|
+
|
|
138
|
+
async def __aexit__(self, *args):
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class TestStreamDirectEndpoint:
|
|
143
|
+
"""Tests for _stream_direct_endpoint with mocked HTTP."""
|
|
144
|
+
|
|
145
|
+
def _run(self, events: list[dict], *, json_output: bool = False, status_code: int = 200, **kwargs) -> tuple[_FakeClient, list[str], list[str]]:
|
|
146
|
+
"""Run _stream_direct_endpoint with mocked httpx, return (client, stdout_lines, stderr_lines)."""
|
|
147
|
+
lines = _make_sse_lines(events)
|
|
148
|
+
response = _FakeResponse(lines, status_code=status_code)
|
|
149
|
+
client = _FakeClient(response)
|
|
150
|
+
|
|
151
|
+
stdout_capture: list[str] = []
|
|
152
|
+
stderr_capture: list[str] = []
|
|
153
|
+
|
|
154
|
+
def mock_print(*args, **kw):
|
|
155
|
+
file = kw.get("file")
|
|
156
|
+
import sys
|
|
157
|
+
text = " ".join(str(a) for a in args)
|
|
158
|
+
if file is sys.stderr:
|
|
159
|
+
stderr_capture.append(text)
|
|
160
|
+
else:
|
|
161
|
+
stdout_capture.append(text)
|
|
162
|
+
|
|
163
|
+
async def _run_inner():
|
|
164
|
+
with patch("httpx.AsyncClient", return_value=client):
|
|
165
|
+
with patch("builtins.print", side_effect=mock_print):
|
|
166
|
+
await _stream_direct_endpoint(
|
|
167
|
+
api_url=kwargs.get("api_url", "https://api.wafer.ai"),
|
|
168
|
+
auth_token=kwargs.get("auth_token", "test-token"),
|
|
169
|
+
endpoint_path=kwargs.get("endpoint_path", "/v1/docs/query"),
|
|
170
|
+
query=kwargs.get("query", "How do bank conflicts work?"),
|
|
171
|
+
template_args=kwargs.get("template_args", None),
|
|
172
|
+
defaults=kwargs.get("defaults", {"corpus": "cuda"}),
|
|
173
|
+
json_output=json_output,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
trio.run(_run_inner)
|
|
177
|
+
return client, stdout_capture, stderr_capture
|
|
178
|
+
|
|
179
|
+
def test_text_events_stream_to_stdout(self) -> None:
|
|
180
|
+
events = [
|
|
181
|
+
{"type": "text", "content": "Bank conflicts occur "},
|
|
182
|
+
{"type": "text", "content": "when threads access "},
|
|
183
|
+
{"type": "text", "content": "the same bank."},
|
|
184
|
+
{"type": "done"},
|
|
185
|
+
]
|
|
186
|
+
_, stdout, stderr = self._run(events)
|
|
187
|
+
text_output = "".join(stdout)
|
|
188
|
+
assert "Bank conflicts occur " in text_output
|
|
189
|
+
assert "when threads access " in text_output
|
|
190
|
+
assert "the same bank." in text_output
|
|
191
|
+
|
|
192
|
+
def test_tool_call_events_render_to_stderr(self) -> None:
|
|
193
|
+
events = [
|
|
194
|
+
{"type": "tool_call", "name": "grep", "input": {"pattern": "bank conflict"}},
|
|
195
|
+
{"type": "tool_result", "name": "grep", "content": "line1\nline2\nline3\n"},
|
|
196
|
+
{"type": "text", "content": "Answer here."},
|
|
197
|
+
{"type": "done"},
|
|
198
|
+
]
|
|
199
|
+
_, stdout, stderr = self._run(events)
|
|
200
|
+
# tool_call and tool_result go to stderr (dim status)
|
|
201
|
+
stderr_text = " ".join(stderr)
|
|
202
|
+
assert "grep" in stderr_text
|
|
203
|
+
# Text goes to stdout
|
|
204
|
+
assert any("Answer here." in s for s in stdout)
|
|
205
|
+
|
|
206
|
+
def test_tool_result_no_matches(self) -> None:
|
|
207
|
+
events = [
|
|
208
|
+
{"type": "tool_call", "name": "grep", "input": {"pattern": "nonexistent"}},
|
|
209
|
+
{"type": "tool_result", "name": "grep", "content": "No matches found."},
|
|
210
|
+
{"type": "text", "content": "No results."},
|
|
211
|
+
{"type": "done"},
|
|
212
|
+
]
|
|
213
|
+
_, _, stderr = self._run(events)
|
|
214
|
+
stderr_text = " ".join(stderr)
|
|
215
|
+
assert "no results" in stderr_text
|
|
216
|
+
|
|
217
|
+
def test_error_event_to_stderr(self) -> None:
|
|
218
|
+
events = [
|
|
219
|
+
{"type": "error", "content": "Anthropic API error 500"},
|
|
220
|
+
]
|
|
221
|
+
_, _, stderr = self._run(events)
|
|
222
|
+
stderr_text = " ".join(stderr)
|
|
223
|
+
assert "Anthropic API error 500" in stderr_text
|
|
224
|
+
|
|
225
|
+
def test_request_body_merges_defaults_and_args(self) -> None:
|
|
226
|
+
events = [{"type": "text", "content": "ok"}, {"type": "done"}]
|
|
227
|
+
client, _, _ = self._run(
|
|
228
|
+
events,
|
|
229
|
+
defaults={"corpus": "cuda"},
|
|
230
|
+
template_args={"corpus": "hip"},
|
|
231
|
+
query="test question",
|
|
232
|
+
)
|
|
233
|
+
assert client.last_json is not None
|
|
234
|
+
assert client.last_json["corpus"] == "hip" # template_args override defaults
|
|
235
|
+
assert client.last_json["query"] == "test question"
|
|
236
|
+
|
|
237
|
+
def test_request_url_construction(self) -> None:
|
|
238
|
+
events = [{"type": "text", "content": "ok"}, {"type": "done"}]
|
|
239
|
+
client, _, _ = self._run(
|
|
240
|
+
events,
|
|
241
|
+
api_url="https://api.wafer.ai",
|
|
242
|
+
endpoint_path="/v1/docs/query",
|
|
243
|
+
)
|
|
244
|
+
assert client.last_url == "https://api.wafer.ai/v1/docs/query"
|
|
245
|
+
|
|
246
|
+
def test_request_url_strips_trailing_slash(self) -> None:
|
|
247
|
+
events = [{"type": "text", "content": "ok"}, {"type": "done"}]
|
|
248
|
+
client, _, _ = self._run(
|
|
249
|
+
events,
|
|
250
|
+
api_url="https://api.wafer.ai/",
|
|
251
|
+
endpoint_path="/v1/docs/query",
|
|
252
|
+
)
|
|
253
|
+
assert client.last_url == "https://api.wafer.ai/v1/docs/query"
|
|
254
|
+
|
|
255
|
+
def test_auth_header_sent(self) -> None:
|
|
256
|
+
events = [{"type": "text", "content": "ok"}, {"type": "done"}]
|
|
257
|
+
client, _, _ = self._run(events, auth_token="my-secret-token")
|
|
258
|
+
assert client.last_headers is not None
|
|
259
|
+
assert client.last_headers["Authorization"] == "Bearer my-secret-token"
|
|
260
|
+
|
|
261
|
+
def test_json_mode_text_events(self) -> None:
|
|
262
|
+
events = [
|
|
263
|
+
{"type": "text", "content": "Hello "},
|
|
264
|
+
{"type": "text", "content": "world"},
|
|
265
|
+
{"type": "done"},
|
|
266
|
+
]
|
|
267
|
+
_, stdout, _ = self._run(events, json_output=True)
|
|
268
|
+
# In JSON mode, output is NDJSON lines
|
|
269
|
+
json_events = [json.loads(line) for line in stdout if line.strip()]
|
|
270
|
+
text_deltas = [e for e in json_events if e.get("type") == "text_delta"]
|
|
271
|
+
assert len(text_deltas) == 2
|
|
272
|
+
assert text_deltas[0]["delta"] == "Hello "
|
|
273
|
+
assert text_deltas[1]["delta"] == "world"
|
|
274
|
+
|
|
275
|
+
def test_json_mode_tool_events(self) -> None:
|
|
276
|
+
events = [
|
|
277
|
+
{"type": "tool_call", "name": "grep", "input": {"pattern": "warp"}},
|
|
278
|
+
{"type": "tool_result", "name": "grep", "content": "line1\n"},
|
|
279
|
+
{"type": "text", "content": "Answer"},
|
|
280
|
+
{"type": "done"},
|
|
281
|
+
]
|
|
282
|
+
_, stdout, _ = self._run(events, json_output=True)
|
|
283
|
+
json_events = [json.loads(line) for line in stdout if line.strip()]
|
|
284
|
+
types = [e["type"] for e in json_events]
|
|
285
|
+
assert "tool_call_start" in types
|
|
286
|
+
assert "tool_call_end" in types
|
|
287
|
+
assert "tool_result" in types
|
|
288
|
+
assert "text_delta" in types
|
|
289
|
+
assert "session_end" in types
|
|
290
|
+
|
|
291
|
+
def test_json_mode_error_event(self) -> None:
|
|
292
|
+
events = [
|
|
293
|
+
{"type": "error", "content": "Something went wrong"},
|
|
294
|
+
]
|
|
295
|
+
_, stdout, _ = self._run(events, json_output=True)
|
|
296
|
+
json_events = [json.loads(line) for line in stdout if line.strip()]
|
|
297
|
+
error_events = [e for e in json_events if e.get("type") == "error"]
|
|
298
|
+
assert len(error_events) == 1
|
|
299
|
+
assert error_events[0]["error"] == "Something went wrong"
|
|
300
|
+
|
|
301
|
+
def test_http_401_exits(self) -> None:
|
|
302
|
+
response = _FakeResponse([], status_code=401)
|
|
303
|
+
client = _FakeClient(response)
|
|
304
|
+
|
|
305
|
+
with pytest.raises(SystemExit) as exc_info:
|
|
306
|
+
async def _run_inner():
|
|
307
|
+
with patch("httpx.AsyncClient", return_value=client):
|
|
308
|
+
await _stream_direct_endpoint(
|
|
309
|
+
api_url="https://api.wafer.ai",
|
|
310
|
+
auth_token="bad-token",
|
|
311
|
+
endpoint_path="/v1/docs/query",
|
|
312
|
+
query="test",
|
|
313
|
+
template_args=None,
|
|
314
|
+
defaults={"corpus": "cuda"},
|
|
315
|
+
json_output=False,
|
|
316
|
+
)
|
|
317
|
+
trio.run(_run_inner)
|
|
318
|
+
assert exc_info.value.code == 1
|
|
319
|
+
|
|
320
|
+
def test_http_402_exits(self) -> None:
|
|
321
|
+
response = _FakeResponse([], status_code=402)
|
|
322
|
+
client = _FakeClient(response)
|
|
323
|
+
|
|
324
|
+
with pytest.raises(SystemExit) as exc_info:
|
|
325
|
+
async def _run_inner():
|
|
326
|
+
with patch("httpx.AsyncClient", return_value=client):
|
|
327
|
+
await _stream_direct_endpoint(
|
|
328
|
+
api_url="https://api.wafer.ai",
|
|
329
|
+
auth_token="token",
|
|
330
|
+
endpoint_path="/v1/docs/query",
|
|
331
|
+
query="test",
|
|
332
|
+
template_args=None,
|
|
333
|
+
defaults={"corpus": "cuda"},
|
|
334
|
+
json_output=False,
|
|
335
|
+
)
|
|
336
|
+
trio.run(_run_inner)
|
|
337
|
+
assert exc_info.value.code == 1
|
|
338
|
+
|
|
339
|
+
def test_multiple_tool_turns(self) -> None:
|
|
340
|
+
events = [
|
|
341
|
+
{"type": "tool_call", "name": "grep", "input": {"pattern": "bank conflict"}},
|
|
342
|
+
{"type": "tool_result", "name": "grep", "content": "file.md:10:bank conflict\n"},
|
|
343
|
+
{"type": "tool_call", "name": "read_file", "input": {"path": "./file.md"}},
|
|
344
|
+
{"type": "tool_result", "name": "read_file", "content": "Full content here\n"},
|
|
345
|
+
{"type": "text", "content": "Final answer."},
|
|
346
|
+
{"type": "done"},
|
|
347
|
+
]
|
|
348
|
+
_, stdout, stderr = self._run(events)
|
|
349
|
+
stderr_text = " ".join(stderr)
|
|
350
|
+
assert "grep" in stderr_text
|
|
351
|
+
assert "reading" in stderr_text
|
|
352
|
+
assert any("Final answer." in s for s in stdout)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Template for querying GPU documentation.
|
|
2
|
+
|
|
3
|
+
Streams directly from the server-side docs agent — no local agent loop.
|
|
4
|
+
The server runs a multi-turn Sonnet agent with grep/read_file/list_files
|
|
5
|
+
tools against the corpus volume in a Modal sandbox.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
wafer agent -t ask-docs "How do bank conflicts occur?"
|
|
9
|
+
wafer agent -t ask-docs --args corpus=hip "Explain HIP streams"
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from wafer_core.rollouts.templates import TemplateConfig
|
|
14
|
+
except ImportError:
|
|
15
|
+
from rollouts.templates import TemplateConfig
|
|
16
|
+
|
|
17
|
+
template = TemplateConfig(
|
|
18
|
+
name="ask-docs",
|
|
19
|
+
description="Query GPU documentation to answer technical questions",
|
|
20
|
+
direct_endpoint="/v1/docs/query",
|
|
21
|
+
defaults={"corpus": "cuda"},
|
|
22
|
+
)
|
|
@@ -429,6 +429,140 @@ def _load_template(
|
|
|
429
429
|
return template, None
|
|
430
430
|
except Exception as e:
|
|
431
431
|
return None, str(e)
|
|
432
|
+
async def _stream_direct_endpoint(
|
|
433
|
+
api_url: str,
|
|
434
|
+
auth_token: str,
|
|
435
|
+
endpoint_path: str,
|
|
436
|
+
query: str,
|
|
437
|
+
template_args: dict[str, str] | None,
|
|
438
|
+
defaults: dict[str, str] | None,
|
|
439
|
+
json_output: bool,
|
|
440
|
+
) -> None:
|
|
441
|
+
"""Stream SSE from a server-side agent endpoint directly to the terminal.
|
|
442
|
+
|
|
443
|
+
Bypasses the local agent loop entirely. Used when a template has
|
|
444
|
+
`direct_endpoint` set — the server runs the full agent loop and we
|
|
445
|
+
just render the events.
|
|
446
|
+
"""
|
|
447
|
+
import httpx
|
|
448
|
+
|
|
449
|
+
url = f"{api_url.rstrip('/')}{endpoint_path}"
|
|
450
|
+
body: dict[str, str] = {}
|
|
451
|
+
if defaults:
|
|
452
|
+
body.update(defaults)
|
|
453
|
+
if template_args:
|
|
454
|
+
body.update(template_args)
|
|
455
|
+
body["query"] = query
|
|
456
|
+
|
|
457
|
+
headers = {
|
|
458
|
+
"Authorization": f"Bearer {auth_token}",
|
|
459
|
+
"Content-Type": "application/json",
|
|
460
|
+
"Accept": "text/event-stream",
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
frontend: StreamingChunkFrontend | None = None
|
|
464
|
+
if json_output:
|
|
465
|
+
frontend = StreamingChunkFrontend()
|
|
466
|
+
|
|
467
|
+
async with httpx.AsyncClient(timeout=180.0) as client:
|
|
468
|
+
async with client.stream("POST", url, json=body, headers=headers) as response:
|
|
469
|
+
if response.status_code == 401:
|
|
470
|
+
print("Error: Authentication failed. Run 'wafer settings login'.", file=sys.stderr)
|
|
471
|
+
sys.exit(1)
|
|
472
|
+
if response.status_code == 402:
|
|
473
|
+
print("Error: Insufficient credits. Check 'wafer settings billing'.", file=sys.stderr)
|
|
474
|
+
sys.exit(1)
|
|
475
|
+
if response.status_code != 200:
|
|
476
|
+
raw = await response.aread()
|
|
477
|
+
print(f"Error: API returned {response.status_code}: {raw.decode(errors='replace')[:500]}", file=sys.stderr)
|
|
478
|
+
sys.exit(1)
|
|
479
|
+
|
|
480
|
+
async for line in response.aiter_lines():
|
|
481
|
+
if not line.startswith("data: "):
|
|
482
|
+
continue
|
|
483
|
+
data_str = line[len("data: "):]
|
|
484
|
+
if data_str == "[DONE]":
|
|
485
|
+
break
|
|
486
|
+
event = json.loads(data_str)
|
|
487
|
+
event_type = event.get("type", "")
|
|
488
|
+
|
|
489
|
+
if event_type == "tool_call":
|
|
490
|
+
tool_name = event.get("name", "")
|
|
491
|
+
tool_input = event.get("input", {})
|
|
492
|
+
summary = _format_tool_call_summary(tool_name, tool_input)
|
|
493
|
+
if json_output:
|
|
494
|
+
assert frontend is not None
|
|
495
|
+
frontend._emit({"type": "tool_call_start", "tool_name": tool_name})
|
|
496
|
+
frontend._emit({"type": "tool_call_end", "tool_name": tool_name, "args": tool_input})
|
|
497
|
+
else:
|
|
498
|
+
print(f"\033[2m {summary}\033[0m", file=sys.stderr)
|
|
499
|
+
|
|
500
|
+
elif event_type == "tool_result":
|
|
501
|
+
tool_name = event.get("name", "")
|
|
502
|
+
content = event.get("content", "")
|
|
503
|
+
summary = _format_tool_result_summary(tool_name, content)
|
|
504
|
+
if json_output:
|
|
505
|
+
assert frontend is not None
|
|
506
|
+
frontend._emit({"type": "tool_result", "is_error": False})
|
|
507
|
+
else:
|
|
508
|
+
print(f"\033[2m {summary}\033[0m", file=sys.stderr)
|
|
509
|
+
|
|
510
|
+
elif event_type == "text":
|
|
511
|
+
text = event.get("content", "")
|
|
512
|
+
if json_output:
|
|
513
|
+
assert frontend is not None
|
|
514
|
+
frontend._emit({"type": "text_delta", "delta": text})
|
|
515
|
+
else:
|
|
516
|
+
print(text, end="", flush=True)
|
|
517
|
+
|
|
518
|
+
elif event_type == "error":
|
|
519
|
+
error_msg = event.get("content", "Unknown error")
|
|
520
|
+
if json_output:
|
|
521
|
+
assert frontend is not None
|
|
522
|
+
frontend._emit({"type": "error", "error": error_msg})
|
|
523
|
+
else:
|
|
524
|
+
print(f"\nError: {error_msg}", file=sys.stderr)
|
|
525
|
+
|
|
526
|
+
elif event_type == "done":
|
|
527
|
+
break
|
|
528
|
+
|
|
529
|
+
if json_output:
|
|
530
|
+
assert frontend is not None
|
|
531
|
+
frontend._emit({"type": "session_end"})
|
|
532
|
+
else:
|
|
533
|
+
print() # trailing newline after streamed text
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def _format_tool_call_summary(tool_name: str, tool_input: dict) -> str:
|
|
537
|
+
"""Format a tool call into a concise status line."""
|
|
538
|
+
if tool_name == "grep":
|
|
539
|
+
pattern = tool_input.get("pattern", "")
|
|
540
|
+
return f'searching: grep("{pattern}")...'
|
|
541
|
+
if tool_name == "read_file":
|
|
542
|
+
path = tool_input.get("path", "")
|
|
543
|
+
return f"reading: {path}..."
|
|
544
|
+
if tool_name == "list_files":
|
|
545
|
+
pattern = tool_input.get("pattern", "*")
|
|
546
|
+
return f'listing: find("{pattern}")...'
|
|
547
|
+
return f"{tool_name}({json.dumps(tool_input)})..."
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def _format_tool_result_summary(tool_name: str, content: str) -> str:
|
|
551
|
+
"""Format a tool result into a concise status line."""
|
|
552
|
+
if "No matches found" in content or "No files found" in content:
|
|
553
|
+
return "no results"
|
|
554
|
+
if "Error:" in content:
|
|
555
|
+
return content[:80]
|
|
556
|
+
line_count = content.count("\n")
|
|
557
|
+
if tool_name == "grep":
|
|
558
|
+
return f"found {line_count} matches"
|
|
559
|
+
if tool_name == "read_file":
|
|
560
|
+
return f"read {line_count} lines"
|
|
561
|
+
if tool_name == "list_files":
|
|
562
|
+
return f"found {line_count} files"
|
|
563
|
+
return f"got {len(content)} chars"
|
|
564
|
+
|
|
565
|
+
|
|
432
566
|
def main( # noqa: PLR0913, PLR0915
|
|
433
567
|
prompt: str | None = None,
|
|
434
568
|
interactive: bool = False,
|
|
@@ -587,6 +721,32 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
587
721
|
print(f"Template: {tpl.name}", file=sys.stderr)
|
|
588
722
|
print(f" {tpl.description}", file=sys.stderr)
|
|
589
723
|
print(file=sys.stderr)
|
|
724
|
+
# Direct endpoint: bypass local agent loop, stream from server
|
|
725
|
+
if tpl.direct_endpoint is not None:
|
|
726
|
+
assert prompt, (
|
|
727
|
+
f"Template '{tpl.name}' uses direct streaming and requires a prompt. "
|
|
728
|
+
f"Usage: wafer agent -t {tpl.name} \"your question\""
|
|
729
|
+
)
|
|
730
|
+
wafer_api_url = os.environ.get("WAFER_API_URL", get_api_url())
|
|
731
|
+
wafer_auth_token = os.environ.get("WAFER_AUTH_TOKEN", "")
|
|
732
|
+
assert wafer_auth_token, "WAFER_AUTH_TOKEN not set. Run 'wafer settings login' first."
|
|
733
|
+
_direct_tpl = tpl
|
|
734
|
+
_direct_prompt = prompt
|
|
735
|
+
|
|
736
|
+
async def _run_direct() -> None:
|
|
737
|
+
await _stream_direct_endpoint(
|
|
738
|
+
api_url=wafer_api_url,
|
|
739
|
+
auth_token=wafer_auth_token,
|
|
740
|
+
endpoint_path=_direct_tpl.direct_endpoint,
|
|
741
|
+
query=_direct_prompt,
|
|
742
|
+
template_args=template_args,
|
|
743
|
+
defaults=_direct_tpl.defaults if _direct_tpl.defaults else None,
|
|
744
|
+
json_output=json_output,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
import trio
|
|
748
|
+
trio.run(_run_direct)
|
|
749
|
+
return
|
|
590
750
|
else:
|
|
591
751
|
tpl = _get_default_template()
|
|
592
752
|
base_system_prompt = tpl.system_prompt
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
"""Template for querying GPU documentation.
|
|
2
|
-
|
|
3
|
-
Usage:
|
|
4
|
-
wafer agent -t ask-docs "How do bank conflicts occur?"
|
|
5
|
-
wafer agent -t ask-docs "Explain warp divergence in CUDA"
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
try:
|
|
9
|
-
from wafer_core.rollouts.templates import TemplateConfig
|
|
10
|
-
except ImportError:
|
|
11
|
-
from rollouts.templates import TemplateConfig
|
|
12
|
-
|
|
13
|
-
template = TemplateConfig(
|
|
14
|
-
name="ask-docs",
|
|
15
|
-
description="Query GPU documentation to answer technical questions",
|
|
16
|
-
system_prompt="""You are a GPU programming expert. Use the ask_docs tool to search documentation and answer questions.
|
|
17
|
-
|
|
18
|
-
Available corpora: cuda, cutlass, hip, amd, cdna3, hopper, rdna35, llvm-amdgpu, gcnasm.
|
|
19
|
-
|
|
20
|
-
Strategy:
|
|
21
|
-
1. Call ask_docs with the user's question and the appropriate corpus
|
|
22
|
-
2. If the answer is incomplete, call ask_docs again with a refined query or different corpus
|
|
23
|
-
3. Synthesize a clear, accurate answer
|
|
24
|
-
|
|
25
|
-
Be concise but thorough. Include code examples when relevant.""",
|
|
26
|
-
tools=["ask_docs"],
|
|
27
|
-
model="anthropic/claude-opus-4-5-20251101",
|
|
28
|
-
max_tokens=8192,
|
|
29
|
-
thinking=False,
|
|
30
|
-
thinking_budget=10000,
|
|
31
|
-
single_turn=False,
|
|
32
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|