wafer-cli 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_cli-0.1.0/PKG-INFO +9 -0
- wafer_cli-0.1.0/README.md +221 -0
- wafer_cli-0.1.0/pyproject.toml +73 -0
- wafer_cli-0.1.0/setup.cfg +4 -0
- wafer_cli-0.1.0/tests/test_compiler_analyze_integration.py +320 -0
- wafer_cli-0.1.0/tests/test_config_integration.py +50 -0
- wafer_cli-0.1.0/tests/test_file_operations_integration.py +193 -0
- wafer_cli-0.1.0/tests/test_ssh_integration.py +134 -0
- wafer_cli-0.1.0/tests/test_workflow_integration.py +147 -0
- wafer_cli-0.1.0/wafer/__init__.py +3 -0
- wafer_cli-0.1.0/wafer/api_client.py +201 -0
- wafer_cli-0.1.0/wafer/auth.py +254 -0
- wafer_cli-0.1.0/wafer/cli.py +1536 -0
- wafer_cli-0.1.0/wafer/compiler_analyze.py +63 -0
- wafer_cli-0.1.0/wafer/config.py +105 -0
- wafer_cli-0.1.0/wafer/evaluate.py +911 -0
- wafer_cli-0.1.0/wafer/gpu_run.py +303 -0
- wafer_cli-0.1.0/wafer/inference.py +148 -0
- wafer_cli-0.1.0/wafer/ncu_analyze.py +571 -0
- wafer_cli-0.1.0/wafer/targets.py +296 -0
- wafer_cli-0.1.0/wafer/wevin_cli.py +897 -0
- wafer_cli-0.1.0/wafer_cli.egg-info/PKG-INFO +9 -0
- wafer_cli-0.1.0/wafer_cli.egg-info/SOURCES.txt +25 -0
- wafer_cli-0.1.0/wafer_cli.egg-info/dependency_links.txt +1 -0
- wafer_cli-0.1.0/wafer_cli.egg-info/entry_points.txt +2 -0
- wafer_cli-0.1.0/wafer_cli.egg-info/requires.txt +4 -0
- wafer_cli-0.1.0/wafer_cli.egg-info/top_level.txt +2 -0
wafer_cli-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: wafer-cli
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: typer>=0.12.0
|
|
7
|
+
Requires-Dist: trio>=0.24.0
|
|
8
|
+
Requires-Dist: trio-asyncio>=0.15.0
|
|
9
|
+
Requires-Dist: wafer-core>=0.1.0
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Wafer CLI
|
|
2
|
+
|
|
3
|
+
Run commands on remote GPUs in Docker containers.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
cd apps/wafer-cli
|
|
9
|
+
uv sync
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
## Quick Start
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
# Login with GitHub
|
|
16
|
+
wafer login
|
|
17
|
+
|
|
18
|
+
# Run a command on remote GPU
|
|
19
|
+
wafer remote-run -- nvidia-smi
|
|
20
|
+
|
|
21
|
+
# Run Python script with file upload
|
|
22
|
+
wafer remote-run --upload-dir ./my_project -- python3 train.py
|
|
23
|
+
|
|
24
|
+
# Check who you're logged in as
|
|
25
|
+
wafer whoami
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Commands
|
|
29
|
+
|
|
30
|
+
### `wafer login`
|
|
31
|
+
|
|
32
|
+
Authenticate with GitHub OAuth. Opens browser for login flow.
|
|
33
|
+
|
|
34
|
+
```bash
|
|
35
|
+
wafer login
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
Credentials are stored in `~/.wafer/credentials.json`.
|
|
39
|
+
|
|
40
|
+
### `wafer remote-run`
|
|
41
|
+
|
|
42
|
+
Run any command on a remote GPU inside a Docker container.
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
# Basic: run a command
|
|
46
|
+
wafer remote-run -- nvidia-smi
|
|
47
|
+
|
|
48
|
+
# Upload files first, then run
|
|
49
|
+
wafer remote-run --upload-dir ./my_project -- python3 main.py
|
|
50
|
+
|
|
51
|
+
# Custom Docker image
|
|
52
|
+
wafer remote-run --image pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel -- python3 -c "import torch; print(torch.cuda.get_device_name())"
|
|
53
|
+
|
|
54
|
+
# With custom entrypoint (for images with non-shell defaults)
|
|
55
|
+
wafer remote-run --image vllm/vllm-openai:latest --docker-entrypoint bash -- python3 script.py
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Options:
|
|
59
|
+
- `--upload-dir <path>` - Upload local directory before running command
|
|
60
|
+
- `--image <image>` - Docker image to use (default: vllm/vllm-openai:latest)
|
|
61
|
+
- `--docker-entrypoint <cmd>` - Override container entrypoint
|
|
62
|
+
- `--pull-image` - Pull image even if it exists locally
|
|
63
|
+
- `--require-hwc` - Require hardware counters (for NCU profiling)
|
|
64
|
+
|
|
65
|
+
### `wafer push`
|
|
66
|
+
|
|
67
|
+
Upload files to a remote workspace (for multi-command workflows).
|
|
68
|
+
|
|
69
|
+
```bash
|
|
70
|
+
# Push files
|
|
71
|
+
wafer push ./my_project
|
|
72
|
+
|
|
73
|
+
# Returns workspace_id for use with remote-run
|
|
74
|
+
# Example: ws_abc123
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
### `wafer logout`
|
|
78
|
+
|
|
79
|
+
Remove stored credentials.
|
|
80
|
+
|
|
81
|
+
```bash
|
|
82
|
+
wafer logout
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
### `wafer whoami`
|
|
86
|
+
|
|
87
|
+
Show current authenticated user.
|
|
88
|
+
|
|
89
|
+
```bash
|
|
90
|
+
wafer whoami
|
|
91
|
+
# Output: Logged in as user@example.com
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
## Examples
|
|
95
|
+
|
|
96
|
+
### Run PyTorch training
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
wafer remote-run --upload-dir ./training -- python3 train.py --epochs 10
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
### Profile with NCU
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
wafer remote-run --require-hwc --upload-dir ./kernel -- ncu --set full python3 benchmark.py
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### Interactive debugging
|
|
109
|
+
|
|
110
|
+
```bash
|
|
111
|
+
# Upload once
|
|
112
|
+
WORKSPACE=$(wafer push ./project)
|
|
113
|
+
|
|
114
|
+
# Run multiple commands against same workspace
|
|
115
|
+
wafer remote-run --workspace-id $WORKSPACE -- python3 test1.py
|
|
116
|
+
wafer remote-run --workspace-id $WORKSPACE -- python3 test2.py
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
### Custom CUDA image
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
wafer remote-run \
|
|
123
|
+
--image nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 \
|
|
124
|
+
--upload-dir ./cuda_kernels \
|
|
125
|
+
-- nvcc -o kernel kernel.cu && ./kernel
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
---
|
|
129
|
+
|
|
130
|
+
## Architecture
|
|
131
|
+
|
|
132
|
+
```
|
|
133
|
+
CLI wafer-api GPU
|
|
134
|
+
| | |
|
|
135
|
+
|-- POST /v1/gpu/jobs ------------->| |
|
|
136
|
+
| { command, files[], image } | |
|
|
137
|
+
| |-- SSH upload + docker ------->|
|
|
138
|
+
|<-- SSE stream: stdout/stderr -----|<-- stream output -------------|
|
|
139
|
+
| | |
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
### Components
|
|
143
|
+
|
|
144
|
+
- **wafer-cli** (`apps/wafer-cli/`) - Thin client that calls wafer-api
|
|
145
|
+
- **wafer-api** (`services/wafer-api/`) - Backend that owns GPU targets and SSH credentials
|
|
146
|
+
- **wafer-core** (`packages/wafer-core/`) - Internal SSH client for file upload and command execution
|
|
147
|
+
|
|
148
|
+
### Why API-backed?
|
|
149
|
+
|
|
150
|
+
- **Security**: SSH credentials stay on server, not in client config
|
|
151
|
+
- **Routing**: Backend picks best available GPU based on requirements
|
|
152
|
+
- **Multi-tenant**: Multiple users share GPU pool without credential management
|
|
153
|
+
|
|
154
|
+
## Local Target Management (Advanced)
|
|
155
|
+
|
|
156
|
+
For direct SSH access (bypassing wafer-api), you can configure local targets:
|
|
157
|
+
|
|
158
|
+
```bash
|
|
159
|
+
wafer targets list
|
|
160
|
+
wafer targets add examples/targets/my-gpu.toml
|
|
161
|
+
wafer targets default my-gpu
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
See `examples/targets/` for TOML format.
|
|
165
|
+
|
|
166
|
+
## Testing
|
|
167
|
+
|
|
168
|
+
### Prerequisites
|
|
169
|
+
|
|
170
|
+
1. Start the wafer-api server:
|
|
171
|
+
```bash
|
|
172
|
+
cd services/wafer-api && uv run uvicorn src.main:app --port 8000
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
2. Authenticate (if not already logged in):
|
|
176
|
+
```bash
|
|
177
|
+
cd apps/wafer-cli && uv run wafer login
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
### Run Integration Tests
|
|
181
|
+
|
|
182
|
+
```bash
|
|
183
|
+
# GPU API tests (push, jobs, cancellation)
|
|
184
|
+
python scripts/test_gpu_api.py --no-server
|
|
185
|
+
|
|
186
|
+
# CLI remote-run tests (basic, upload-dir, nested)
|
|
187
|
+
python scripts/test_cli_remote_run.py --no-server
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
### Manual Command Tests
|
|
191
|
+
|
|
192
|
+
```bash
|
|
193
|
+
cd apps/wafer-cli
|
|
194
|
+
|
|
195
|
+
# wafer remote-run (via API)
|
|
196
|
+
uv run wafer remote-run -- nvidia-smi
|
|
197
|
+
|
|
198
|
+
# wafer remote-run with file upload
|
|
199
|
+
uv run wafer remote-run --upload-dir ./my_project -- python script.py
|
|
200
|
+
|
|
201
|
+
# wafer ncu-analyze (via API)
|
|
202
|
+
uv run wafer ncu-analyze path/to/profile.ncu-rep --remote
|
|
203
|
+
|
|
204
|
+
# wafer ask-docs (requires docs-tool running on port 8002)
|
|
205
|
+
WAFER_DOCS_URL=http://localhost:8002 uv run wafer ask-docs "What is a Triton kernel?"
|
|
206
|
+
|
|
207
|
+
# wafer wevin with --tools and --json flags
|
|
208
|
+
uv run wafer wevin --ref kernel.py --desc "Optimize" --test "n=128" --tools read,write,edit --json --no-tui --max-turns 1
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
### Expected Results
|
|
212
|
+
|
|
213
|
+
- All integration tests should pass
|
|
214
|
+
- `wafer remote-run` executes commands on remote GPU via API
|
|
215
|
+
- `wafer ncu-analyze --remote` uploads profile and returns analysis
|
|
216
|
+
- All SSH operations use internal `wafer_core.ssh.SSHClient`
|
|
217
|
+
|
|
218
|
+
## Requirements
|
|
219
|
+
|
|
220
|
+
- Python 3.10+
|
|
221
|
+
- GitHub account (for authentication)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "wafer-cli"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "CLI tool for running commands on remote GPUs and GPU kernel optimization agent"
|
|
5
|
+
requires-python = ">=3.11"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"typer>=0.12.0",
|
|
8
|
+
"trio>=0.24.0",
|
|
9
|
+
"trio-asyncio>=0.15.0", # Bridge asyncssh (asyncio) to trio for async SSH
|
|
10
|
+
# Wafer core for environments and utils (includes rollouts)
|
|
11
|
+
"wafer-core>=0.1.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[project.scripts]
|
|
15
|
+
wafer = "wafer.cli:main"
|
|
16
|
+
|
|
17
|
+
[tool.uv.sources]
|
|
18
|
+
wafer-core = { workspace = true }
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["setuptools>=61.0"]
|
|
22
|
+
build-backend = "setuptools.build_meta"
|
|
23
|
+
|
|
24
|
+
[tool.setuptools.packages.find]
|
|
25
|
+
where = ["."]
|
|
26
|
+
include = ["wafer*"]
|
|
27
|
+
|
|
28
|
+
[tool.ruff]
|
|
29
|
+
line-length = 100
|
|
30
|
+
target-version = "py311"
|
|
31
|
+
preview = true # Required for PLR1702 (too-many-nested-blocks)
|
|
32
|
+
|
|
33
|
+
[tool.ruff.lint]
|
|
34
|
+
select = [
|
|
35
|
+
"E", # pycodestyle errors
|
|
36
|
+
"F", # pyflakes
|
|
37
|
+
"I", # isort (import sorting)
|
|
38
|
+
"ANN", # flake8-annotations (enforce type annotations)
|
|
39
|
+
"ASYNC", # flake8-async (trio/asyncio best practices)
|
|
40
|
+
"B", # flake8-bugbear (common bugs)
|
|
41
|
+
"UP", # pyupgrade (modern Python patterns)
|
|
42
|
+
"PLR0913", # too-many-arguments (Tiger Style: "hourglass shape: few parameters")
|
|
43
|
+
"PLR0915", # too-many-statements (Tiger Style: 70 line limit)
|
|
44
|
+
"PLR1702", # too-many-nested-blocks (Tiger Style: "centralize control flow")
|
|
45
|
+
"PLW2901", # redefined-loop-variable (Carmack SSA: single assignment)
|
|
46
|
+
"RET506", # superfluous-else-raise (explicit control flow)
|
|
47
|
+
"RET507", # superfluous-else-continue (explicit control flow)
|
|
48
|
+
"A", # flake8-builtins (shadowing builtins like list, str)
|
|
49
|
+
"RUF018", # assignment-in-assert (catches typos)
|
|
50
|
+
"TRY002", # raise-vanilla-exception (use specific exceptions)
|
|
51
|
+
"TRY003", # raise-vanilla-args (proper exception messages)
|
|
52
|
+
"TRY004", # type-check-without-type-error (use TypeError for type checks)
|
|
53
|
+
"TRY201", # verbose-raise (use bare raise)
|
|
54
|
+
"TRY300", # try-consider-else (clear control flow)
|
|
55
|
+
"TRY400", # error-instead-of-exception (use logging.exception)
|
|
56
|
+
]
|
|
57
|
+
ignore = [
|
|
58
|
+
"E501", # Line too long (handled by formatter)
|
|
59
|
+
"B008", # Typer uses function calls in defaults
|
|
60
|
+
"TRY003", # Avoid specifying long messages outside exception class (too opinionated - revisit later)
|
|
61
|
+
"TRY300", # Consider moving to else block (too opinionated about try/except style - revisit later)
|
|
62
|
+
"TRY400", # Use logging.exception instead of logging.error (we prefer explicit error logging - revisit later)
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
[tool.ruff.lint.per-file-ignores]
|
|
66
|
+
"tests/**/*.py" = ["ANN201"] # Don't require return type annotations in tests
|
|
67
|
+
"wafer/evaluate.py" = ["PLR0915"] # run_evaluate_docker/ssh have complex deployment flows (80/77 statements) - TODO: refactor
|
|
68
|
+
|
|
69
|
+
[tool.ruff.lint.pylint]
|
|
70
|
+
max-args = 7 # Max function arguments (Tiger Style: few parameters)
|
|
71
|
+
max-statements = 70 # Max statements per function (Tiger Style: 70 line limit)
|
|
72
|
+
max-branches = 12 # Max if/elif branches
|
|
73
|
+
max-nested-blocks = 5 # Max nesting depth (Linus rule: "if you need more than 3 levels, you're screwed")
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""Integration tests for wafer compiler-analyze CLI command.
|
|
2
|
+
|
|
3
|
+
Tests the compiler-analyze command end-to-end, verifying:
|
|
4
|
+
- CLI command execution
|
|
5
|
+
- JSON output format
|
|
6
|
+
- Error handling
|
|
7
|
+
- Various input combinations
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import subprocess
|
|
12
|
+
|
|
13
|
+
import pytest
|
|
14
|
+
|
|
15
|
+
# Constants
|
|
16
|
+
CLI_TIMEOUT_SECONDS = 30
|
|
17
|
+
CLI_COMMAND = "wafer"
|
|
18
|
+
COMPILER_ANALYZE_SUBCOMMAND = "compiler-analyze"
|
|
19
|
+
JSON_FLAG = "--json"
|
|
20
|
+
MLIR_FLAG = "--mlir-text"
|
|
21
|
+
PTX_FLAG = "--ptx-text"
|
|
22
|
+
SASS_FLAG = "--sass-text"
|
|
23
|
+
SOURCE_FLAG = "--source-text"
|
|
24
|
+
KERNEL_NAME_FLAG = "--kernel-name"
|
|
25
|
+
|
|
26
|
+
SUCCESS_KEY = "success"
|
|
27
|
+
DATA_KEY = "data"
|
|
28
|
+
ERROR_KEY = "error"
|
|
29
|
+
KERNEL_NAME_KEY = "kernel_name"
|
|
30
|
+
MLIR_TEXT_KEY = "mlir_text"
|
|
31
|
+
PTX_TEXT_KEY = "ptx_text"
|
|
32
|
+
SASS_TEXT_KEY = "sass_text"
|
|
33
|
+
SOURCE_CODE_KEY = "source_code"
|
|
34
|
+
PARSED_MLIR_KEY = "parsed_mlir"
|
|
35
|
+
PARSED_PTX_KEY = "parsed_ptx"
|
|
36
|
+
PARSED_SASS_KEY = "parsed_sass"
|
|
37
|
+
LAYOUTS_KEY = "layouts"
|
|
38
|
+
MEMORY_PATHS_KEY = "memory_paths"
|
|
39
|
+
PIPELINE_STAGES_KEY = "pipeline_stages"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_compiler_analyze_basic() -> None:
|
|
43
|
+
"""Test basic compiler-analyze with minimal inputs."""
|
|
44
|
+
mlir_text = "module { func @kernel() { } }"
|
|
45
|
+
ptx_text = ".version 8.0\n.target sm_80\n.entry kernel() { ret; }"
|
|
46
|
+
sass_text = "// SASS code"
|
|
47
|
+
|
|
48
|
+
args = [
|
|
49
|
+
CLI_COMMAND,
|
|
50
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
51
|
+
MLIR_FLAG, mlir_text,
|
|
52
|
+
PTX_FLAG, ptx_text,
|
|
53
|
+
SASS_FLAG, sass_text,
|
|
54
|
+
JSON_FLAG
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
result = subprocess.run(
|
|
58
|
+
args,
|
|
59
|
+
capture_output=True,
|
|
60
|
+
text=True,
|
|
61
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
assert result.returncode == 0
|
|
65
|
+
assert result.stderr is not None
|
|
66
|
+
|
|
67
|
+
output_json = json.loads(result.stdout)
|
|
68
|
+
assert isinstance(output_json, dict)
|
|
69
|
+
assert output_json.get(SUCCESS_KEY) is True
|
|
70
|
+
assert DATA_KEY in output_json
|
|
71
|
+
|
|
72
|
+
data = output_json[DATA_KEY]
|
|
73
|
+
assert isinstance(data, dict)
|
|
74
|
+
assert KERNEL_NAME_KEY in data
|
|
75
|
+
assert MLIR_TEXT_KEY in data
|
|
76
|
+
assert PTX_TEXT_KEY in data
|
|
77
|
+
assert SASS_TEXT_KEY in data
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_compiler_analyze_with_source_code() -> None:
|
|
81
|
+
"""Test compiler-analyze with optional source code."""
|
|
82
|
+
mlir_text = "module { func @test_kernel() { } }"
|
|
83
|
+
ptx_text = ".version 8.0\n.target sm_80"
|
|
84
|
+
sass_text = "// SASS"
|
|
85
|
+
source_code = "__global__ void test_kernel() { }"
|
|
86
|
+
|
|
87
|
+
args = [
|
|
88
|
+
CLI_COMMAND,
|
|
89
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
90
|
+
MLIR_FLAG, mlir_text,
|
|
91
|
+
PTX_FLAG, ptx_text,
|
|
92
|
+
SASS_FLAG, sass_text,
|
|
93
|
+
SOURCE_FLAG, source_code,
|
|
94
|
+
JSON_FLAG
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
result = subprocess.run(
|
|
98
|
+
args,
|
|
99
|
+
capture_output=True,
|
|
100
|
+
text=True,
|
|
101
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
assert result.returncode == 0
|
|
105
|
+
assert result.stderr is not None
|
|
106
|
+
|
|
107
|
+
output_json = json.loads(result.stdout)
|
|
108
|
+
assert isinstance(output_json, dict)
|
|
109
|
+
assert output_json.get(SUCCESS_KEY) is True
|
|
110
|
+
assert DATA_KEY in output_json
|
|
111
|
+
|
|
112
|
+
data = output_json[DATA_KEY]
|
|
113
|
+
assert isinstance(data, dict)
|
|
114
|
+
assert data.get(SOURCE_CODE_KEY) == source_code
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_compiler_analyze_with_kernel_name() -> None:
|
|
118
|
+
"""Test compiler-analyze with explicit kernel name."""
|
|
119
|
+
mlir_text = "module { func @my_kernel() { } }"
|
|
120
|
+
ptx_text = ".version 8.0\n.target sm_80"
|
|
121
|
+
sass_text = "// SASS"
|
|
122
|
+
kernel_name = "my_custom_kernel"
|
|
123
|
+
|
|
124
|
+
args = [
|
|
125
|
+
CLI_COMMAND,
|
|
126
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
127
|
+
MLIR_FLAG, mlir_text,
|
|
128
|
+
PTX_FLAG, ptx_text,
|
|
129
|
+
SASS_FLAG, sass_text,
|
|
130
|
+
KERNEL_NAME_FLAG, kernel_name,
|
|
131
|
+
JSON_FLAG
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
result = subprocess.run(
|
|
135
|
+
args,
|
|
136
|
+
capture_output=True,
|
|
137
|
+
text=True,
|
|
138
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
assert result.returncode == 0
|
|
142
|
+
assert result.stderr is not None
|
|
143
|
+
|
|
144
|
+
output_json = json.loads(result.stdout)
|
|
145
|
+
assert isinstance(output_json, dict)
|
|
146
|
+
assert output_json.get(SUCCESS_KEY) is True
|
|
147
|
+
assert DATA_KEY in output_json
|
|
148
|
+
|
|
149
|
+
data = output_json[DATA_KEY]
|
|
150
|
+
assert isinstance(data, dict)
|
|
151
|
+
assert data.get(KERNEL_NAME_KEY) == kernel_name
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_compiler_analyze_missing_required_args() -> None:
|
|
155
|
+
"""Test compiler-analyze fails gracefully when required args are missing."""
|
|
156
|
+
args = [
|
|
157
|
+
CLI_COMMAND,
|
|
158
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
159
|
+
JSON_FLAG
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
result = subprocess.run(
|
|
163
|
+
args,
|
|
164
|
+
capture_output=True,
|
|
165
|
+
text=True,
|
|
166
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
stdout_has_content = len(result.stdout.strip()) > 0
|
|
170
|
+
|
|
171
|
+
if stdout_has_content:
|
|
172
|
+
output_json = json.loads(result.stdout)
|
|
173
|
+
assert isinstance(output_json, dict)
|
|
174
|
+
success_value = output_json.get(SUCCESS_KEY)
|
|
175
|
+
has_error = ERROR_KEY in output_json
|
|
176
|
+
assert success_value is False or has_error
|
|
177
|
+
else:
|
|
178
|
+
has_nonzero_exit = result.returncode != 0
|
|
179
|
+
has_stderr = len(result.stderr) > 0
|
|
180
|
+
assert has_nonzero_exit or has_stderr
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_compiler_analyze_partial_inputs() -> None:
|
|
184
|
+
"""Test compiler-analyze with only some inputs provided."""
|
|
185
|
+
mlir_text = "module { func @kernel() { } }"
|
|
186
|
+
ptx_text = ".version 8.0\n.target sm_80"
|
|
187
|
+
empty_sass = ""
|
|
188
|
+
|
|
189
|
+
args = [
|
|
190
|
+
CLI_COMMAND,
|
|
191
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
192
|
+
MLIR_FLAG, mlir_text,
|
|
193
|
+
PTX_FLAG, ptx_text,
|
|
194
|
+
SASS_FLAG, empty_sass,
|
|
195
|
+
JSON_FLAG
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
result = subprocess.run(
|
|
199
|
+
args,
|
|
200
|
+
capture_output=True,
|
|
201
|
+
text=True,
|
|
202
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
assert result.returncode == 0
|
|
206
|
+
assert result.stderr is not None
|
|
207
|
+
|
|
208
|
+
output_json = json.loads(result.stdout)
|
|
209
|
+
assert isinstance(output_json, dict)
|
|
210
|
+
has_success = SUCCESS_KEY in output_json
|
|
211
|
+
has_error = ERROR_KEY in output_json
|
|
212
|
+
assert has_success or has_error
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def test_compiler_analyze_output_structure() -> None:
|
|
216
|
+
"""Test compiler-analyze output has expected structure."""
|
|
217
|
+
mlir_text = "module { func @kernel() { } }"
|
|
218
|
+
ptx_text = ".version 8.0\n.target sm_80"
|
|
219
|
+
sass_text = "// SASS"
|
|
220
|
+
|
|
221
|
+
args = [
|
|
222
|
+
CLI_COMMAND,
|
|
223
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
224
|
+
MLIR_FLAG, mlir_text,
|
|
225
|
+
PTX_FLAG, ptx_text,
|
|
226
|
+
SASS_FLAG, sass_text,
|
|
227
|
+
JSON_FLAG
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
result = subprocess.run(
|
|
231
|
+
args,
|
|
232
|
+
capture_output=True,
|
|
233
|
+
text=True,
|
|
234
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
assert result.returncode == 0
|
|
238
|
+
assert result.stdout is not None
|
|
239
|
+
|
|
240
|
+
output_json = json.loads(result.stdout)
|
|
241
|
+
assert isinstance(output_json, dict)
|
|
242
|
+
|
|
243
|
+
data = output_json.get(DATA_KEY, {})
|
|
244
|
+
assert isinstance(data, dict)
|
|
245
|
+
|
|
246
|
+
expected_fields = [
|
|
247
|
+
KERNEL_NAME_KEY,
|
|
248
|
+
MLIR_TEXT_KEY,
|
|
249
|
+
PTX_TEXT_KEY,
|
|
250
|
+
SASS_TEXT_KEY,
|
|
251
|
+
PARSED_MLIR_KEY,
|
|
252
|
+
PARSED_PTX_KEY,
|
|
253
|
+
PARSED_SASS_KEY,
|
|
254
|
+
LAYOUTS_KEY,
|
|
255
|
+
MEMORY_PATHS_KEY,
|
|
256
|
+
PIPELINE_STAGES_KEY,
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
for field in expected_fields:
|
|
260
|
+
assert field in data
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def test_compiler_analyze_error_handling() -> None:
|
|
264
|
+
"""Test compiler-analyze handles invalid inputs gracefully."""
|
|
265
|
+
invalid_mlir = "{ invalid json }"
|
|
266
|
+
ptx_text = ".version 8.0"
|
|
267
|
+
sass_text = "// SASS"
|
|
268
|
+
|
|
269
|
+
args = [
|
|
270
|
+
CLI_COMMAND,
|
|
271
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
272
|
+
MLIR_FLAG, invalid_mlir,
|
|
273
|
+
PTX_FLAG, ptx_text,
|
|
274
|
+
SASS_FLAG, sass_text,
|
|
275
|
+
JSON_FLAG
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
result = subprocess.run(
|
|
279
|
+
args,
|
|
280
|
+
capture_output=True,
|
|
281
|
+
text=True,
|
|
282
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
assert result.stdout is not None
|
|
286
|
+
|
|
287
|
+
output_json = json.loads(result.stdout)
|
|
288
|
+
assert isinstance(output_json, dict)
|
|
289
|
+
|
|
290
|
+
has_success = SUCCESS_KEY in output_json
|
|
291
|
+
has_error = ERROR_KEY in output_json
|
|
292
|
+
assert has_success or has_error
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def test_compiler_analyze_no_json_flag() -> None:
|
|
296
|
+
"""Test compiler-analyze without --json flag (should still work)."""
|
|
297
|
+
mlir_text = "module { func @kernel() { } }"
|
|
298
|
+
ptx_text = ".version 8.0\n.target sm_80"
|
|
299
|
+
sass_text = "// SASS"
|
|
300
|
+
|
|
301
|
+
args = [
|
|
302
|
+
CLI_COMMAND,
|
|
303
|
+
COMPILER_ANALYZE_SUBCOMMAND,
|
|
304
|
+
MLIR_FLAG, mlir_text,
|
|
305
|
+
PTX_FLAG, ptx_text,
|
|
306
|
+
SASS_FLAG, sass_text,
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
result = subprocess.run(
|
|
310
|
+
args,
|
|
311
|
+
capture_output=True,
|
|
312
|
+
text=True,
|
|
313
|
+
timeout=CLI_TIMEOUT_SECONDS,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
assert result.returncode == 0
|
|
317
|
+
assert result.stdout is not None
|
|
318
|
+
|
|
319
|
+
output_json = json.loads(result.stdout)
|
|
320
|
+
assert isinstance(output_json, dict)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Integration tests for config loading and environment resolution.
|
|
2
|
+
|
|
3
|
+
Tests WaferConfig loading and environment resolution.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
from wafer.config import WaferConfig
|
|
11
|
+
from wafer.inference import resolve_environment
|
|
12
|
+
|
|
13
|
+
# Constants
|
|
14
|
+
CONFIG_DIR = ".wafer"
|
|
15
|
+
CONFIG_FILENAME = "config.toml"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_config_and_environment_resolution() -> None:
|
|
19
|
+
"""Test config loading and environment resolution."""
|
|
20
|
+
home_dir = Path.home()
|
|
21
|
+
assert home_dir.exists()
|
|
22
|
+
|
|
23
|
+
config_path = home_dir / CONFIG_DIR / CONFIG_FILENAME
|
|
24
|
+
|
|
25
|
+
if not config_path.exists():
|
|
26
|
+
pytest.skip(f"Config not found: {config_path}")
|
|
27
|
+
|
|
28
|
+
config = WaferConfig.from_toml(config_path)
|
|
29
|
+
assert config is not None
|
|
30
|
+
|
|
31
|
+
assert config.target is not None
|
|
32
|
+
assert len(config.target) > 0
|
|
33
|
+
assert config.ssh_key is not None
|
|
34
|
+
assert len(config.ssh_key) > 0
|
|
35
|
+
assert len(config.environments) > 0
|
|
36
|
+
|
|
37
|
+
default_env = resolve_environment(config, None)
|
|
38
|
+
assert default_env is not None
|
|
39
|
+
assert default_env.docker is not None
|
|
40
|
+
assert len(default_env.docker) > 0
|
|
41
|
+
assert default_env.name is not None
|
|
42
|
+
assert len(default_env.name) > 0
|
|
43
|
+
|
|
44
|
+
env_names = list(config.environments.keys())
|
|
45
|
+
assert len(env_names) > 0
|
|
46
|
+
|
|
47
|
+
explicit_env_name = env_names[0]
|
|
48
|
+
explicit_env = resolve_environment(config, explicit_env_name)
|
|
49
|
+
assert explicit_env is not None
|
|
50
|
+
assert explicit_env.name == explicit_env_name
|