wafer-cli 0.2.58__tar.gz → 0.2.60__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/PKG-INFO +1 -1
  2. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/pyproject.toml +1 -1
  3. wafer_cli-0.2.60/tests/test_direct_streaming.py +352 -0
  4. wafer_cli-0.2.60/wafer/templates/ask_docs.py +22 -0
  5. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/wevin_cli.py +160 -0
  6. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer_cli.egg-info/PKG-INFO +1 -1
  7. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer_cli.egg-info/SOURCES.txt +1 -0
  8. wafer_cli-0.2.58/wafer/templates/ask_docs.py +0 -32
  9. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/README.md +0 -0
  10. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/setup.cfg +0 -0
  11. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_agent_template_discovery.py +0 -0
  12. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_analytics.py +0 -0
  13. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_auth.py +0 -0
  14. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_billing.py +0 -0
  15. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_cli_coverage.py +0 -0
  16. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_cli_parity_integration.py +0 -0
  17. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_config_show.py +0 -0
  18. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_corpus_lockdown.py +0 -0
  19. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_deps.py +0 -0
  20. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_distributed_traces_cli.py +0 -0
  21. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_docker_progress.py +0 -0
  22. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_evaluate_ux.py +0 -0
  23. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_file_operations_integration.py +0 -0
  24. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_first_run.py +0 -0
  25. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_inference.py +0 -0
  26. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_json_output.py +0 -0
  27. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_kernel_scope_cli.py +0 -0
  28. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_ncu_run.py +0 -0
  29. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_ncu_run_e2e.py +0 -0
  30. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_ncu_run_local_e2e.py +0 -0
  31. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_nsys_analyze.py +0 -0
  32. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_nsys_profile.py +0 -0
  33. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_output.py +0 -0
  34. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_rocprof_compute_integration.py +0 -0
  35. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_skill_commands.py +0 -0
  36. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_ssh_integration.py +0 -0
  37. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_status.py +0 -0
  38. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_targets_ops.py +0 -0
  39. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_token_waste.py +0 -0
  40. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_ux_improvements.py +0 -0
  41. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/tests/test_wevin_cli.py +0 -0
  42. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/GUIDE.md +0 -0
  43. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/__init__.py +0 -0
  44. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/agent_defaults.py +0 -0
  45. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/analytics.py +0 -0
  46. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/api_client.py +0 -0
  47. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/auth.py +0 -0
  48. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/autotuner.py +0 -0
  49. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/baseline.py +0 -0
  50. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/billing.py +0 -0
  51. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/cli.py +0 -0
  52. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/cli_instructions.py +0 -0
  53. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/deps.py +0 -0
  54. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/distributed_traces.py +0 -0
  55. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/evaluate.py +0 -0
  56. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/global_config.py +0 -0
  57. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/gpu_run.py +0 -0
  58. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/inference.py +0 -0
  59. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/kernel_scope.py +0 -0
  60. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/ncu_analyze.py +0 -0
  61. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/ncu_run.py +0 -0
  62. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/nsys_analyze.py +0 -0
  63. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/nsys_profile.py +0 -0
  64. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/output.py +0 -0
  65. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/problems.py +0 -0
  66. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/rocprof_compute.py +0 -0
  67. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/rocprof_sdk.py +0 -0
  68. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/rocprof_systems.py +0 -0
  69. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/skills/packed-ops-guide/SKILL.md +0 -0
  70. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/SKILL.md +0 -0
  71. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/commands.md +0 -0
  72. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/evaluate.md +0 -0
  73. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/pitfalls.md +0 -0
  74. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/profiling.md +0 -0
  75. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/skills/wafer-guide/workspaces.md +0 -0
  76. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/specs_cli.py +0 -0
  77. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/ssh_keys.py +0 -0
  78. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/target_lock.py +0 -0
  79. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/targets.py +0 -0
  80. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/targets_cli.py +0 -0
  81. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/targets_ops.py +0 -0
  82. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/__init__.py +0 -0
  83. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/aiter_optimize.py +0 -0
  84. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/audit.py +0 -0
  85. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/optimize_flashinfer.py +0 -0
  86. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/optimize_kernel.py +0 -0
  87. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/optimize_kernelbench.py +0 -0
  88. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/optimize_vllm.py +0 -0
  89. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/templates/trace_analyze.py +0 -0
  90. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/tests/test_eval_cli_parity.py +0 -0
  91. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/trace_compare.py +0 -0
  92. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/tracelens.py +0 -0
  93. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer/workspaces.py +0 -0
  94. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer_cli.egg-info/dependency_links.txt +0 -0
  95. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer_cli.egg-info/entry_points.txt +0 -0
  96. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer_cli.egg-info/requires.txt +0 -0
  97. {wafer_cli-0.2.58 → wafer_cli-0.2.60}/wafer_cli.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.58
