wafer-cli 0.2.7__tar.gz → 0.2.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/PKG-INFO +1 -1
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/pyproject.toml +1 -1
- wafer_cli-0.2.9/tests/test_kernel_scope_cli.py +467 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/auth.py +85 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/cli.py +1196 -160
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/evaluate.py +1171 -209
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/gpu_run.py +5 -1
- wafer_cli-0.2.9/wafer/kernel_scope.py +453 -0
- wafer_cli-0.2.9/wafer/problems.py +357 -0
- wafer_cli-0.2.9/wafer/target_lock.py +270 -0
- wafer_cli-0.2.9/wafer/targets.py +842 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/wevin_cli.py +2 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/workspaces.py +53 -1
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer_cli.egg-info/PKG-INFO +1 -1
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer_cli.egg-info/SOURCES.txt +4 -0
- wafer_cli-0.2.7/wafer/targets.py +0 -352
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/README.md +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/setup.cfg +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_analytics.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_billing.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_cli_coverage.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_cli_parity_integration.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_config_integration.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_file_operations_integration.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_isa_cli.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_rocprof_compute_integration.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_ssh_integration.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_wevin_cli.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/tests/test_workflow_integration.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/GUIDE.md +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/__init__.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/analytics.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/api_client.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/autotuner.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/billing.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/config.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/corpus.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/global_config.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/inference.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/ncu_analyze.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/nsys_analyze.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/rocprof_compute.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/rocprof_sdk.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/rocprof_systems.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/skills/wafer-guide/SKILL.md +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/templates/__init__.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/templates/ask_docs.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/templates/optimize_kernel.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/templates/trace_analyze.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer/tracelens.py +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer_cli.egg-info/dependency_links.txt +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer_cli.egg-info/entry_points.txt +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer_cli.egg-info/requires.txt +0 -0
- {wafer_cli-0.2.7 → wafer_cli-0.2.9}/wafer_cli.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
"""Unit tests for Kernel Scope CLI commands.
|
|
2
|
+
|
|
3
|
+
Tests the wafer amd kernel-scope command using CliRunner.
|
|
4
|
+
|
|
5
|
+
Run with: PYTHONPATH=apps/wafer-cli uv run pytest apps/wafer-cli/tests/test_kernel_scope_cli.py -v
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import re
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from unittest.mock import patch, MagicMock
|
|
12
|
+
|
|
13
|
+
import pytest
|
|
14
|
+
from typer.testing import CliRunner
|
|
15
|
+
|
|
16
|
+
from wafer.cli import app
|
|
17
|
+
from wafer.kernel_scope import (
|
|
18
|
+
analyze_command,
|
|
19
|
+
metrics_command,
|
|
20
|
+
targets_command,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
runner = CliRunner()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def strip_ansi(text: str) -> str:
|
|
28
|
+
"""Remove ANSI escape codes from text."""
|
|
29
|
+
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
|
30
|
+
return ansi_escape.sub("", text)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Sample ISA for testing
|
|
34
|
+
SAMPLE_ISA = '''
|
|
35
|
+
.amdgcn_target "amdgcn-amd-amdhsa--gfx90a"
|
|
36
|
+
|
|
37
|
+
.text
|
|
38
|
+
.globl test_kernel
|
|
39
|
+
|
|
40
|
+
test_kernel:
|
|
41
|
+
s_load_dwordx4 s[0:3], s[4:5], 0x0
|
|
42
|
+
s_waitcnt lgkmcnt(0)
|
|
43
|
+
global_load_dwordx4 v[0:3], v[4:5], off
|
|
44
|
+
ds_read_b128 v[8:11], v12
|
|
45
|
+
s_waitcnt vmcnt(0) lgkmcnt(0)
|
|
46
|
+
v_mfma_f32_32x32x8f16 a[0:15], v[0:1], v[2:3], a[0:15]
|
|
47
|
+
v_mfma_f32_32x32x8f16 a[16:31], v[4:5], v[6:7], a[16:31]
|
|
48
|
+
v_add_f32 v0, v1, v2
|
|
49
|
+
v_fma_f32 v3, v4, v5, v6
|
|
50
|
+
global_store_dwordx4 v[20:21], v[0:3], off
|
|
51
|
+
s_barrier
|
|
52
|
+
s_endpgm
|
|
53
|
+
|
|
54
|
+
.amdhsa_kernel test_kernel
|
|
55
|
+
.amdhsa_next_free_vgpr 64
|
|
56
|
+
.amdhsa_next_free_sgpr 32
|
|
57
|
+
.amdhsa_group_segment_fixed_size 16384
|
|
58
|
+
.amdhsa_private_segment_fixed_size 0
|
|
59
|
+
.end_amdhsa_kernel
|
|
60
|
+
'''
|
|
61
|
+
|
|
62
|
+
SAMPLE_ISA_WITH_SPILLS = '''
|
|
63
|
+
.amdgcn_target "amdgcn-amd-amdhsa--gfx942"
|
|
64
|
+
|
|
65
|
+
.text
|
|
66
|
+
.globl spilling_kernel
|
|
67
|
+
|
|
68
|
+
spilling_kernel:
|
|
69
|
+
s_load_dwordx4 s[0:3], s[4:5], 0x0
|
|
70
|
+
s_waitcnt lgkmcnt(0)
|
|
71
|
+
scratch_store_dwordx4 off, v[0:3], s0
|
|
72
|
+
scratch_store_dwordx4 off, v[4:7], s0
|
|
73
|
+
scratch_load_dwordx4 v[8:11], off, s0
|
|
74
|
+
v_add_f32 v0, v1, v2
|
|
75
|
+
s_waitcnt 0
|
|
76
|
+
s_endpgm
|
|
77
|
+
|
|
78
|
+
.amdhsa_kernel spilling_kernel
|
|
79
|
+
.amdhsa_next_free_vgpr 256
|
|
80
|
+
.amdhsa_next_free_sgpr 100
|
|
81
|
+
.amdhsa_group_segment_fixed_size 0
|
|
82
|
+
.amdhsa_private_segment_fixed_size 1024
|
|
83
|
+
.end_amdhsa_kernel
|
|
84
|
+
'''
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# ============================================================================
|
|
88
|
+
# Direct Function Tests
|
|
89
|
+
# ============================================================================
|
|
90
|
+
|
|
91
|
+
class TestAnalyzeCommandFunction:
|
|
92
|
+
"""Tests for analyze_command function directly."""
|
|
93
|
+
|
|
94
|
+
def test_analyze_single_file(self, tmp_path: Path) -> None:
|
|
95
|
+
"""Should analyze a single ISA file."""
|
|
96
|
+
isa_file = tmp_path / "kernel.s"
|
|
97
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
98
|
+
|
|
99
|
+
output = analyze_command(str(isa_file))
|
|
100
|
+
|
|
101
|
+
assert "test_kernel" in output
|
|
102
|
+
assert "gfx90a" in output
|
|
103
|
+
assert "VGPRs:" in output or "vgpr" in output.lower()
|
|
104
|
+
|
|
105
|
+
def test_analyze_file_not_found(self, tmp_path: Path) -> None:
|
|
106
|
+
"""Should raise error for missing file."""
|
|
107
|
+
with pytest.raises(FileNotFoundError):
|
|
108
|
+
analyze_command(str(tmp_path / "missing.s"))
|
|
109
|
+
|
|
110
|
+
def test_analyze_json_output(self, tmp_path: Path) -> None:
|
|
111
|
+
"""Should output valid JSON when json_output=True."""
|
|
112
|
+
isa_file = tmp_path / "kernel.s"
|
|
113
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
114
|
+
|
|
115
|
+
output = analyze_command(str(isa_file), json_output=True)
|
|
116
|
+
data = json.loads(output)
|
|
117
|
+
|
|
118
|
+
assert data["success"] is True
|
|
119
|
+
assert data["isa_analysis"]["kernel_name"] == "test_kernel"
|
|
120
|
+
assert data["isa_analysis"]["architecture"] == "gfx90a"
|
|
121
|
+
assert data["isa_analysis"]["vgpr_count"] == 64
|
|
122
|
+
assert data["isa_analysis"]["mfma_count"] == 2
|
|
123
|
+
|
|
124
|
+
def test_analyze_csv_output(self, tmp_path: Path) -> None:
|
|
125
|
+
"""Should output CSV when csv_output=True."""
|
|
126
|
+
isa_file = tmp_path / "kernel.s"
|
|
127
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
128
|
+
|
|
129
|
+
output = analyze_command(str(isa_file), csv_output=True)
|
|
130
|
+
|
|
131
|
+
# Should have header and data row
|
|
132
|
+
lines = output.strip().split("\n")
|
|
133
|
+
assert len(lines) == 2
|
|
134
|
+
assert "kernel_name" in lines[0]
|
|
135
|
+
assert "vgpr_count" in lines[0]
|
|
136
|
+
assert "test_kernel" in lines[1]
|
|
137
|
+
|
|
138
|
+
def test_analyze_directory(self, tmp_path: Path) -> None:
|
|
139
|
+
"""Should analyze all files in directory."""
|
|
140
|
+
(tmp_path / "kernel1.s").write_text(SAMPLE_ISA)
|
|
141
|
+
(tmp_path / "kernel2.s").write_text(SAMPLE_ISA_WITH_SPILLS)
|
|
142
|
+
|
|
143
|
+
output = analyze_command(str(tmp_path))
|
|
144
|
+
|
|
145
|
+
assert "Analyzed 2 files" in output
|
|
146
|
+
|
|
147
|
+
def test_analyze_with_filter(self, tmp_path: Path) -> None:
|
|
148
|
+
"""Should filter results based on expression."""
|
|
149
|
+
(tmp_path / "kernel1.s").write_text(SAMPLE_ISA)
|
|
150
|
+
(tmp_path / "kernel2.s").write_text(SAMPLE_ISA_WITH_SPILLS)
|
|
151
|
+
|
|
152
|
+
output = analyze_command(str(tmp_path), filter_expr="spills > 0")
|
|
153
|
+
|
|
154
|
+
# Should only show spilling kernel
|
|
155
|
+
assert "1 files" in output or "kernel2" in output or "spilling" in output.lower()
|
|
156
|
+
|
|
157
|
+
def test_analyze_output_to_file(self, tmp_path: Path) -> None:
|
|
158
|
+
"""Should write output to file."""
|
|
159
|
+
isa_file = tmp_path / "kernel.s"
|
|
160
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
161
|
+
output_file = tmp_path / "output.json"
|
|
162
|
+
|
|
163
|
+
analyze_command(
|
|
164
|
+
str(isa_file),
|
|
165
|
+
json_output=True,
|
|
166
|
+
output_file=str(output_file)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
assert output_file.exists()
|
|
170
|
+
data = json.loads(output_file.read_text())
|
|
171
|
+
assert data["success"] is True
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class TestMetricsCommandFunction:
|
|
175
|
+
"""Tests for metrics_command function."""
|
|
176
|
+
|
|
177
|
+
def test_lists_metrics(self) -> None:
|
|
178
|
+
"""Should list available metrics."""
|
|
179
|
+
output = metrics_command()
|
|
180
|
+
|
|
181
|
+
assert "vgpr_count" in output
|
|
182
|
+
assert "sgpr_count" in output
|
|
183
|
+
assert "spill_count" in output
|
|
184
|
+
assert "mfma_count" in output
|
|
185
|
+
assert "mfma_density_pct" in output
|
|
186
|
+
assert "theoretical_occupancy" in output
|
|
187
|
+
|
|
188
|
+
def test_includes_instruction_categories(self) -> None:
|
|
189
|
+
"""Should include instruction category descriptions."""
|
|
190
|
+
output = metrics_command()
|
|
191
|
+
|
|
192
|
+
assert "VALU" in output
|
|
193
|
+
assert "SALU" in output
|
|
194
|
+
assert "VMEM" in output
|
|
195
|
+
assert "MFMA" in output
|
|
196
|
+
assert "LDS" in output
|
|
197
|
+
assert "SPILL" in output
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class TestTargetsCommandFunction:
|
|
201
|
+
"""Tests for targets_command function."""
|
|
202
|
+
|
|
203
|
+
def test_lists_targets(self) -> None:
|
|
204
|
+
"""Should list supported GPU targets."""
|
|
205
|
+
output = targets_command()
|
|
206
|
+
|
|
207
|
+
assert "gfx90a" in output
|
|
208
|
+
assert "gfx942" in output
|
|
209
|
+
assert "gfx908" in output
|
|
210
|
+
|
|
211
|
+
def test_includes_specs(self) -> None:
|
|
212
|
+
"""Should include hardware specs for targets."""
|
|
213
|
+
output = targets_command()
|
|
214
|
+
|
|
215
|
+
assert "MI200" in output or "MI300" in output
|
|
216
|
+
assert "VGPRs" in output or "VGPR" in output
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# ============================================================================
|
|
220
|
+
# CLI Integration Tests
|
|
221
|
+
# ============================================================================
|
|
222
|
+
|
|
223
|
+
class TestKernelScopeCliCommands:
|
|
224
|
+
"""Tests for wafer amd kernel-scope CLI commands."""
|
|
225
|
+
|
|
226
|
+
def test_analyze_command_help(self) -> None:
|
|
227
|
+
"""Should display help for analyze command."""
|
|
228
|
+
result = runner.invoke(app, ["amd", "kernel-scope", "analyze", "--help"])
|
|
229
|
+
|
|
230
|
+
assert result.exit_code == 0
|
|
231
|
+
output = strip_ansi(result.stdout)
|
|
232
|
+
assert "Analyze" in output or "analyze" in output
|
|
233
|
+
|
|
234
|
+
def test_analyze_file_via_cli(self, tmp_path: Path) -> None:
|
|
235
|
+
"""Should analyze file via CLI."""
|
|
236
|
+
isa_file = tmp_path / "kernel.s"
|
|
237
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
238
|
+
|
|
239
|
+
result = runner.invoke(app, [
|
|
240
|
+
"amd", "kernel-scope", "analyze", str(isa_file)
|
|
241
|
+
])
|
|
242
|
+
|
|
243
|
+
assert result.exit_code == 0, f"Failed: {result.output}"
|
|
244
|
+
assert "test_kernel" in result.stdout
|
|
245
|
+
|
|
246
|
+
def test_analyze_json_via_cli(self, tmp_path: Path) -> None:
|
|
247
|
+
"""Should output JSON via CLI."""
|
|
248
|
+
isa_file = tmp_path / "kernel.s"
|
|
249
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
250
|
+
|
|
251
|
+
result = runner.invoke(app, [
|
|
252
|
+
"amd", "kernel-scope", "analyze", str(isa_file), "--json"
|
|
253
|
+
])
|
|
254
|
+
|
|
255
|
+
assert result.exit_code == 0, f"Failed: {result.output}"
|
|
256
|
+
data = json.loads(result.stdout)
|
|
257
|
+
assert data["success"] is True
|
|
258
|
+
|
|
259
|
+
def test_analyze_csv_via_cli(self, tmp_path: Path) -> None:
|
|
260
|
+
"""Should output CSV via CLI."""
|
|
261
|
+
isa_file = tmp_path / "kernel.s"
|
|
262
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
263
|
+
|
|
264
|
+
result = runner.invoke(app, [
|
|
265
|
+
"amd", "kernel-scope", "analyze", str(isa_file), "--csv"
|
|
266
|
+
])
|
|
267
|
+
|
|
268
|
+
assert result.exit_code == 0, f"Failed: {result.output}"
|
|
269
|
+
assert "kernel_name" in result.stdout
|
|
270
|
+
|
|
271
|
+
def test_analyze_missing_file_via_cli(self, tmp_path: Path) -> None:
|
|
272
|
+
"""Should fail for missing file."""
|
|
273
|
+
result = runner.invoke(app, [
|
|
274
|
+
"amd", "kernel-scope", "analyze", str(tmp_path / "missing.s")
|
|
275
|
+
])
|
|
276
|
+
|
|
277
|
+
assert result.exit_code != 0
|
|
278
|
+
|
|
279
|
+
def test_metrics_via_cli(self) -> None:
|
|
280
|
+
"""Should list metrics via CLI."""
|
|
281
|
+
result = runner.invoke(app, ["amd", "kernel-scope", "metrics"])
|
|
282
|
+
|
|
283
|
+
assert result.exit_code == 0
|
|
284
|
+
assert "vgpr_count" in result.stdout
|
|
285
|
+
|
|
286
|
+
def test_targets_via_cli(self) -> None:
|
|
287
|
+
"""Should list targets via CLI."""
|
|
288
|
+
result = runner.invoke(app, ["amd", "kernel-scope", "targets"])
|
|
289
|
+
|
|
290
|
+
assert result.exit_code == 0
|
|
291
|
+
assert "gfx90a" in result.stdout or "gfx942" in result.stdout
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class TestKernelScopeCliHelp:
|
|
295
|
+
"""Tests for kernel-scope command help text."""
|
|
296
|
+
|
|
297
|
+
def test_kernel_scope_help(self) -> None:
|
|
298
|
+
"""Should display help for kernel-scope command group."""
|
|
299
|
+
result = runner.invoke(app, ["amd", "kernel-scope", "--help"])
|
|
300
|
+
|
|
301
|
+
assert result.exit_code == 0
|
|
302
|
+
output = strip_ansi(result.stdout)
|
|
303
|
+
assert "analyze" in output.lower() or "Analyze" in output
|
|
304
|
+
assert "metrics" in output.lower()
|
|
305
|
+
assert "targets" in output.lower()
|
|
306
|
+
|
|
307
|
+
def test_amd_help_includes_kernel_scope(self) -> None:
|
|
308
|
+
"""AMD help should mention kernel-scope."""
|
|
309
|
+
result = runner.invoke(app, ["amd", "--help"])
|
|
310
|
+
|
|
311
|
+
assert result.exit_code == 0
|
|
312
|
+
assert "kernel-scope" in result.stdout
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ============================================================================
|
|
316
|
+
# Output Format Tests
|
|
317
|
+
# ============================================================================
|
|
318
|
+
|
|
319
|
+
class TestOutputFormats:
|
|
320
|
+
"""Tests for output formatting."""
|
|
321
|
+
|
|
322
|
+
def test_text_output_spills_warning(self, tmp_path: Path) -> None:
|
|
323
|
+
"""Text output should show spills warning."""
|
|
324
|
+
isa_file = tmp_path / "kernel.s"
|
|
325
|
+
isa_file.write_text(SAMPLE_ISA_WITH_SPILLS)
|
|
326
|
+
|
|
327
|
+
output = analyze_command(str(isa_file))
|
|
328
|
+
|
|
329
|
+
assert "SPILLS" in output or "spills" in output.lower()
|
|
330
|
+
|
|
331
|
+
def test_text_output_registers_section(self, tmp_path: Path) -> None:
|
|
332
|
+
"""Text output should have registers section."""
|
|
333
|
+
isa_file = tmp_path / "kernel.s"
|
|
334
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
335
|
+
|
|
336
|
+
output = analyze_command(str(isa_file))
|
|
337
|
+
|
|
338
|
+
assert "Registers" in output or "VGPRs" in output
|
|
339
|
+
|
|
340
|
+
def test_text_output_memory_section(self, tmp_path: Path) -> None:
|
|
341
|
+
"""Text output should have memory section."""
|
|
342
|
+
isa_file = tmp_path / "kernel.s"
|
|
343
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
344
|
+
|
|
345
|
+
output = analyze_command(str(isa_file))
|
|
346
|
+
|
|
347
|
+
assert "Memory" in output or "LDS" in output
|
|
348
|
+
|
|
349
|
+
def test_text_output_instructions_section(self, tmp_path: Path) -> None:
|
|
350
|
+
"""Text output should have instructions section."""
|
|
351
|
+
isa_file = tmp_path / "kernel.s"
|
|
352
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
353
|
+
|
|
354
|
+
output = analyze_command(str(isa_file))
|
|
355
|
+
|
|
356
|
+
assert "Instructions" in output or "MFMA" in output
|
|
357
|
+
|
|
358
|
+
def test_text_output_occupancy_section(self, tmp_path: Path) -> None:
|
|
359
|
+
"""Text output should have occupancy section."""
|
|
360
|
+
isa_file = tmp_path / "kernel.s"
|
|
361
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
362
|
+
|
|
363
|
+
output = analyze_command(str(isa_file))
|
|
364
|
+
|
|
365
|
+
assert "Occupancy" in output or "waves" in output.lower()
|
|
366
|
+
|
|
367
|
+
def test_json_output_has_all_fields(self, tmp_path: Path) -> None:
|
|
368
|
+
"""JSON output should include all analysis fields."""
|
|
369
|
+
isa_file = tmp_path / "kernel.s"
|
|
370
|
+
isa_file.write_text(SAMPLE_ISA)
|
|
371
|
+
|
|
372
|
+
output = analyze_command(str(isa_file), json_output=True)
|
|
373
|
+
data = json.loads(output)
|
|
374
|
+
analysis = data["isa_analysis"]
|
|
375
|
+
|
|
376
|
+
# Check required fields
|
|
377
|
+
assert "kernel_name" in analysis
|
|
378
|
+
assert "architecture" in analysis
|
|
379
|
+
assert "vgpr_count" in analysis
|
|
380
|
+
assert "sgpr_count" in analysis
|
|
381
|
+
assert "spill_count" in analysis
|
|
382
|
+
assert "mfma_count" in analysis
|
|
383
|
+
assert "mfma_density_pct" in analysis
|
|
384
|
+
assert "instruction_mix" in analysis
|
|
385
|
+
assert "theoretical_occupancy" in analysis
|
|
386
|
+
assert "warnings" in analysis
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
# ============================================================================
|
|
390
|
+
# Filter Tests
|
|
391
|
+
# ============================================================================
|
|
392
|
+
|
|
393
|
+
class TestFiltering:
|
|
394
|
+
"""Tests for result filtering."""
|
|
395
|
+
|
|
396
|
+
def test_filter_spills_greater_than_zero(self, tmp_path: Path) -> None:
|
|
397
|
+
"""Filter 'spills > 0' should only show files with spills."""
|
|
398
|
+
(tmp_path / "no_spills.s").write_text(SAMPLE_ISA)
|
|
399
|
+
(tmp_path / "has_spills.s").write_text(SAMPLE_ISA_WITH_SPILLS)
|
|
400
|
+
|
|
401
|
+
output = analyze_command(str(tmp_path), filter_expr="spills > 0")
|
|
402
|
+
|
|
403
|
+
# Should filter to only spilling kernel
|
|
404
|
+
assert "1 files" in output or "Analyzed 1" in output
|
|
405
|
+
|
|
406
|
+
def test_filter_vgpr_count(self, tmp_path: Path) -> None:
|
|
407
|
+
"""Filter on VGPR count should work."""
|
|
408
|
+
(tmp_path / "small_vgpr.s").write_text(SAMPLE_ISA) # 64 VGPRs
|
|
409
|
+
(tmp_path / "large_vgpr.s").write_text(SAMPLE_ISA_WITH_SPILLS) # 256 VGPRs
|
|
410
|
+
|
|
411
|
+
output = analyze_command(str(tmp_path), filter_expr="vgpr_count > 128")
|
|
412
|
+
|
|
413
|
+
# Should only show high-VGPR kernel
|
|
414
|
+
assert "1 files" in output or "Analyzed 1" in output
|
|
415
|
+
|
|
416
|
+
def test_filter_mfma_count(self, tmp_path: Path) -> None:
|
|
417
|
+
"""Filter on MFMA count should work."""
|
|
418
|
+
(tmp_path / "has_mfma.s").write_text(SAMPLE_ISA) # 2 MFMAs
|
|
419
|
+
(tmp_path / "no_mfma.s").write_text(SAMPLE_ISA_WITH_SPILLS) # 0 MFMAs
|
|
420
|
+
|
|
421
|
+
output = analyze_command(str(tmp_path), filter_expr="mfma > 0")
|
|
422
|
+
|
|
423
|
+
# Should show kernel with MFMA
|
|
424
|
+
assert "1 files" in output or "has_mfma" in output
|
|
425
|
+
|
|
426
|
+
def test_filter_invalid_expression(self, tmp_path: Path, capsys) -> None:
|
|
427
|
+
"""Invalid filter expression should warn."""
|
|
428
|
+
(tmp_path / "kernel.s").write_text(SAMPLE_ISA)
|
|
429
|
+
|
|
430
|
+
output = analyze_command(str(tmp_path), filter_expr="invalid filter")
|
|
431
|
+
|
|
432
|
+
# Should still analyze, just warn about invalid filter
|
|
433
|
+
# The function prints warning to stderr
|
|
434
|
+
captured = capsys.readouterr()
|
|
435
|
+
assert "Invalid" in captured.err or "Warning" in captured.err or "2 files" not in output
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
# ============================================================================
|
|
439
|
+
# Edge Cases
|
|
440
|
+
# ============================================================================
|
|
441
|
+
|
|
442
|
+
class TestEdgeCases:
|
|
443
|
+
"""Tests for edge cases and error handling."""
|
|
444
|
+
|
|
445
|
+
def test_analyze_unsupported_file_type(self, tmp_path: Path) -> None:
|
|
446
|
+
"""Should handle unsupported file types gracefully."""
|
|
447
|
+
txt_file = tmp_path / "file.xyz"
|
|
448
|
+
txt_file.write_text("not ISA content")
|
|
449
|
+
|
|
450
|
+
with pytest.raises(RuntimeError, match="Unsupported"):
|
|
451
|
+
analyze_command(str(txt_file))
|
|
452
|
+
|
|
453
|
+
def test_analyze_empty_directory(self, tmp_path: Path) -> None:
|
|
454
|
+
"""Should handle empty directories."""
|
|
455
|
+
output = analyze_command(str(tmp_path))
|
|
456
|
+
|
|
457
|
+
assert "0 files" in output or "No supported files" in output.lower()
|
|
458
|
+
|
|
459
|
+
def test_analyze_directory_with_subdirs(self, tmp_path: Path) -> None:
|
|
460
|
+
"""Should scan subdirectories when recursive."""
|
|
461
|
+
subdir = tmp_path / "subdir"
|
|
462
|
+
subdir.mkdir()
|
|
463
|
+
(subdir / "kernel.s").write_text(SAMPLE_ISA)
|
|
464
|
+
|
|
465
|
+
output = analyze_command(str(tmp_path), recursive=True)
|
|
466
|
+
|
|
467
|
+
assert "1 files" in output or "test_kernel" in output
|
|
@@ -345,3 +345,88 @@ def browser_login(timeout: int = 120, port: int | None = None) -> tuple[str, str
|
|
|
345
345
|
|
|
346
346
|
server.server_close()
|
|
347
347
|
raise TimeoutError(f"No response within {timeout} seconds")
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def device_code_login(timeout: int = 600) -> tuple[str, str | None]:
|
|
351
|
+
"""Authenticate using state-based flow (no browser/port forwarding needed).
|
|
352
|
+
|
|
353
|
+
This is the SSH-friendly auth flow similar to GitHub CLI:
|
|
354
|
+
1. Request a state token from the API
|
|
355
|
+
2. Display the auth URL with state parameter
|
|
356
|
+
3. User visits URL on any device and signs in normally
|
|
357
|
+
4. Poll API until user completes authentication
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
timeout: Seconds to wait for authentication (default 600 = 10 minutes)
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Tuple of (access_token, refresh_token). refresh_token may be None.
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
TimeoutError: If user doesn't authenticate within timeout
|
|
367
|
+
RuntimeError: If auth flow failed
|
|
368
|
+
"""
|
|
369
|
+
api_url = get_api_url()
|
|
370
|
+
|
|
371
|
+
# Request state and auth URL
|
|
372
|
+
with httpx.Client(timeout=10.0) as client:
|
|
373
|
+
response = client.post(f"{api_url}/v1/auth/cli-auth/start", json={})
|
|
374
|
+
response.raise_for_status()
|
|
375
|
+
data = response.json()
|
|
376
|
+
|
|
377
|
+
state = data["state"]
|
|
378
|
+
auth_url = data["auth_url"]
|
|
379
|
+
expires_in = data["expires_in"]
|
|
380
|
+
|
|
381
|
+
# Display instructions to user
|
|
382
|
+
print("\n" + "=" * 60)
|
|
383
|
+
print(" WAFER CLI - Authentication")
|
|
384
|
+
print("=" * 60)
|
|
385
|
+
print(f"\n Visit: {auth_url}")
|
|
386
|
+
print("\n Sign in with GitHub to complete authentication")
|
|
387
|
+
print("\n" + "=" * 60 + "\n")
|
|
388
|
+
|
|
389
|
+
# Poll for authentication
|
|
390
|
+
start = time.time()
|
|
391
|
+
poll_interval = 5 # Poll every 5 seconds
|
|
392
|
+
last_poll = 0.0
|
|
393
|
+
|
|
394
|
+
print("Waiting for authentication", end="", flush=True)
|
|
395
|
+
|
|
396
|
+
while time.time() - start < min(timeout, expires_in):
|
|
397
|
+
# Show progress dots
|
|
398
|
+
if time.time() - last_poll >= poll_interval:
|
|
399
|
+
print(".", end="", flush=True)
|
|
400
|
+
|
|
401
|
+
# Poll the API
|
|
402
|
+
with httpx.Client(timeout=10.0) as client:
|
|
403
|
+
try:
|
|
404
|
+
response = client.post(f"{api_url}/v1/auth/cli-auth/token", json={"state": state})
|
|
405
|
+
|
|
406
|
+
if response.status_code == 200:
|
|
407
|
+
# Success!
|
|
408
|
+
data = response.json()
|
|
409
|
+
print(f" {CHECK}\n")
|
|
410
|
+
return data["access_token"], data.get("refresh_token")
|
|
411
|
+
|
|
412
|
+
if response.status_code == 428:
|
|
413
|
+
# Still waiting
|
|
414
|
+
last_poll = time.time()
|
|
415
|
+
time.sleep(1)
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
# Some other error
|
|
419
|
+
print(f" {CROSS}\n")
|
|
420
|
+
raise RuntimeError(f"CLI auth flow failed: {response.status_code} {response.text}")
|
|
421
|
+
|
|
422
|
+
except httpx.RequestError as e:
|
|
423
|
+
# Network error, retry
|
|
424
|
+
print("!", end="", flush=True)
|
|
425
|
+
last_poll = time.time()
|
|
426
|
+
time.sleep(1)
|
|
427
|
+
continue
|
|
428
|
+
|
|
429
|
+
time.sleep(0.5) # Small sleep to avoid busy loop
|
|
430
|
+
|
|
431
|
+
print(f" {CROSS}\n")
|
|
432
|
+
raise TimeoutError(f"Authentication not completed within {expires_in} seconds")
|