wave-gpu 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: wave-gpu
3
+ Version: 0.1.0
4
+ Summary: Write GPU kernels in Python, run on any GPU
5
+ Author: Ojima Abraham
6
+ License: Apache-2.0
7
+ Project-URL: Homepage, https://wave.ojima.me
8
+ Project-URL: Repository, https://github.com/Oabraham1/wave
9
+ Keywords: gpu,cuda,metal,hip,sycl,compute
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/markdown
15
+
16
+ # WAVE Python SDK
17
+
18
+ Write GPU kernels in Python, run on any GPU.
19
+
20
+ ## Install
21
+
22
+ ```bash
23
+ pip install wave-gpu
24
+ ```
25
+
26
+ ## Usage
27
+
28
+ ```python
29
+ import wave_gpu
30
+
31
+ @wave_gpu.kernel
32
+ def vector_add(a: wave_gpu.f32[:], b: wave_gpu.f32[:], out: wave_gpu.f32[:], n: wave_gpu.u32):
33
+ gid = wave_gpu.thread_id()
34
+ if gid < n:
35
+ out[gid] = a[gid] + b[gid]
36
+
37
+ a = wave_gpu.array([1.0, 2.0, 3.0, 4.0])
38
+ b = wave_gpu.array([5.0, 6.0, 7.0, 8.0])
39
+ out = wave_gpu.zeros(4)
40
+ vector_add(a, b, out, len(a))
41
+ print(out.to_list()) # [6.0, 8.0, 10.0, 12.0]
42
+ ```
43
+
44
+ ## License
45
+
46
+ Apache 2.0 - see [LICENSE](../../LICENSE)
@@ -0,0 +1,31 @@
1
+ # WAVE Python SDK
2
+
3
+ Write GPU kernels in Python, run on any GPU.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install wave-gpu
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import wave_gpu
15
+
16
+ @wave_gpu.kernel
17
+ def vector_add(a: wave_gpu.f32[:], b: wave_gpu.f32[:], out: wave_gpu.f32[:], n: wave_gpu.u32):
18
+ gid = wave_gpu.thread_id()
19
+ if gid < n:
20
+ out[gid] = a[gid] + b[gid]
21
+
22
+ a = wave_gpu.array([1.0, 2.0, 3.0, 4.0])
23
+ b = wave_gpu.array([5.0, 6.0, 7.0, 8.0])
24
+ out = wave_gpu.zeros(4)
25
+ vector_add(a, b, out, len(a))
26
+ print(out.to_list()) # [6.0, 8.0, 10.0, 12.0]
27
+ ```
28
+
29
+ ## License
30
+
31
+ Apache 2.0 - see [LICENSE](../../LICENSE)
@@ -0,0 +1,26 @@
1
+ [project]
2
+ name = "wave-gpu"
3
+ version = "0.1.0"
4
+ description = "Write GPU kernels in Python, run on any GPU"
5
+ authors = [{name = "Ojima Abraham"}]
6
+ license = {text = "Apache-2.0"}
7
+ requires-python = ">=3.9"
8
+ readme = "README.md"
9
+ dependencies = []
10
+ keywords = ["gpu", "cuda", "metal", "hip", "sycl", "compute"]
11
+ classifiers = [
12
+ "Development Status :: 3 - Alpha",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ "Programming Language :: Python :: 3",
15
+ ]
16
+
17
+ [build-system]
18
+ requires = ["setuptools>=68.0"]
19
+ build-backend = "setuptools.build_meta"
20
+
21
+ [tool.setuptools.packages.find]
22
+ where = ["src"]
23
+
24
+ [project.urls]
25
+ Homepage = "https://wave.ojima.me"
26
+ Repository = "https://github.com/Oabraham1/wave"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,55 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """WAVE GPU SDK for Python.
5
+
6
+ Write GPU kernels in Python, run on any GPU. Supports Apple Metal,
7
+ NVIDIA CUDA, AMD ROCm, Intel SYCL, and a built-in emulator.
8
+ """
9
+
10
+ from .array import WaveArray, array, ones, zeros
11
+ from .device import DeviceInfo, device
12
+ from .kernel import kernel
13
+ from .types import f16, f32, f64, i32, u32
14
+
15
+ __version__ = "0.1.0"
16
+
17
+ __all__ = [
18
+ "WaveArray",
19
+ "array",
20
+ "ones",
21
+ "zeros",
22
+ "DeviceInfo",
23
+ "device",
24
+ "kernel",
25
+ "f16",
26
+ "f32",
27
+ "f64",
28
+ "i32",
29
+ "u32",
30
+ ]
31
+
32
+
33
+ def thread_id() -> int:
34
+ """Placeholder for thread_id() intrinsic used in kernel source."""
35
+ raise RuntimeError("thread_id() can only be called inside a @kernel function")
36
+
37
+
38
+ def workgroup_id() -> int:
39
+ """Placeholder for workgroup_id() intrinsic used in kernel source."""
40
+ raise RuntimeError("workgroup_id() can only be called inside a @kernel function")
41
+
42
+
43
+ def lane_id() -> int:
44
+ """Placeholder for lane_id() intrinsic used in kernel source."""
45
+ raise RuntimeError("lane_id() can only be called inside a @kernel function")
46
+
47
+
48
+ def wave_width() -> int:
49
+ """Placeholder for wave_width() intrinsic used in kernel source."""
50
+ raise RuntimeError("wave_width() can only be called inside a @kernel function")
51
+
52
+
53
+ def barrier() -> None:
54
+ """Placeholder for barrier() intrinsic used in kernel source."""
55
+ raise RuntimeError("barrier() can only be called inside a @kernel function")
@@ -0,0 +1,42 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Array types for WAVE GPU kernel data."""
5
+
6
+ from typing import List, Sequence, Union
7
+
8
+
9
+ class WaveArray:
10
+ """CPU-side array that can be passed to WAVE GPU kernels."""
11
+
12
+ def __init__(self, data: Sequence[Union[int, float]], dtype: str = "f32") -> None:
13
+ self.data: List[float] = [float(x) for x in data]
14
+ self.dtype: str = dtype
15
+
16
+ def to_list(self) -> List[float]:
17
+ """Return the array contents as a Python list."""
18
+ return list(self.data)
19
+
20
+ def __len__(self) -> int:
21
+ return len(self.data)
22
+
23
+ def __getitem__(self, idx: int) -> float:
24
+ return self.data[idx]
25
+
26
+ def __repr__(self) -> str:
27
+ return f"WaveArray({self.data}, dtype='{self.dtype}')"
28
+
29
+
30
+ def array(data: Sequence[Union[int, float]], dtype: str = "f32") -> WaveArray:
31
+ """Create a WAVE array from a Python sequence."""
32
+ return WaveArray(data, dtype)
33
+
34
+
35
+ def zeros(n: int, dtype: str = "f32") -> WaveArray:
36
+ """Create a zero-filled WAVE array."""
37
+ return WaveArray([0.0] * n, dtype)
38
+
39
+
40
+ def ones(n: int, dtype: str = "f32") -> WaveArray:
41
+ """Create a WAVE array filled with ones."""
42
+ return WaveArray([1.0] * n, dtype)
@@ -0,0 +1,69 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """GPU detection for the WAVE Python SDK."""
5
+
6
+ import platform
7
+ import subprocess
8
+ from typing import Tuple
9
+
10
+
11
+ class DeviceInfo:
12
+ """Detected GPU device information."""
13
+
14
+ def __init__(self, vendor: str, name: str) -> None:
15
+ self.vendor: str = vendor
16
+ self.name: str = name
17
+
18
+ def __repr__(self) -> str:
19
+ return self.name
20
+
21
+
22
+ def detect_gpu() -> Tuple[str, str]:
23
+ """Detect the best available GPU.
24
+
25
+ Returns a (vendor, name) tuple. Vendor is one of: 'apple', 'nvidia',
26
+ 'amd', 'intel', 'emulator'.
27
+ """
28
+ if platform.system() == "Darwin":
29
+ return ("apple", "Apple GPU (Metal)")
30
+
31
+ try:
32
+ result = subprocess.run(
33
+ ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
34
+ capture_output=True,
35
+ check=True,
36
+ text=True,
37
+ )
38
+ name = result.stdout.strip().split("\n")[0]
39
+ return ("nvidia", f"{name} (CUDA)")
40
+ except (FileNotFoundError, subprocess.CalledProcessError):
41
+ pass
42
+
43
+ try:
44
+ result = subprocess.run(
45
+ ["rocminfo"], capture_output=True, check=True, text=True
46
+ )
47
+ if "gfx" in result.stdout:
48
+ for line in result.stdout.splitlines():
49
+ if "Marketing Name" in line:
50
+ name = line.split(":", 1)[1].strip()
51
+ return ("amd", f"{name} (ROCm)")
52
+ return ("amd", "AMD GPU (ROCm)")
53
+ except (FileNotFoundError, subprocess.CalledProcessError):
54
+ pass
55
+
56
+ try:
57
+ result = subprocess.run(["sycl-ls"], capture_output=True, check=True, text=True)
58
+ if "level_zero:gpu" in result.stdout or "opencl:gpu" in result.stdout:
59
+ return ("intel", "Intel GPU (SYCL)")
60
+ except (FileNotFoundError, subprocess.CalledProcessError):
61
+ pass
62
+
63
+ return ("emulator", "WAVE Emulator (no GPU)")
64
+
65
+
66
+ def device() -> DeviceInfo:
67
+ """Detect and return the best available GPU device."""
68
+ vendor, name = detect_gpu()
69
+ return DeviceInfo(vendor, name)
@@ -0,0 +1,64 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Kernel decorator and compilation for the WAVE Python SDK."""
5
+
6
+ import inspect
7
+ import textwrap
8
+ from typing import Any, Callable, List, Optional
9
+
10
+ from .array import WaveArray
11
+ from .runtime import CompiledKernel
12
+
13
+
14
+ class KernelWrapper:
15
+ """Wraps a Python function as a WAVE GPU kernel."""
16
+
17
+ def __init__(self, func: Callable[..., Any]) -> None:
18
+ self._func = func
19
+ self._source = textwrap.dedent(inspect.getsource(func))
20
+ self._compiled: Optional[CompiledKernel] = None
21
+
22
+ def _ensure_compiled(self) -> CompiledKernel:
23
+ if self._compiled is None:
24
+ self._compiled = CompiledKernel(self._source, "python")
25
+ return self._compiled
26
+
27
+ def __call__(self, *args: Any, **kwargs: Any) -> None:
28
+ """Launch the kernel with the given arguments.
29
+
30
+ Buffer arguments (WaveArray) are passed as device buffers.
31
+ Scalar arguments (int) are passed as kernel parameters.
32
+ """
33
+ grid = kwargs.get("grid", (1, 1, 1))
34
+ workgroup = kwargs.get("workgroup", (256, 1, 1))
35
+
36
+ buffers: List[WaveArray] = []
37
+ scalars: List[int] = []
38
+
39
+ for arg in args:
40
+ if isinstance(arg, WaveArray):
41
+ buffers.append(arg)
42
+ elif isinstance(arg, (int, float)):
43
+ scalars.append(int(arg))
44
+ else:
45
+ raise TypeError(
46
+ f"Unsupported argument type: {type(arg).__name__}. "
47
+ "Expected WaveArray or int."
48
+ )
49
+
50
+ compiled = self._ensure_compiled()
51
+
52
+ n_threads = max(len(b) for b in buffers) if buffers else 1
53
+ if grid == (1, 1, 1) and workgroup == (256, 1, 1):
54
+ wg_size = min(256, n_threads)
55
+ n_groups = (n_threads + wg_size - 1) // wg_size
56
+ grid = (n_groups, 1, 1)
57
+ workgroup = (wg_size, 1, 1)
58
+
59
+ compiled.launch(buffers, scalars, grid, workgroup)
60
+
61
+
62
+ def kernel(func: Callable[..., Any]) -> KernelWrapper:
63
+ """Decorator that marks a Python function as a WAVE GPU kernel."""
64
+ return KernelWrapper(func)
@@ -0,0 +1,210 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Full pipeline orchestration for the WAVE Python SDK.
5
+
6
+ Handles compilation, backend translation, and kernel launch by calling
7
+ the WAVE tool chain via subprocess.
8
+ """
9
+
10
+ import os
11
+ import struct
12
+ import subprocess
13
+ import tempfile
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional, Tuple
16
+
17
+ from .array import WaveArray
18
+ from .device import detect_gpu
19
+
20
+ _VENDOR_TO_LANG = {
21
+ "apple": "metal",
22
+ "nvidia": "ptx",
23
+ "amd": "hip",
24
+ "intel": "sycl",
25
+ }
26
+
27
+
28
+ def _find_tool(name: str) -> str:
29
+ """Find a WAVE tool binary, checking target/release/ first, then PATH."""
30
+ repo_root = Path(__file__).resolve().parents[5]
31
+ release_path = repo_root / "target" / "release" / name
32
+ if release_path.exists():
33
+ return str(release_path)
34
+
35
+ debug_path = repo_root / "target" / "debug" / name
36
+ if debug_path.exists():
37
+ return str(debug_path)
38
+
39
+ for crate_dir in repo_root.iterdir():
40
+ if crate_dir.name.startswith("wave-") and crate_dir.is_dir():
41
+ candidate = crate_dir / "target" / "release" / name
42
+ if candidate.exists():
43
+ return str(candidate)
44
+ candidate = crate_dir / "target" / "debug" / name
45
+ if candidate.exists():
46
+ return str(candidate)
47
+
48
+ return name
49
+
50
+
51
+ def compile_kernel(source: str, language: str = "python") -> bytes:
52
+ """Compile kernel source to WAVE binary (.wbin) bytes."""
53
+ compiler = _find_tool("wave-compiler")
54
+
55
+ ext_map = {"python": ".py", "rust": ".rs", "cpp": ".cpp", "typescript": ".ts"}
56
+ ext = ext_map.get(language, ".py")
57
+
58
+ with tempfile.NamedTemporaryFile(suffix=ext, mode="w", delete=False) as f:
59
+ f.write(source)
60
+ src_path = f.name
61
+
62
+ wbin_path = src_path + ".wbin"
63
+
64
+ try:
65
+ result = subprocess.run(
66
+ [compiler, src_path, "-o", wbin_path, "-l", language],
67
+ capture_output=True,
68
+ text=True,
69
+ )
70
+ if result.returncode != 0:
71
+ raise RuntimeError(f"Compilation failed:\n{result.stderr}\n{result.stdout}")
72
+ with open(wbin_path, "rb") as f:
73
+ return f.read()
74
+ finally:
75
+ for p in [src_path, wbin_path]:
76
+ try:
77
+ os.unlink(p)
78
+ except OSError:
79
+ pass
80
+
81
+
82
+ def translate_wbin(wbin: bytes, vendor: str) -> str:
83
+ """Translate .wbin to vendor-specific source code."""
84
+ if vendor == "emulator":
85
+ raise ValueError("Emulator does not need backend translation")
86
+
87
+ backend_name = f"wave-{_VENDOR_TO_LANG[vendor]}"
88
+ backend = _find_tool(backend_name)
89
+
90
+ with tempfile.NamedTemporaryFile(suffix=".wbin", delete=False) as f:
91
+ f.write(wbin)
92
+ wbin_path = f.name
93
+
94
+ try:
95
+ result = subprocess.run(
96
+ [backend, wbin_path],
97
+ capture_output=True,
98
+ text=True,
99
+ )
100
+ if result.returncode != 0:
101
+ raise RuntimeError(
102
+ f"Backend translation failed:\n{result.stderr}\n{result.stdout}"
103
+ )
104
+ return result.stdout
105
+ finally:
106
+ try:
107
+ os.unlink(wbin_path)
108
+ except OSError:
109
+ pass
110
+
111
+
112
+ def launch_emulator(
113
+ wbin: bytes,
114
+ buffers: List[WaveArray],
115
+ scalars: List[int],
116
+ grid: Tuple[int, int, int],
117
+ workgroup: Tuple[int, int, int],
118
+ ) -> None:
119
+ """Launch a kernel on the WAVE emulator."""
120
+ emulator = _find_tool("wave-emu")
121
+
122
+ with tempfile.TemporaryDirectory() as tmpdir:
123
+ wbin_path = os.path.join(tmpdir, "kernel.wbin")
124
+ with open(wbin_path, "wb") as f:
125
+ f.write(wbin)
126
+
127
+ mem_path = os.path.join(tmpdir, "memory.bin")
128
+ offsets: List[int] = []
129
+ offset = 0
130
+
131
+ mem_data = bytearray()
132
+ for buf in buffers:
133
+ offsets.append(offset)
134
+ for val in buf.data:
135
+ mem_data.extend(struct.pack("<f", val))
136
+ offset += len(buf.data) * 4
137
+
138
+ with open(mem_path, "wb") as f:
139
+ f.write(mem_data)
140
+
141
+ cmd = [
142
+ emulator,
143
+ wbin_path,
144
+ "--memory-file",
145
+ mem_path,
146
+ "--grid",
147
+ f"{grid[0]},{grid[1]},{grid[2]}",
148
+ "--workgroup",
149
+ f"{workgroup[0]},{workgroup[1]},{workgroup[2]}",
150
+ ]
151
+
152
+ for i, off in enumerate(offsets):
153
+ cmd.extend(["--reg", f"{i}={off}"])
154
+ for i, scalar in enumerate(scalars):
155
+ cmd.extend(["--reg", f"{len(buffers) + i}={scalar}"])
156
+
157
+ result = subprocess.run(cmd, capture_output=True, text=True)
158
+ if result.returncode != 0:
159
+ raise RuntimeError(
160
+ f"Emulator execution failed:\n{result.stderr}\n{result.stdout}"
161
+ )
162
+
163
+ with open(mem_path, "rb") as f:
164
+ mem_data = bytearray(f.read())
165
+
166
+ offset = 0
167
+ for buf in buffers:
168
+ size = len(buf.data) * 4
169
+ chunk = mem_data[offset : offset + size]
170
+ buf.data = [
171
+ struct.unpack("<f", chunk[j : j + 4])[0]
172
+ for j in range(0, len(chunk), 4)
173
+ ]
174
+ offset += size
175
+
176
+
177
+ class CompiledKernel:
178
+ """A compiled WAVE kernel ready for launch."""
179
+
180
+ def __init__(self, source: str, language: str = "python") -> None:
181
+ self._source = source
182
+ self._language = language
183
+ self._wbin: Optional[bytes] = None
184
+ self._vendor_cache: Dict[str, str] = {}
185
+
186
+ def _ensure_compiled(self) -> bytes:
187
+ if self._wbin is None:
188
+ self._wbin = compile_kernel(self._source, self._language)
189
+ return self._wbin
190
+
191
+ def launch(
192
+ self,
193
+ buffers: List[WaveArray],
194
+ scalars: List[int],
195
+ grid: Tuple[int, int, int] = (1, 1, 1),
196
+ workgroup: Tuple[int, int, int] = (256, 1, 1),
197
+ ) -> None:
198
+ """Launch the kernel with the given arguments."""
199
+ wbin = self._ensure_compiled()
200
+ vendor, _ = detect_gpu()
201
+
202
+ if vendor == "emulator":
203
+ launch_emulator(wbin, buffers, scalars, grid, workgroup)
204
+ else:
205
+ if vendor not in self._vendor_cache:
206
+ self._vendor_cache[vendor] = translate_wbin(wbin, vendor)
207
+ raise NotImplementedError(
208
+ f"Direct {vendor} launch not yet implemented in Python SDK. "
209
+ "Use the emulator or the Rust SDK."
210
+ )
@@ -0,0 +1,26 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Type annotations for WAVE GPU kernel parameters."""
5
+
6
+ from typing import Any
7
+
8
+
9
+ class _TypeDescriptor:
10
+ """Base class for WAVE type annotations used in kernel signatures."""
11
+
12
+ def __init__(self, name: str) -> None:
13
+ self._name = name
14
+
15
+ def __repr__(self) -> str:
16
+ return self._name
17
+
18
+ def __getitem__(self, _key: Any) -> "_TypeDescriptor":
19
+ return _TypeDescriptor(f"{self._name}[]")
20
+
21
+
22
+ f32 = _TypeDescriptor("f32")
23
+ f64 = _TypeDescriptor("f64")
24
+ f16 = _TypeDescriptor("f16")
25
+ u32 = _TypeDescriptor("u32")
26
+ i32 = _TypeDescriptor("i32")
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: wave-gpu
3
+ Version: 0.1.0
4
+ Summary: Write GPU kernels in Python, run on any GPU
5
+ Author: Ojima Abraham
6
+ License: Apache-2.0
7
+ Project-URL: Homepage, https://wave.ojima.me
8
+ Project-URL: Repository, https://github.com/Oabraham1/wave
9
+ Keywords: gpu,cuda,metal,hip,sycl,compute
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/markdown
15
+
16
+ # WAVE Python SDK
17
+
18
+ Write GPU kernels in Python, run on any GPU.
19
+
20
+ ## Install
21
+
22
+ ```bash
23
+ pip install wave-gpu
24
+ ```
25
+
26
+ ## Usage
27
+
28
+ ```python
29
+ import wave_gpu
30
+
31
+ @wave_gpu.kernel
32
+ def vector_add(a: wave_gpu.f32[:], b: wave_gpu.f32[:], out: wave_gpu.f32[:], n: wave_gpu.u32):
33
+ gid = wave_gpu.thread_id()
34
+ if gid < n:
35
+ out[gid] = a[gid] + b[gid]
36
+
37
+ a = wave_gpu.array([1.0, 2.0, 3.0, 4.0])
38
+ b = wave_gpu.array([5.0, 6.0, 7.0, 8.0])
39
+ out = wave_gpu.zeros(4)
40
+ vector_add(a, b, out, len(a))
41
+ print(out.to_list()) # [6.0, 8.0, 10.0, 12.0]
42
+ ```
43
+
44
+ ## License
45
+
46
+ Apache 2.0 - see [LICENSE](../../LICENSE)
@@ -0,0 +1,16 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/wave_gpu/__init__.py
4
+ src/wave_gpu/array.py
5
+ src/wave_gpu/device.py
6
+ src/wave_gpu/kernel.py
7
+ src/wave_gpu/runtime.py
8
+ src/wave_gpu/types.py
9
+ src/wave_gpu.egg-info/PKG-INFO
10
+ src/wave_gpu.egg-info/SOURCES.txt
11
+ src/wave_gpu.egg-info/dependency_links.txt
12
+ src/wave_gpu.egg-info/top_level.txt
13
+ tests/test_array.py
14
+ tests/test_device.py
15
+ tests/test_kernel.py
16
+ tests/test_vector_add.py
@@ -0,0 +1 @@
1
+ wave_gpu
@@ -0,0 +1,47 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tests for the WaveArray type."""
5
+
6
+ import wave_gpu
7
+
8
+
9
+ def test_array_creation():
10
+ a = wave_gpu.array([1.0, 2.0, 3.0])
11
+ assert len(a) == 3
12
+ assert a[0] == 1.0
13
+ assert a[1] == 2.0
14
+ assert a[2] == 3.0
15
+
16
+
17
+ def test_array_to_list():
18
+ a = wave_gpu.array([1.0, 2.0, 3.0, 4.0])
19
+ assert a.to_list() == [1.0, 2.0, 3.0, 4.0]
20
+
21
+
22
+ def test_zeros():
23
+ a = wave_gpu.zeros(5)
24
+ assert len(a) == 5
25
+ assert a.to_list() == [0.0, 0.0, 0.0, 0.0, 0.0]
26
+
27
+
28
+ def test_ones():
29
+ a = wave_gpu.ones(3)
30
+ assert len(a) == 3
31
+ assert a.to_list() == [1.0, 1.0, 1.0]
32
+
33
+
34
+ def test_array_dtype():
35
+ a = wave_gpu.array([1, 2, 3], dtype="u32")
36
+ assert a.dtype == "u32"
37
+ assert a.to_list() == [1.0, 2.0, 3.0]
38
+
39
+
40
+ def test_array_from_integers():
41
+ a = wave_gpu.array([1, 2, 3])
42
+ assert a.to_list() == [1.0, 2.0, 3.0]
43
+
44
+
45
+ def test_array_repr():
46
+ a = wave_gpu.array([1.0, 2.0])
47
+ assert "WaveArray" in repr(a)
@@ -0,0 +1,17 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tests for GPU device detection."""
5
+
6
+ import wave_gpu
7
+
8
+
9
+ def test_device_returns_info():
10
+ dev = wave_gpu.device()
11
+ assert dev.vendor in ("apple", "nvidia", "amd", "intel", "emulator")
12
+ assert len(dev.name) > 0
13
+
14
+
15
+ def test_device_repr():
16
+ dev = wave_gpu.device()
17
+ assert len(repr(dev)) > 0
@@ -0,0 +1,38 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tests for kernel compilation and the @kernel decorator."""
5
+
6
+ import pytest
7
+
8
+ import wave_gpu
9
+
10
+
11
+ def test_kernel_decorator():
12
+ @wave_gpu.kernel
13
+ def vector_add(a, b, out, n):
14
+ gid = wave_gpu.thread_id()
15
+ if gid < n:
16
+ out[gid] = a[gid] + b[gid]
17
+
18
+ assert hasattr(vector_add, "_source")
19
+ assert "vector_add" in vector_add._source
20
+
21
+
22
+ def test_thread_id_outside_kernel():
23
+ with pytest.raises(RuntimeError, match="inside a @kernel"):
24
+ wave_gpu.thread_id()
25
+
26
+
27
+ def test_barrier_outside_kernel():
28
+ with pytest.raises(RuntimeError, match="inside a @kernel"):
29
+ wave_gpu.barrier()
30
+
31
+
32
+ def test_kernel_bad_arg_type():
33
+ @wave_gpu.kernel
34
+ def noop(x):
35
+ pass
36
+
37
+ with pytest.raises(TypeError, match="Unsupported argument type"):
38
+ noop("not a valid arg")
@@ -0,0 +1,39 @@
1
+ # Copyright 2026 Ojima Abraham
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """End-to-end vector addition test using the runtime compilation pipeline."""
5
+
6
+ import pytest
7
+
8
+ import wave_gpu
9
+ from wave_gpu.runtime import compile_kernel
10
+
11
+
12
+ @pytest.mark.skipif(
13
+ True,
14
+ reason="Requires wave-compiler binary on PATH or in target/",
15
+ )
16
+ def test_vector_add_e2e():
17
+ """Full end-to-end: compile + launch vector_add on emulator."""
18
+ a = wave_gpu.array([1.0, 2.0, 3.0, 4.0])
19
+ b = wave_gpu.array([5.0, 6.0, 7.0, 8.0])
20
+ out = wave_gpu.zeros(4)
21
+
22
+ @wave_gpu.kernel
23
+ def vector_add(
24
+ a: wave_gpu.f32[:],
25
+ b: wave_gpu.f32[:],
26
+ out: wave_gpu.f32[:],
27
+ n: wave_gpu.u32,
28
+ ):
29
+ gid = wave_gpu.thread_id()
30
+ if gid < n:
31
+ out[gid] = a[gid] + b[gid]
32
+
33
+ vector_add(a, b, out, len(a))
34
+ assert out.to_list() == [6.0, 8.0, 10.0, 12.0]
35
+
36
+
37
+ def test_compile_kernel_source():
38
+ """Test that compile_kernel function exists and has correct signature."""
39
+ assert callable(compile_kernel)