3
+ Version: 0.2.60
4
4
  Summary: CLI for running GPU workloads, managing remote workspaces, and evaluating/optimizing kernels
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "wafer-cli"
3
- version = "0.2.58"
3
+ version = "0.2.60"
4
4
  description = "CLI for running GPU workloads, managing remote workspaces, and evaluating/optimizing kernels"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -0,0 +1,352 @@
1
+ """Tests for direct endpoint streaming (_stream_direct_endpoint).
2
+
3
+ Tests SSE parsing, output formatting for tool_call/tool_result/text/error
4
+ events, and JSON mode re-emission.
5
+
6
+ Run with:
7
+ PYTHONPATH=apps/wafer-cli uv run pytest apps/wafer-cli/tests/test_direct_streaming.py -v
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ from unittest.mock import AsyncMock, MagicMock, patch
13
+
14
+ import pytest
15
+ import trio
16
+
17
+ from wafer.wevin_cli import (
18
+ _format_tool_call_summary,
19
+ _format_tool_result_summary,
20
+ _stream_direct_endpoint,
21
+ )
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Unit tests for formatting helpers
26
+ # ---------------------------------------------------------------------------
27
+
28
+
29
+ class TestFormatToolCallSummary:
30
+ def test_grep(self) -> None:
31
+ result = _format_tool_call_summary("grep", {"pattern": "shared memory"})
32
+ assert result == 'searching: grep("shared memory")...'
33
+
34
+ def test_read_file(self) -> None:
35
+ result = _format_tool_call_summary("read_file", {"path": "./guide/memory.md"})
36
+ assert result == "reading: ./guide/memory.md..."
37
+
38
+ def test_list_files(self) -> None:
39
+ result = _format_tool_call_summary("list_files", {"pattern": "*.md"})
40
+ assert result == 'listing: find("*.md")...'
41
+
42
+ def test_list_files_default(self) -> None:
43
+ result = _format_tool_call_summary("list_files", {})
44
+ assert result == 'listing: find("*")...'
45
+
46
+ def test_unknown_tool(self) -> None:
47
+ result = _format_tool_call_summary("some_tool", {"x": 1})
48
+ assert "some_tool" in result
49
+
50
+
51
+ class TestFormatToolResultSummary:
52
+ def test_no_matches(self) -> None:
53
+ result = _format_tool_result_summary("grep", "No matches found.")
54
+ assert result == "no results"
55
+
56
+ def test_no_files(self) -> None:
57
+ result = _format_tool_result_summary("list_files", "No files found matching pattern.")
58
+ assert result == "no results"
59
+
60
+ def test_error(self) -> None:
61
+ result = _format_tool_result_summary("read_file", "Error: file not found")
62
+ assert result.startswith("Error:")
63
+
64
+ def test_grep_matches(self) -> None:
65
+ content = "line1\nline2\nline3\n"
66
+ result = _format_tool_result_summary("grep", content)
67
+ assert "3 matches" in result
68
+
69
+ def test_read_file_lines(self) -> None:
70
+ content = "a\nb\nc\nd\ne\n"
71
+ result = _format_tool_result_summary("read_file", content)
72
+ assert "5 lines" in result
73
+
74
+ def test_list_files_count(self) -> None:
75
+ content = "./a.md\n./b.md\n"
76
+ result = _format_tool_result_summary("list_files", content)
77
+ assert "2 files" in result
78
+
79
+ def test_unknown_tool_char_count(self) -> None:
80
+ result = _format_tool_result_summary("unknown", "abcdef")
81
+ assert "6 chars" in result
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # SSE streaming integration tests (mocked httpx)
86
+ # ---------------------------------------------------------------------------
87
+
88
+
89
+ def _make_sse_lines(events: list[dict]) -> list[str]:
90
+ """Build SSE lines from a list of event dicts."""
91
+ lines = []
92
+ for ev in events:
93
+ lines.append(f"data: {json.dumps(ev)}")
94
+ lines.append("data: [DONE]")
95
+ return lines
96
+
97
+
98
+ class _FakeResponse:
99
+ """Fake httpx streaming response."""
100
+
101
+ def __init__(self, lines: list[str], status_code: int = 200) -> None:
102
+ self.status_code = status_code
103
+ self._lines = lines
104
+ self._raw_body = b""
105
+
106
+ async def aiter_lines(self):
107
+ for line in self._lines:
108
+ yield line
109
+
110
+ async def aread(self) -> bytes:
111
+ return self._raw_body
112
+
113
+ async def __aenter__(self):
114
+ return self
115
+
116
+ async def __aexit__(self, *args):
117
+ pass
118
+
119
+
120
+ class _FakeClient:
121
+ """Fake httpx.AsyncClient that returns a _FakeResponse from .stream()."""
122
+
123
+ def __init__(self, response: _FakeResponse) -> None:
124
+ self._response = response
125
+ self.last_url: str | None = None
126
+ self.last_json: dict | None = None
127
+ self.last_headers: dict | None = None
128
+
129
+ def stream(self, method: str, url: str, *, json: dict | None = None, headers: dict | None = None):
130
+ self.last_url = url
131
+ self.last_json = json
132
+ self.last_headers = headers
133
+ return self._response
134
+
135
+ async def __aenter__(self):
136
+ return self
137
+
138
+ async def __aexit__(self, *args):
139
+ pass
140
+
141
+
142
+ class TestStreamDirectEndpoint:
143
+ """Tests for _stream_direct_endpoint with mocked HTTP."""
144
+
145
+ def _run(self, events: list[dict], *, json_output: bool = False, status_code: int = 200, **kwargs) -> tuple[_FakeClient, list[str], list[str]]:
146
+ """Run _stream_direct_endpoint with mocked httpx, return (client, stdout_lines, stderr_lines)."""
147
+ lines = _make_sse_lines(events)
148
+ response = _FakeResponse(lines, status_code=status_code)
149
+ client = _FakeClient(response)
150
+
151
+ stdout_capture: list[str] = []
152
+ stderr_capture: list[str] = []
153
+
154
+ def mock_print(*args, **kw):
155
+ file = kw.get("file")
156
+ import sys
157
+ text = " ".join(str(a) for a in args)
158
+ if file is sys.stderr:
159
+ stderr_capture.append(text)
160
+ else:
161
+ stdout_capture.append(text)
162
+
163
+ async def _run_inner():
164
+ with patch("httpx.AsyncClient", return_value=client):
165
+ with patch("builtins.print", side_effect=mock_print):
166
+ await _stream_direct_endpoint(
167
+ api_url=kwargs.get("api_url", "https://api.wafer.ai"),
168
+ auth_token=kwargs.get("auth_token", "test-token"),
169
+ endpoint_path=kwargs.get("endpoint_path", "/v1/docs/query"),
170
+ query=kwargs.get("query", "How do bank conflicts work?"),
171
+ template_args=kwargs.get("template_args", None),
172
+ defaults=kwargs.get("defaults", {"corpus": "cuda"}),
173
+ json_output=json_output,
174
+ )
175
+
176
+ trio.run(_run_inner)
177
+ return client, stdout_capture, stderr_capture
178
+
179
+ def test_text_events_stream_to_stdout(self) -> None:
180
+ events = [
181
+ {"type": "text", "content": "Bank conflicts occur "},
182
+ {"type": "text", "content": "when threads access "},
183
+ {"type": "text", "content": "the same bank."},
184
+ {"type": "done"},
185
+ ]
186
+ _, stdout, stderr = self._run(events)
187
+ text_output = "".join(stdout)
188
+ assert "Bank conflicts occur " in text_output
189
+ assert "when threads access " in text_output
190
+ assert "the same bank." in text_output
191
+
192
+ def test_tool_call_events_render_to_stderr(self) -> None:
193
+ events = [
194
+ {"type": "tool_call", "name": "grep", "input": {"pattern": "bank conflict"}},
195
+ {"type": "tool_result", "name": "grep", "content": "line1\nline2\nline3\n"},
196
+ {"type": "text", "content": "Answer here."},
197
+ {"type": "done"},
198
+ ]
199
+ _, stdout, stderr = self._run(events)
200
+ # tool_call and tool_result go to stderr (dim status)
201
+ stderr_text = " ".join(stderr)
202
+ assert "grep" in stderr_text
203
+ # Text goes to stdout
204
+ assert any("Answer here." in s for s in stdout)
205
+
206
+ def test_tool_result_no_matches(self) -> None:
207
+ events = [
208
+ {"type": "tool_call", "name": "grep", "input": {"pattern": "nonexistent"}},
209
+ {"type": "tool_result", "name": "grep", "content": "No matches found."},
210
+ {"type": "text", "content": "No results."},
211
+ {"type": "done"},
212
+ ]
213
+ _, _, stderr = self._run(events)
214
+ stderr_text = " ".join(stderr)
215
+ assert "no results" in stderr_text
216
+
217
+ def test_error_event_to_stderr(self) -> None:
218
+ events = [
219
+ {"type": "error", "content": "Anthropic API error 500"},
220
+ ]
221
+ _, _, stderr = self._run(events)
222
+ stderr_text = " ".join(stderr)
223
+ assert "Anthropic API error 500" in stderr_text
224
+
225
+ def test_request_body_merges_defaults_and_args(self) -> None:
226
+ events = [{"type": "text", "content": "ok"}, {"type": "done"}]
227
+ client, _, _ = self._run(
228
+ events,
229
+ defaults={"corpus": "cuda"},
230
+ template_args={"corpus": "hip"},
231
+ query="test question",
232
+ )
233
+ assert client.last_json is not None
234
+ assert client.last_json["corpus"] == "hip" # template_args override defaults
235
+ assert client.last_json["query"] == "test question"
236
+
237
+ def test_request_url_construction(self) -> None:
238
+ events = [{"type": "text", "content": "ok"}, {"type": "done"}]
239
+ client, _, _ = self._run(
240
+ events,
241
+ api_url="https://api.wafer.ai",
242
+ endpoint_path="/v1/docs/query",
243
+ )
244
+ assert client.last_url == "https://api.wafer.ai/v1/docs/query"
245
+
246
+ def test_request_url_strips_trailing_slash(self) -> None:
247
+ events = [{"type": "text", "content": "ok"}, {"type": "done"}]
248
+ client, _, _ = self._run(
249
+ events,
250
+ api_url="https://api.wafer.ai/",
251
+ endpoint_path="/v1/docs/query",
252
+ )
253
+ assert client.last_url == "https://api.wafer.ai/v1/docs/query"
254
+
255
+ def test_auth_header_sent(self) -> None:
256
+ events = [{"type": "text", "content": "ok"}, {"type": "done"}]
257
+ client, _, _ = self._run(events, auth_token="my-secret-token")
258
+ assert client.last_headers is not None
259
+ assert client.last_headers["Authorization"] == "Bearer my-secret-token"
260
+
261
+ def test_json_mode_text_events(self) -> None:
262
+ events = [
263
+ {"type": "text", "content": "Hello "},
264
+ {"type": "text", "content": "world"},
265
+ {"type": "done"},
266
+ ]
267
+ _, stdout, _ = self._run(events, json_output=True)
268
+ # In JSON mode, output is NDJSON lines
269
+ json_events = [json.loads(line) for line in stdout if line.strip()]
270
+ text_deltas = [e for e in json_events if e.get("type") == "text_delta"]
271
+ assert len(text_deltas) == 2
272
+ assert text_deltas[0]["delta"] == "Hello "
273
+ assert text_deltas[1]["delta"] == "world"
274
+
275
+ def test_json_mode_tool_events(self) -> None:
276
+ events = [
277
+ {"type": "tool_call", "name": "grep", "input": {"pattern": "warp"}},
278
+ {"type": "tool_result", "name": "grep", "content": "line1\n"},
279
+ {"type": "text", "content": "Answer"},
280
+ {"type": "done"},
281
+ ]
282
+ _, stdout, _ = self._run(events, json_output=True)
283
+ json_events = [json.loads(line) for line in stdout if line.strip()]
284
+ types = [e["type"] for e in json_events]
285
+ assert "tool_call_start" in types
286
+ assert "tool_call_end" in types
287
+ assert "tool_result" in types
288
+ assert "text_delta" in types
289
+ assert "session_end" in types
290
+
291
+ def test_json_mode_error_event(self) -> None:
292
+ events = [
293
+ {"type": "error", "content": "Something went wrong"},
294
+ ]
295
+ _, stdout, _ = self._run(events, json_output=True)
296
+ json_events = [json.loads(line) for line in stdout if line.strip()]
297
+ error_events = [e for e in json_events if e.get("type") == "error"]
298
+ assert len(error_events) == 1
299
+ assert error_events[0]["error"] == "Something went wrong"
300
+
301
+ def test_http_401_exits(self) -> None:
302
+ response = _FakeResponse([], status_code=401)
303
+ client = _FakeClient(response)
304
+
305
+ with pytest.raises(SystemExit) as exc_info:
306
+ async def _run_inner():
307
+ with patch("httpx.AsyncClient", return_value=client):
308
+ await _stream_direct_endpoint(
309
+ api_url="https://api.wafer.ai",
310
+ auth_token="bad-token",
311
+ endpoint_path="/v1/docs/query",
312
+ query="test",
313
+ template_args=None,
314
+ defaults={"corpus": "cuda"},
315
+ json_output=False,
316
+ )
317
+ trio.run(_run_inner)
318
+ assert exc_info.value.code == 1
319
+
320
+ def test_http_402_exits(self) -> None:
321
+ response = _FakeResponse([], status_code=402)
322
+ client = _FakeClient(response)
323
+
324
+ with pytest.raises(SystemExit) as exc_info:
325
+ async def _run_inner():
326
+ with patch("httpx.AsyncClient", return_value=client):
327
+ await _stream_direct_endpoint(
328
+ api_url="https://api.wafer.ai",
329
+ auth_token="token",
330
+ endpoint_path="/v1/docs/query",
331
+ query="test",
332
+ template_args=None,
333
+ defaults={"corpus": "cuda"},
334
+ json_output=False,
335
+ )
336
+ trio.run(_run_inner)
337
+ assert exc_info.value.code == 1
338
+
339
+ def test_multiple_tool_turns(self) -> None:
340
+ events = [
341
+ {"type": "tool_call", "name": "grep", "input": {"pattern": "bank conflict"}},
342
+ {"type": "tool_result", "name": "grep", "content": "file.md:10:bank conflict\n"},
343
+ {"type": "tool_call", "name": "read_file", "input": {"path": "./file.md"}},
344
+ {"type": "tool_result", "name": "read_file", "content": "Full content here\n"},
345
+ {"type": "text", "content": "Final answer."},
346
+ {"type": "done"},
347
+ ]
348
+ _, stdout, stderr = self._run(events)
349
+ stderr_text = " ".join(stderr)
350
+ assert "grep" in stderr_text
351
+ assert "reading" in stderr_text
352
+ assert any("Final answer." in s for s in stdout)
@@ -0,0 +1,22 @@
1
+ """Template for querying GPU documentation.
2
+
3
+ Streams directly from the server-side docs agent — no local agent loop.
4
+ The server runs a multi-turn Sonnet agent with grep/read_file/list_files
5
+ tools against the corpus volume in a Modal sandbox.
6
+
7
+ Usage:
8
+ wafer agent -t ask-docs "How do bank conflicts occur?"
9
+ wafer agent -t ask-docs --args corpus=hip "Explain HIP streams"
10
+ """
11
+
12
+ try:
13
+ from wafer_core.rollouts.templates import TemplateConfig
14
+ except ImportError:
15
+ from rollouts.templates import TemplateConfig
16
+
17
+ template = TemplateConfig(
18
+ name="ask-docs",
19
+ description="Query GPU documentation to answer technical questions",
20
+ direct_endpoint="/v1/docs/query",
21
+ defaults={"corpus": "cuda"},
22
+ )
@@ -429,6 +429,140 @@ def _load_template(
429
429
  return template, None
430
430
  except Exception as e:
431
431
  return None, str(e)
432
+ async def _stream_direct_endpoint(
433
+ api_url: str,
434
+ auth_token: str,
435
+ endpoint_path: str,
436
+ query: str,
437
+ template_args: dict[str, str] | None,
438
+ defaults: dict[str, str] | None,
439
+ json_output: bool,
440
+ ) -> None:
441
+ """Stream SSE from a server-side agent endpoint directly to the terminal.
442
+
443
+ Bypasses the local agent loop entirely. Used when a template has
444
+ `direct_endpoint` set — the server runs the full agent loop and we
445
+ just render the events.
446
+ """
447
+ import httpx
448
+
449
+ url = f"{api_url.rstrip('/')}{endpoint_path}"
450
+ body: dict[str, str] = {}
451
+ if defaults:
452
+ body.update(defaults)
453
+ if template_args:
454
+ body.update(template_args)
455
+ body["query"] = query
456
+
457
+ headers = {
458
+ "Authorization": f"Bearer {auth_token}",
459
+ "Content-Type": "application/json",
460
+ "Accept": "text/event-stream",
461
+ }
462
+
463
+ frontend: StreamingChunkFrontend | None = None
464
+ if json_output:
465
+ frontend = StreamingChunkFrontend()
466
+
467
+ async with httpx.AsyncClient(timeout=180.0) as client:
468
+ async with client.stream("POST", url, json=body, headers=headers) as response:
469
+ if response.status_code == 401:
470
+ print("Error: Authentication failed. Run 'wafer settings login'.", file=sys.stderr)
471
+ sys.exit(1)
472
+ if response.status_code == 402:
473
+ print("Error: Insufficient credits. Check 'wafer settings billing'.", file=sys.stderr)
474
+ sys.exit(1)
475
+ if response.status_code != 200:
476
+ raw = await response.aread()
477
+ print(f"Error: API returned {response.status_code}: {raw.decode(errors='replace')[:500]}", file=sys.stderr)
478
+ sys.exit(1)
479
+
480
+ async for line in response.aiter_lines():
481
+ if not line.startswith("data: "):
482
+ continue
483
+ data_str = line[len("data: "):]
484
+ if data_str == "[DONE]":
485
+ break
486
+ event = json.loads(data_str)
487
+ event_type = event.get("type", "")
488
+
489
+ if event_type == "tool_call":
490
+ tool_name = event.get("name", "")
491
+ tool_input = event.get("input", {})
492
+ summary = _format_tool_call_summary(tool_name, tool_input)
493
+ if json_output:
494
+ assert frontend is not None
495
+ frontend._emit({"type": "tool_call_start", "tool_name": tool_name})
496
+ frontend._emit({"type": "tool_call_end", "tool_name": tool_name, "args": tool_input})
497
+ else:
498
+ print(f"\033[2m {summary}\033[0m", file=sys.stderr)
499
+
500
+ elif event_type == "tool_result":
501
+ tool_name = event.get("name", "")
502
+ content = event.get("content", "")
503
+ summary = _format_tool_result_summary(tool_name, content)
504
+ if json_output:
505
+ assert frontend is not None
506
+ frontend._emit({"type": "tool_result", "is_error": False})
507
+ else:
508
+ print(f"\033[2m {summary}\033[0m", file=sys.stderr)
509
+
510
+ elif event_type == "text":
511
+ text = event.get("content", "")
512
+ if json_output:
513
+ assert frontend is not None
514
+ frontend._emit({"type": "text_delta", "delta": text})
515
+ else:
516
+ print(text, end="", flush=True)
517
+
518
+ elif event_type == "error":
519
+ error_msg = event.get("content", "Unknown error")
520
+ if json_output:
521
+ assert frontend is not None
522
+ frontend._emit({"type": "error", "error": error_msg})
523
+ else:
524
+ print(f"\nError: {error_msg}", file=sys.stderr)
525
+
526
+ elif event_type == "done":
527
+ break
528
+
529
+ if json_output:
530
+ assert frontend is not None
531
+ frontend._emit({"type": "session_end"})
532
+ else:
533
+ print() # trailing newline after streamed text
534
+
535
+
536
+ def _format_tool_call_summary(tool_name: str, tool_input: dict) -> str:
537
+ """Format a tool call into a concise status line."""
538
+ if tool_name == "grep":
539
+ pattern = tool_input.get("pattern", "")
540
+ return f'searching: grep("{pattern}")...'
541
+ if tool_name == "read_file":
542
+ path = tool_input.get("path", "")
543
+ return f"reading: {path}..."
544
+ if tool_name == "list_files":
545
+ pattern = tool_input.get("pattern", "*")
546
+ return f'listing: find("{pattern}")...'
547
+ return f"{tool_name}({json.dumps(tool_input)})..."
548
+
549
+
550
+ def _format_tool_result_summary(tool_name: str, content: str) -> str:
551
+ """Format a tool result into a concise status line."""
552
+ if "No matches found" in content or "No files found" in content:
553
+ return "no results"
554
+ if "Error:" in content:
555
+ return content[:80]
556
+ line_count = content.count("\n")
557
+ if tool_name == "grep":
558
+ return f"found {line_count} matches"
559
+ if tool_name == "read_file":
560
+ return f"read {line_count} lines"
561
+ if tool_name == "list_files":
562
+ return f"found {line_count} files"
563
+ return f"got {len(content)} chars"
564
+
565
+
432
566
  def main( # noqa: PLR0913, PLR0915
433
567
  prompt: str | None = None,
434
568
  interactive: bool = False,
@@ -587,6 +721,32 @@ def main( # noqa: PLR0913, PLR0915
587
721
  print(f"Template: {tpl.name}", file=sys.stderr)
588
722
  print(f" {tpl.description}", file=sys.stderr)
589
723
  print(file=sys.stderr)
724
+ # Direct endpoint: bypass local agent loop, stream from server
725
+ if tpl.direct_endpoint is not None:
726
+ assert prompt, (
727
+ f"Template '{tpl.name}' uses direct streaming and requires a prompt. "
728
+ f"Usage: wafer agent -t {tpl.name} \"your question\""
729
+ )
730
+ wafer_api_url = os.environ.get("WAFER_API_URL", get_api_url())
731
+ wafer_auth_token = os.environ.get("WAFER_AUTH_TOKEN", "")
732
+ assert wafer_auth_token, "WAFER_AUTH_TOKEN not set. Run 'wafer settings login' first."
733
+ _direct_tpl = tpl
734
+ _direct_prompt = prompt
735
+
736
+ async def _run_direct() -> None:
737
+ await _stream_direct_endpoint(
738
+ api_url=wafer_api_url,
739
+ auth_token=wafer_auth_token,
740
+ endpoint_path=_direct_tpl.direct_endpoint,
741
+ query=_direct_prompt,
742
+ template_args=template_args,
743
+ defaults=_direct_tpl.defaults if _direct_tpl.defaults else None,
744
+ json_output=json_output,
745
+ )
746
+
747
+ import trio
748
+ trio.run(_run_direct)
749
+ return
590
750
  else:
591
751
  tpl = _get_default_template()
592
752
  base_system_prompt = tpl.system_prompt
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.58
3
+ Version: 0.2.60
4
4
  Summary: CLI for running GPU workloads, managing remote workspaces, and evaluating/optimizing kernels
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -9,6 +9,7 @@ tests/test_cli_parity_integration.py
9
9
  tests/test_config_show.py
10
10
  tests/test_corpus_lockdown.py
11
11
  tests/test_deps.py
12
+ tests/test_direct_streaming.py
12
13
  tests/test_distributed_traces_cli.py
13
14
  tests/test_docker_progress.py
14
15
  tests/test_evaluate_ux.py
@@ -1,32 +0,0 @@
1
- """Template for querying GPU documentation.
2
-
3
- Usage:
4
- wafer agent -t ask-docs "How do bank conflicts occur?"
5
- wafer agent -t ask-docs "Explain warp divergence in CUDA"
6
- """
7
-
8
- try:
9
- from wafer_core.rollouts.templates import TemplateConfig
10
- except ImportError:
11
- from rollouts.templates import TemplateConfig
12
-
13
- template = TemplateConfig(
14
- name="ask-docs",
15
- description="Query GPU documentation to answer technical questions",
16
- system_prompt="""You are a GPU programming expert. Use the ask_docs tool to search documentation and answer questions.
17
-
18
- Available corpora: cuda, cutlass, hip, amd, cdna3, hopper, rdna35, llvm-amdgpu, gcnasm.
19
-
20
- Strategy:
21
- 1. Call ask_docs with the user's question and the appropriate corpus
22
- 2. If the answer is incomplete, call ask_docs again with a refined query or different corpus
23
- 3. Synthesize a clear, accurate answer
24
-
25
- Be concise but thorough. Include code examples when relevant.""",
26
- tools=["ask_docs"],
27
- model="anthropic/claude-opus-4-5-20251101",
28
- max_tokens=8192,
29
- thinking=False,
30
- thinking_budget=10000,
31
- single_turn=False,
32
- )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes