metal-debug 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 imperatormk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,229 @@
1
+ Metadata-Version: 2.4
2
+ Name: metal-debug
3
+ Version: 0.1.0
4
+ Summary: Printf-style debugging for Metal compute shaders
5
+ Author: imperatormk
6
+ License-Expression: MIT
7
+ Project-URL: Repository, https://github.com/imperatormk/metal-debug
8
+ Keywords: metal,gpu,debug,apple,shader,compute
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Operating System :: MacOS
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Topic :: Software Development :: Debuggers
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Requires-Dist: numpy
17
+ Provides-Extra: torch
18
+ Requires-Dist: torch; extra == "torch"
19
+ Provides-Extra: tui
20
+ Requires-Dist: textual>=1.0; extra == "tui"
21
+ Dynamic: license-file
22
+
23
+ # metal-debug
24
+
25
+ Printf-style debugging for Metal compute shaders. No Xcode GPU debugger, no buffer dumps, no guessing.
26
+
27
+ Add `#include "metal_debug.h"` to your shader, drop in a debug buffer, see what every thread computed.
28
+
29
+ ```metal
30
+ #include "metal_debug.h"
31
+
32
+ kernel void my_kernel(
33
+ device float *A [[buffer(0)]],
34
+ device float *B [[buffer(1)]],
35
+ device float *C [[buffer(2)]],
36
+ device uint *dbg_buf [[buffer(30)]],
37
+ uint id [[thread_position_in_grid]]
38
+ ) {
39
+ float a = A[id], b = B[id];
40
+
41
+ dbg_printf(dbg_buf, id, 0, a); // log input A
42
+ dbg_printf(dbg_buf, id, 1, b); // log input B
43
+
44
+ float result = a * b;
45
+ dbg_watch_nan(dbg_buf, id, 2, result); // only logs if NaN/Inf
46
+ dbg_assert(dbg_buf, id, 3, result > 0); // GPU-side assertion
47
+
48
+ C[id] = result;
49
+ }
50
+ ```
51
+
52
+ Host side (ObjC):
53
+ ```objc
54
+ MetalDebugSession *dbg = [[MetalDebugSession alloc] initWithDevice:device maxEntries:4096];
55
+ [encoder setBuffer:dbg.buffer offset:0 atIndex:30];
56
+ // ... dispatch ...
57
+ [dbg dump];
58
+ ```
59
+
60
+ Output:
61
+ ```
62
+ [metal-debug] 24 entries
63
+ thread[0] 0: 3.5
64
+ thread[0] 1: 2.0
65
+ thread[1] 0: 1.2
66
+ thread[1] 1: -0.5
67
+ thread[1] 3: ASSERTION FAILED
68
+ ```
69
+
70
+ ## Features
71
+
72
+ | Feature | GPU API | Description |
73
+ |---------|---------|-------------|
74
+ | Printf | `dbg_printf(buf, tid, tag, val)` | Log float/int/uint/half/vec values |
75
+ | Conditional | `dbg_printf_if(buf, cond, tid, tag, val)` | Only log when condition is true |
76
+ | NaN watchpoint | `dbg_watch_nan(buf, tid, tag, val)` | Log only NaN/Inf values |
77
+ | Range watchpoint | `dbg_watch_range(buf, tid, tag, val, lo, hi)` | Log values outside range |
78
+ | Assertions | `dbg_assert(buf, tid, tag, cond)` | Record assertion failures |
79
+ | Breakpoints | `dbg_break(buf, tid, tag, cond)` | Set flag for host to detect |
80
+ | Stats | `dbg_stats(buf, tag, val)` | Cross-thread min/max/mean/count |
81
+ | Histogram | `dbg_histogram(buf, tag, val, lo, hi)` | Value distribution with bar chart |
82
+ | Named tags | Preprocessor or host-side | `"loss"` instead of `tag=42` |
83
+ | 2D grid view | Host-side | Display values as threadgroup grid |
84
+ | Diff mode | Host-side | Compare two kernel runs |
85
+ | Zero-overhead disable | `#define METAL_DEBUG_DISABLE` | Compiles out all debug calls |
86
+
87
+ ## How it works
88
+
89
+ 1. **GPU side**: `metal_debug.h` is a single header. Debug calls write `(thread_id, tag, type, value)` entries into a device buffer using atomic counters.
90
+
91
+ 2. **Host side**: `MetalDebugSession` allocates the buffer, binds it at slot 30, and reads/formats entries after execution.
92
+
93
+ 3. **No recompilation needed** when changing buffer size — `max_entries` is stored in the buffer itself and read by the GPU at runtime.
94
+
95
+ ## Build & test
96
+
97
+ ```bash
98
+ git clone <this repo>
99
+ cd metal-debug
100
+ make test # compiles + runs 9 test kernels on your GPU
101
+ ```
102
+
103
+ ## Integration
104
+
105
+ ### ObjC / C++
106
+
107
+ Copy `src/metal_debug.h` into your project. Link `runtime/MetalDebugSession.{h,m}` into your app.
108
+
109
+ ```objc
110
+ #import "MetalDebugSession.h"
111
+
112
+ MetalDebugSession *dbg = [[MetalDebugSession alloc]
113
+ initWithDevice:device maxEntries:4096];
114
+ [dbg setName:@"loss" forTag:0];
115
+
116
+ [encoder setBuffer:dbg.buffer offset:0 atIndex:30];
117
+ // dispatch kernel...
118
+
119
+ [dbg dump]; // all entries, sorted by thread
120
+ [dbg dumpTag:0]; // filter by tag
121
+ [dbg dumpGrid:0 width:8 height:8]; // 2D threadgroup view
122
+ [dbg dumpStats:0]; // min/max/mean
123
+ [dbg dumpHistogram:0 lo:0 hi:1]; // value distribution
124
+
125
+ if ([dbg breakpointHit])
126
+ [dbg dumpBreakpoint]; // what went wrong
127
+
128
+ [dbg reset]; // reuse for next dispatch
129
+ ```
130
+
131
+ ### Swift
132
+
133
+ ```swift
134
+ import Metal
135
+
136
+ let dbg = MetalDebugSession(device: device, maxEntries: 4096)
137
+ encoder.setBuffer(dbg.buffer, offset: 0, index: 30)
138
+ // dispatch kernel...
139
+ dbg.dump()
140
+ ```
141
+
142
+ See `examples/SwiftDemo/` for a complete Swift example.
143
+
144
+ ### Python (PyTorch MPS / Triton)
145
+
146
+ ```python
147
+ from metal_debug import MetalDebugSession
148
+
149
+ dbg = MetalDebugSession(max_entries=4096)
150
+ # pass dbg.tensor as buffer(30) to your Metal/Triton kernel
151
+ torch.mps.synchronize()
152
+ dbg.dump()
153
+ ```
154
+
155
+ ### Interactive TUI
156
+
157
+ Explore debug traces interactively — filter, navigate, see grid views and stats live:
158
+
159
+ ```bash
160
+ pip install textual
161
+
162
+ # Launch with demo data
163
+ python python/tui.py --demo
164
+
165
+ # Launch with a debug buffer dump
166
+ python python/tui.py trace.bin
167
+ ```
168
+
169
+ Or from Python after a kernel dispatch:
170
+ ```python
171
+ dbg.explore(grid_width=8, grid_height=8)
172
+ ```
173
+
174
+ Keyboard shortcuts:
175
+ | Key | Action |
176
+ |-----|--------|
177
+ | `↑/↓` | Navigate entries |
178
+ | `g` | Show 2D grid for selected tag |
179
+ | `a` | Show assertions only |
180
+ | `b` | Jump to breakpoint thread |
181
+ | `c` | Clear filters |
182
+ | `m` | Toggle mouse (enable copy/paste) |
183
+ | `escape` | Focus table from filter input |
184
+ | `q` | Quit |
185
+
186
+ ### Source preprocessor
187
+
188
+ Auto-inject the debug buffer parameter into kernel signatures and use string tags:
189
+
190
+ ```bash
191
+ python3 src/metal_debug_preprocess.py my_kernel.metal -o my_kernel_debug.metal
192
+ xcrun metal -I path/to/metal-debug/src -o out.metallib my_kernel_debug.metal
193
+ ```
194
+
195
+ Before:
196
+ ```metal
197
+ kernel void foo(device float *A [[buffer(0)]], uint id [[thread_position_in_grid]]) {
198
+ dbg(id, "value", A[id]);
199
+ }
200
+ ```
201
+
202
+ After preprocessing:
203
+ ```metal
204
+ kernel void foo(device float *A [[buffer(0)]], uint id [[thread_position_in_grid]],
205
+ device uint *_dbg_buf [[buffer(30)]]) {
206
+ dbg(id, 47248/*value*/, A[id]);
207
+ }
208
+ ```
209
+
210
+ ## Convenience macros
211
+
212
+ If you use `DBG_PARAM` in your kernel signature, the short macros work:
213
+
214
+ ```metal
215
+ kernel void my_kernel(device float *A [[buffer(0)]], DBG_PARAM,
216
+ uint id [[thread_position_in_grid]]) {
217
+ dbg(id, 0, A[id]); // printf
218
+ dbg_if(id == 0, id, 1, A[id]); // conditional
219
+ dbg_nan(id, 2, A[id]); // NaN watchpoint
220
+ dbg_check(id, 3, A[id] > 0); // assertion
221
+ dbg_stat(0, A[id]); // stats accumulator
222
+ dbg_hist(0, A[id], 0, 100); // histogram
223
+ dbg_brk(id, 4, A[id] < 0); // breakpoint
224
+ }
225
+ ```
226
+
227
+ ## License
228
+
229
+ MIT
@@ -0,0 +1,207 @@
1
+ # metal-debug
2
+
3
+ Printf-style debugging for Metal compute shaders. No Xcode GPU debugger, no buffer dumps, no guessing.
4
+
5
+ Add `#include "metal_debug.h"` to your shader, drop in a debug buffer, see what every thread computed.
6
+
7
+ ```metal
8
+ #include "metal_debug.h"
9
+
10
+ kernel void my_kernel(
11
+ device float *A [[buffer(0)]],
12
+ device float *B [[buffer(1)]],
13
+ device float *C [[buffer(2)]],
14
+ device uint *dbg_buf [[buffer(30)]],
15
+ uint id [[thread_position_in_grid]]
16
+ ) {
17
+ float a = A[id], b = B[id];
18
+
19
+ dbg_printf(dbg_buf, id, 0, a); // log input A
20
+ dbg_printf(dbg_buf, id, 1, b); // log input B
21
+
22
+ float result = a * b;
23
+ dbg_watch_nan(dbg_buf, id, 2, result); // only logs if NaN/Inf
24
+ dbg_assert(dbg_buf, id, 3, result > 0); // GPU-side assertion
25
+
26
+ C[id] = result;
27
+ }
28
+ ```
29
+
30
+ Host side (ObjC):
31
+ ```objc
32
+ MetalDebugSession *dbg = [[MetalDebugSession alloc] initWithDevice:device maxEntries:4096];
33
+ [encoder setBuffer:dbg.buffer offset:0 atIndex:30];
34
+ // ... dispatch ...
35
+ [dbg dump];
36
+ ```
37
+
38
+ Output:
39
+ ```
40
+ [metal-debug] 24 entries
41
+ thread[0] 0: 3.5
42
+ thread[0] 1: 2.0
43
+ thread[1] 0: 1.2
44
+ thread[1] 1: -0.5
45
+ thread[1] 3: ASSERTION FAILED
46
+ ```
47
+
48
+ ## Features
49
+
50
+ | Feature | GPU API | Description |
51
+ |---------|---------|-------------|
52
+ | Printf | `dbg_printf(buf, tid, tag, val)` | Log float/int/uint/half/vec values |
53
+ | Conditional | `dbg_printf_if(buf, cond, tid, tag, val)` | Only log when condition is true |
54
+ | NaN watchpoint | `dbg_watch_nan(buf, tid, tag, val)` | Log only NaN/Inf values |
55
+ | Range watchpoint | `dbg_watch_range(buf, tid, tag, val, lo, hi)` | Log values outside range |
56
+ | Assertions | `dbg_assert(buf, tid, tag, cond)` | Record assertion failures |
57
+ | Breakpoints | `dbg_break(buf, tid, tag, cond)` | Set flag for host to detect |
58
+ | Stats | `dbg_stats(buf, tag, val)` | Cross-thread min/max/mean/count |
59
+ | Histogram | `dbg_histogram(buf, tag, val, lo, hi)` | Value distribution with bar chart |
60
+ | Named tags | Preprocessor or host-side | `"loss"` instead of `tag=42` |
61
+ | 2D grid view | Host-side | Display values as threadgroup grid |
62
+ | Diff mode | Host-side | Compare two kernel runs |
63
+ | Zero-overhead disable | `#define METAL_DEBUG_DISABLE` | Compiles out all debug calls |
64
+
65
+ ## How it works
66
+
67
+ 1. **GPU side**: `metal_debug.h` is a single header. Debug calls write `(thread_id, tag, type, value)` entries into a device buffer using atomic counters.
68
+
69
+ 2. **Host side**: `MetalDebugSession` allocates the buffer, binds it at slot 30, and reads/formats entries after execution.
70
+
71
+ 3. **No recompilation needed** when changing buffer size — `max_entries` is stored in the buffer itself and read by the GPU at runtime.
72
+
73
+ ## Build & test
74
+
75
+ ```bash
76
+ git clone <this repo>
77
+ cd metal-debug
78
+ make test # compiles + runs 9 test kernels on your GPU
79
+ ```
80
+
81
+ ## Integration
82
+
83
+ ### ObjC / C++
84
+
85
+ Copy `src/metal_debug.h` into your project. Link `runtime/MetalDebugSession.{h,m}` into your app.
86
+
87
+ ```objc
88
+ #import "MetalDebugSession.h"
89
+
90
+ MetalDebugSession *dbg = [[MetalDebugSession alloc]
91
+ initWithDevice:device maxEntries:4096];
92
+ [dbg setName:@"loss" forTag:0];
93
+
94
+ [encoder setBuffer:dbg.buffer offset:0 atIndex:30];
95
+ // dispatch kernel...
96
+
97
+ [dbg dump]; // all entries, sorted by thread
98
+ [dbg dumpTag:0]; // filter by tag
99
+ [dbg dumpGrid:0 width:8 height:8]; // 2D threadgroup view
100
+ [dbg dumpStats:0]; // min/max/mean
101
+ [dbg dumpHistogram:0 lo:0 hi:1]; // value distribution
102
+
103
+ if ([dbg breakpointHit])
104
+ [dbg dumpBreakpoint]; // what went wrong
105
+
106
+ [dbg reset]; // reuse for next dispatch
107
+ ```
108
+
109
+ ### Swift
110
+
111
+ ```swift
112
+ import Metal
113
+
114
+ let dbg = MetalDebugSession(device: device, maxEntries: 4096)
115
+ encoder.setBuffer(dbg.buffer, offset: 0, index: 30)
116
+ // dispatch kernel...
117
+ dbg.dump()
118
+ ```
119
+
120
+ See `examples/SwiftDemo/` for a complete Swift example.
121
+
122
+ ### Python (PyTorch MPS / Triton)
123
+
124
+ ```python
125
+ from metal_debug import MetalDebugSession
126
+
127
+ dbg = MetalDebugSession(max_entries=4096)
128
+ # pass dbg.tensor as buffer(30) to your Metal/Triton kernel
129
+ torch.mps.synchronize()
130
+ dbg.dump()
131
+ ```
132
+
133
+ ### Interactive TUI
134
+
135
+ Explore debug traces interactively — filter, navigate, see grid views and stats live:
136
+
137
+ ```bash
138
+ pip install textual
139
+
140
+ # Launch with demo data
141
+ python python/tui.py --demo
142
+
143
+ # Launch with a debug buffer dump
144
+ python python/tui.py trace.bin
145
+ ```
146
+
147
+ Or from Python after a kernel dispatch:
148
+ ```python
149
+ dbg.explore(grid_width=8, grid_height=8)
150
+ ```
151
+
152
+ Keyboard shortcuts:
153
+ | Key | Action |
154
+ |-----|--------|
155
+ | `↑/↓` | Navigate entries |
156
+ | `g` | Show 2D grid for selected tag |
157
+ | `a` | Show assertions only |
158
+ | `b` | Jump to breakpoint thread |
159
+ | `c` | Clear filters |
160
+ | `m` | Toggle mouse (enable copy/paste) |
161
+ | `escape` | Focus table from filter input |
162
+ | `q` | Quit |
163
+
164
+ ### Source preprocessor
165
+
166
+ Auto-inject the debug buffer parameter into kernel signatures and use string tags:
167
+
168
+ ```bash
169
+ python3 src/metal_debug_preprocess.py my_kernel.metal -o my_kernel_debug.metal
170
+ xcrun metal -I path/to/metal-debug/src -o out.metallib my_kernel_debug.metal
171
+ ```
172
+
173
+ Before:
174
+ ```metal
175
+ kernel void foo(device float *A [[buffer(0)]], uint id [[thread_position_in_grid]]) {
176
+ dbg(id, "value", A[id]);
177
+ }
178
+ ```
179
+
180
+ After preprocessing:
181
+ ```metal
182
+ kernel void foo(device float *A [[buffer(0)]], uint id [[thread_position_in_grid]],
183
+ device uint *_dbg_buf [[buffer(30)]]) {
184
+ dbg(id, 47248/*value*/, A[id]);
185
+ }
186
+ ```
187
+
188
+ ## Convenience macros
189
+
190
+ If you use `DBG_PARAM` in your kernel signature, the short macros work:
191
+
192
+ ```metal
193
+ kernel void my_kernel(device float *A [[buffer(0)]], DBG_PARAM,
194
+ uint id [[thread_position_in_grid]]) {
195
+ dbg(id, 0, A[id]); // printf
196
+ dbg_if(id == 0, id, 1, A[id]); // conditional
197
+ dbg_nan(id, 2, A[id]); // NaN watchpoint
198
+ dbg_check(id, 3, A[id] > 0); // assertion
199
+ dbg_stat(0, A[id]); // stats accumulator
200
+ dbg_hist(0, A[id], 0, 100); // histogram
201
+ dbg_brk(id, 4, A[id] < 0); // breakpoint
202
+ }
203
+ ```
204
+
205
+ ## License
206
+
207
+ MIT
@@ -0,0 +1,35 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "metal-debug"
7
+ version = "0.1.0"
8
+ description = "Printf-style debugging for Metal compute shaders"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ requires-python = ">=3.9"
12
+ authors = [{ name = "imperatormk" }]
13
+ keywords = ["metal", "gpu", "debug", "apple", "shader", "compute"]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Operating System :: MacOS",
17
+ "Programming Language :: Python :: 3",
18
+ "Topic :: Software Development :: Debuggers",
19
+ ]
20
+ dependencies = [
21
+ "numpy",
22
+ ]
23
+
24
+ [project.optional-dependencies]
25
+ torch = ["torch"]
26
+ tui = ["textual>=1.0"]
27
+
28
+ [project.urls]
29
+ Repository = "https://github.com/imperatormk/metal-debug"
30
+
31
+ [tool.setuptools]
32
+ py-modules = ["metal_debug"]
33
+
34
+ [tool.setuptools.package-dir]
35
+ "" = "python"
@@ -0,0 +1,229 @@
1
+ Metadata-Version: 2.4
2
+ Name: metal-debug
3
+ Version: 0.1.0
4
+ Summary: Printf-style debugging for Metal compute shaders
5
+ Author: imperatormk
6
+ License-Expression: MIT
7
+ Project-URL: Repository, https://github.com/imperatormk/metal-debug
8
+ Keywords: metal,gpu,debug,apple,shader,compute
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Operating System :: MacOS
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Topic :: Software Development :: Debuggers
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Requires-Dist: numpy
17
+ Provides-Extra: torch
18
+ Requires-Dist: torch; extra == "torch"
19
+ Provides-Extra: tui
20
+ Requires-Dist: textual>=1.0; extra == "tui"
21
+ Dynamic: license-file
22
+
23
+ # metal-debug
24
+
25
+ Printf-style debugging for Metal compute shaders. No Xcode GPU debugger, no buffer dumps, no guessing.
26
+
27
+ Add `#include "metal_debug.h"` to your shader, drop in a debug buffer, see what every thread computed.
28
+
29
+ ```metal
30
+ #include "metal_debug.h"
31
+
32
+ kernel void my_kernel(
33
+ device float *A [[buffer(0)]],
34
+ device float *B [[buffer(1)]],
35
+ device float *C [[buffer(2)]],
36
+ device uint *dbg_buf [[buffer(30)]],
37
+ uint id [[thread_position_in_grid]]
38
+ ) {
39
+ float a = A[id], b = B[id];
40
+
41
+ dbg_printf(dbg_buf, id, 0, a); // log input A
42
+ dbg_printf(dbg_buf, id, 1, b); // log input B
43
+
44
+ float result = a * b;
45
+ dbg_watch_nan(dbg_buf, id, 2, result); // only logs if NaN/Inf
46
+ dbg_assert(dbg_buf, id, 3, result > 0); // GPU-side assertion
47
+
48
+ C[id] = result;
49
+ }
50
+ ```
51
+
52
+ Host side (ObjC):
53
+ ```objc
54
+ MetalDebugSession *dbg = [[MetalDebugSession alloc] initWithDevice:device maxEntries:4096];
55
+ [encoder setBuffer:dbg.buffer offset:0 atIndex:30];
56
+ // ... dispatch ...
57
+ [dbg dump];
58
+ ```
59
+
60
+ Output:
61
+ ```
62
+ [metal-debug] 24 entries
63
+ thread[0] 0: 3.5
64
+ thread[0] 1: 2.0
65
+ thread[1] 0: 1.2
66
+ thread[1] 1: -0.5
67
+ thread[1] 3: ASSERTION FAILED
68
+ ```
69
+
70
+ ## Features
71
+
72
+ | Feature | GPU API | Description |
73
+ |---------|---------|-------------|
74
+ | Printf | `dbg_printf(buf, tid, tag, val)` | Log float/int/uint/half/vec values |
75
+ | Conditional | `dbg_printf_if(buf, cond, tid, tag, val)` | Only log when condition is true |
76
+ | NaN watchpoint | `dbg_watch_nan(buf, tid, tag, val)` | Log only NaN/Inf values |
77
+ | Range watchpoint | `dbg_watch_range(buf, tid, tag, val, lo, hi)` | Log values outside range |
78
+ | Assertions | `dbg_assert(buf, tid, tag, cond)` | Record assertion failures |
79
+ | Breakpoints | `dbg_break(buf, tid, tag, cond)` | Set flag for host to detect |
80
+ | Stats | `dbg_stats(buf, tag, val)` | Cross-thread min/max/mean/count |
81
+ | Histogram | `dbg_histogram(buf, tag, val, lo, hi)` | Value distribution with bar chart |
82
+ | Named tags | Preprocessor or host-side | `"loss"` instead of `tag=42` |
83
+ | 2D grid view | Host-side | Display values as threadgroup grid |
84
+ | Diff mode | Host-side | Compare two kernel runs |
85
+ | Zero-overhead disable | `#define METAL_DEBUG_DISABLE` | Compiles out all debug calls |
86
+
87
+ ## How it works
88
+
89
+ 1. **GPU side**: `metal_debug.h` is a single header. Debug calls write `(thread_id, tag, type, value)` entries into a device buffer using atomic counters.
90
+
91
+ 2. **Host side**: `MetalDebugSession` allocates the buffer, binds it at slot 30, and reads/formats entries after execution.
92
+
93
+ 3. **No recompilation needed** when changing buffer size — `max_entries` is stored in the buffer itself and read by the GPU at runtime.
94
+
95
+ ## Build & test
96
+
97
+ ```bash
98
+ git clone <this repo>
99
+ cd metal-debug
100
+ make test # compiles + runs 9 test kernels on your GPU
101
+ ```
102
+
103
+ ## Integration
104
+
105
+ ### ObjC / C++
106
+
107
+ Copy `src/metal_debug.h` into your project. Link `runtime/MetalDebugSession.{h,m}` into your app.
108
+
109
+ ```objc
110
+ #import "MetalDebugSession.h"
111
+
112
+ MetalDebugSession *dbg = [[MetalDebugSession alloc]
113
+ initWithDevice:device maxEntries:4096];
114
+ [dbg setName:@"loss" forTag:0];
115
+
116
+ [encoder setBuffer:dbg.buffer offset:0 atIndex:30];
117
+ // dispatch kernel...
118
+
119
+ [dbg dump]; // all entries, sorted by thread
120
+ [dbg dumpTag:0]; // filter by tag
121
+ [dbg dumpGrid:0 width:8 height:8]; // 2D threadgroup view
122
+ [dbg dumpStats:0]; // min/max/mean
123
+ [dbg dumpHistogram:0 lo:0 hi:1]; // value distribution
124
+
125
+ if ([dbg breakpointHit])
126
+ [dbg dumpBreakpoint]; // what went wrong
127
+
128
+ [dbg reset]; // reuse for next dispatch
129
+ ```
130
+
131
+ ### Swift
132
+
133
+ ```swift
134
+ import Metal
135
+
136
+ let dbg = MetalDebugSession(device: device, maxEntries: 4096)
137
+ encoder.setBuffer(dbg.buffer, offset: 0, index: 30)
138
+ // dispatch kernel...
139
+ dbg.dump()
140
+ ```
141
+
142
+ See `examples/SwiftDemo/` for a complete Swift example.
143
+
144
+ ### Python (PyTorch MPS / Triton)
145
+
146
+ ```python
147
+ from metal_debug import MetalDebugSession
148
+
149
+ dbg = MetalDebugSession(max_entries=4096)
150
+ # pass dbg.tensor as buffer(30) to your Metal/Triton kernel
151
+ torch.mps.synchronize()
152
+ dbg.dump()
153
+ ```
154
+
155
+ ### Interactive TUI
156
+
157
+ Explore debug traces interactively — filter, navigate, see grid views and stats live:
158
+
159
+ ```bash
160
+ pip install textual
161
+
162
+ # Launch with demo data
163
+ python python/tui.py --demo
164
+
165
+ # Launch with a debug buffer dump
166
+ python python/tui.py trace.bin
167
+ ```
168
+
169
+ Or from Python after a kernel dispatch:
170
+ ```python
171
+ dbg.explore(grid_width=8, grid_height=8)
172
+ ```
173
+
174
+ Keyboard shortcuts:
175
+ | Key | Action |
176
+ |-----|--------|
177
+ | `↑/↓` | Navigate entries |
178
+ | `g` | Show 2D grid for selected tag |
179
+ | `a` | Show assertions only |
180
+ | `b` | Jump to breakpoint thread |
181
+ | `c` | Clear filters |
182
+ | `m` | Toggle mouse (enable copy/paste) |
183
+ | `escape` | Focus table from filter input |
184
+ | `q` | Quit |
185
+
186
+ ### Source preprocessor
187
+
188
+ Auto-inject the debug buffer parameter into kernel signatures and use string tags:
189
+
190
+ ```bash
191
+ python3 src/metal_debug_preprocess.py my_kernel.metal -o my_kernel_debug.metal
192
+ xcrun metal -I path/to/metal-debug/src -o out.metallib my_kernel_debug.metal
193
+ ```
194
+
195
+ Before:
196
+ ```metal
197
+ kernel void foo(device float *A [[buffer(0)]], uint id [[thread_position_in_grid]]) {
198
+ dbg(id, "value", A[id]);
199
+ }
200
+ ```
201
+
202
+ After preprocessing:
203
+ ```metal
204
+ kernel void foo(device float *A [[buffer(0)]], uint id [[thread_position_in_grid]],
205
+ device uint *_dbg_buf [[buffer(30)]]) {
206
+ dbg(id, 47248/*value*/, A[id]);
207
+ }
208
+ ```
209
+
210
+ ## Convenience macros
211
+
212
+ If you use `DBG_PARAM` in your kernel signature, the short macros work:
213
+
214
+ ```metal
215
+ kernel void my_kernel(device float *A [[buffer(0)]], DBG_PARAM,
216
+ uint id [[thread_position_in_grid]]) {
217
+ dbg(id, 0, A[id]); // printf
218
+ dbg_if(id == 0, id, 1, A[id]); // conditional
219
+ dbg_nan(id, 2, A[id]); // NaN watchpoint
220
+ dbg_check(id, 3, A[id] > 0); // assertion
221
+ dbg_stat(0, A[id]); // stats accumulator
222
+ dbg_hist(0, A[id], 0, 100); // histogram
223
+ dbg_brk(id, 4, A[id] < 0); // breakpoint
224
+ }
225
+ ```
226
+
227
+ ## License
228
+
229
+ MIT
@@ -0,0 +1,9 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ python/metal_debug.py
5
+ python/metal_debug.egg-info/PKG-INFO
6
+ python/metal_debug.egg-info/SOURCES.txt
7
+ python/metal_debug.egg-info/dependency_links.txt
8
+ python/metal_debug.egg-info/requires.txt
9
+ python/metal_debug.egg-info/top_level.txt
@@ -0,0 +1,7 @@
1
+ numpy
2
+
3
+ [torch]
4
+ torch
5
+
6
+ [tui]
7
+ textual>=1.0
@@ -0,0 +1 @@
1
+ metal_debug
@@ -0,0 +1,377 @@
1
+ """
2
+ metal_debug — Python/PyTorch wrapper for Metal compute shader debugging.
3
+
4
+ Usage:
5
+ from metal_debug import MetalDebugSession
6
+
7
+ dbg = MetalDebugSession(max_entries=1024)
8
+
9
+ # Your kernel gets the debug buffer as an extra argument at buffer(30)
10
+ # Pass dbg.tensor as that argument
11
+ my_kernel(A, B, C, ..., dbg.tensor, threads=..., group_size=...)
12
+ torch.mps.synchronize()
13
+
14
+ dbg.dump() # print all entries
15
+ dbg.dump(tag=2) # filter by tag
16
+ dbg.dump(thread=0) # filter by thread
17
+ dbg.stats(tag=0) # min/max/mean/count
18
+ dbg.histogram(tag=0, lo=0, hi=100)
19
+ dbg.grid(tag=50, width=4, height=4)
20
+ dbg.diff(other_session) # compare two runs
21
+
22
+ if dbg.breakpoint_hit:
23
+ dbg.dump_breakpoint()
24
+
25
+ Compile your .metal with:
26
+ xcrun metal -DMETAL_DEBUG_MAX_ENTRIES=1024 -I path/to/metal-debug/src ...
27
+ """
28
+
29
+ import torch
30
+ import struct
31
+ import json
32
+ from pathlib import Path
33
+ from typing import Optional
34
+
35
+
36
+ # Must match GPU-side constants
37
+ METAL_DEBUG_HIST_BINS = 32
38
+ METAL_DEBUG_STATS_TAGS = 256
39
+
40
+ TYPE_FLOAT = 0
41
+ TYPE_INT = 1
42
+ TYPE_UINT = 2
43
+ TYPE_HALF = 3
44
+ TYPE_ASSERT_FAIL = 4
45
+
46
+
47
+ def _decode_value(type_id: int, bits: int) -> str:
48
+ """Decode a debug entry value from its type and bit pattern."""
49
+ if type_id == TYPE_FLOAT:
50
+ val = struct.unpack('f', struct.pack('I', bits))[0]
51
+ return f"{val:.6g}"
52
+ elif type_id == TYPE_INT:
53
+ val = struct.unpack('i', struct.pack('I', bits))[0]
54
+ return str(val)
55
+ elif type_id == TYPE_UINT:
56
+ return str(bits)
57
+ elif type_id == TYPE_HALF:
58
+ # Decode half from low 16 bits
59
+ h = bits & 0xFFFF
60
+ sign = (h >> 15) & 1
61
+ exp = (h >> 10) & 0x1F
62
+ mant = h & 0x3FF
63
+ if exp == 0:
64
+ val = mant * 2**-24
65
+ elif exp == 31:
66
+ val = float('nan') if mant else float('inf')
67
+ else:
68
+ val = (mant + 1024) * 2**(exp - 25)
69
+ if sign:
70
+ val = -val
71
+ return f"{val:.4g} (half)"
72
+ elif type_id == TYPE_ASSERT_FAIL:
73
+ if bits == 0:
74
+ return "ASSERTION FAILED"
75
+ elif bits == 0xDEAD:
76
+ return "BREAKPOINT"
77
+ else:
78
+ val = struct.unpack('f', struct.pack('I', bits))[0]
79
+ return f"ASSERTION FAILED (val={val:.6g})"
80
+ return f"0x{bits:08x}"
81
+
82
+
83
+ def _decode_float(bits: int) -> float:
84
+ return struct.unpack('f', struct.pack('I', bits))[0]
85
+
86
+
87
+ class MetalDebugSession:
88
+ """Host-side debug session for Metal compute shaders."""
89
+
90
+ def __init__(self, max_entries: int = 1024, device: str = "mps"):
91
+ self.max_entries = max_entries
92
+ self._device = device
93
+ self._tag_names: dict[int, str] = {}
94
+
95
+ # Calculate buffer layout
96
+ # Header: [0]=counter, [1]=max_entries
97
+ self._entries_base = 2
98
+ self._stats_base = self._entries_base + max_entries * 4
99
+ self._hist_base = self._stats_base + METAL_DEBUG_STATS_TAGS * 4
100
+ self._break_base = self._hist_base + 16 * METAL_DEBUG_HIST_BINS
101
+ total_uints = self._break_base + 3
102
+
103
+ # Allocate as MPS tensor
104
+ self.tensor = torch.zeros(total_uints, dtype=torch.int32, device=device)
105
+ self._init_buffer()
106
+
107
+ def _init_buffer(self):
108
+ """Write max_entries to buf[1] and initialize stats min/max."""
109
+ cpu = self.tensor.cpu()
110
+ buf = cpu.numpy().view('uint32')
111
+ # Write max_entries so GPU can read it at runtime
112
+ buf[1] = self.max_entries
113
+ # Init stats
114
+ flt_max_bits = struct.unpack('I', struct.pack('f', 3.4028235e+38))[0]
115
+ flt_min_bits = struct.unpack('I', struct.pack('f', -3.4028235e+38))[0]
116
+ for tag in range(METAL_DEBUG_STATS_TAGS):
117
+ base = self._stats_base + tag * 4
118
+ buf[base + 1] = flt_max_bits
119
+ buf[base + 2] = flt_min_bits
120
+ self.tensor = torch.from_numpy(buf.view('int32').copy()).to(self._device)
121
+
122
+ def reset(self):
123
+ """Clear the debug buffer for the next dispatch."""
124
+ self.tensor = torch.zeros_like(self.tensor)
125
+ self._init_buffer()
126
+
127
+ # ── Tag names ────────────────────────────────────────────────────────────
128
+
129
+ def set_tag_name(self, tag: int, name: str):
130
+ self._tag_names[tag] = name
131
+
132
+ def load_tag_names(self, path: str):
133
+ """Load tag names from a .tags.json file (generated by preprocessor)."""
134
+ with open(path) as f:
135
+ data = json.load(f)
136
+ for tag_str, name in data.items():
137
+ self._tag_names[int(tag_str)] = name
138
+
139
+ def _tag_label(self, tag: int) -> str:
140
+ name = self._tag_names.get(tag)
141
+ return f"{name}({tag})" if name else str(tag)
142
+
143
+ # ── Read entries ─────────────────────────────────────────────────────────
144
+
145
+ def _buf(self):
146
+ """Get buffer as CPU uint32 numpy array."""
147
+ return self.tensor.cpu().numpy().view('uint32')
148
+
149
+ def entry_count(self) -> int:
150
+ buf = self._buf()
151
+ return min(int(buf[0]), self.max_entries)
152
+
153
+ def entries(self) -> list[dict]:
154
+ buf = self._buf()
155
+ count = min(int(buf[0]), self.max_entries)
156
+ result = []
157
+ for i in range(count):
158
+ base = self._entries_base + i * 4
159
+ entry = {
160
+ 'thread': int(buf[base]),
161
+ 'tag': int(buf[base + 1]),
162
+ 'type': int(buf[base + 2]),
163
+ 'value_bits': int(buf[base + 3]),
164
+ 'value': _decode_value(int(buf[base + 2]), int(buf[base + 3])),
165
+ 'tag_name': self._tag_label(int(buf[base + 1])),
166
+ }
167
+ result.append(entry)
168
+ return result
169
+
170
+ # ── Dump ─────────────────────────────────────────────────────────────────
171
+
172
+ def dump(self, tag: Optional[int] = None, thread: Optional[int] = None,
173
+ thread_range: Optional[tuple[int, int]] = None):
174
+ """Print debug entries, optionally filtered."""
175
+ entries = self.entries()
176
+
177
+ if tag is not None:
178
+ entries = [e for e in entries if e['tag'] == tag]
179
+ if thread is not None:
180
+ entries = [e for e in entries if e['thread'] == thread]
181
+ if thread_range is not None:
182
+ lo, hi = thread_range
183
+ entries = [e for e in entries if lo <= e['thread'] <= hi]
184
+
185
+ # Sort by thread, then tag
186
+ entries.sort(key=lambda e: (e['thread'], e['tag']))
187
+
188
+ total = self._buf()[0]
189
+ overflow = f" (OVERFLOW: {total} attempted)" if total > self.max_entries else ""
190
+ print(f"[metal-debug] {len(entries)} entries{overflow}")
191
+
192
+ for e in entries:
193
+ print(f" thread[{e['thread']}] {e['tag_name']}: {e['value']}")
194
+
195
+ # ── Assertions ───────────────────────────────────────────────────────────
196
+
197
+ @property
198
+ def has_assertion_failures(self) -> bool:
199
+ return any(e['type'] == TYPE_ASSERT_FAIL for e in self.entries())
200
+
201
+ def dump_assertions(self):
202
+ failures = [e for e in self.entries() if e['type'] == TYPE_ASSERT_FAIL]
203
+ if not failures:
204
+ print("[metal-debug] No assertion failures.")
205
+ return
206
+ print(f"[metal-debug] ASSERTION FAILURES:")
207
+ for e in failures:
208
+ print(f" thread[{e['thread']}] {e['tag_name']}: {e['value']}")
209
+
210
+ # ── Stats ────────────────────────────────────────────────────────────────
211
+
212
+ def stats(self, tag: int) -> dict:
213
+ buf = self._buf()
214
+ base = self._stats_base + tag * 4
215
+ count = int(buf[base])
216
+ if count == 0:
217
+ return {'count': 0, 'min': 0, 'max': 0, 'mean': 0}
218
+
219
+ min_val = _decode_float(int(buf[base + 1]))
220
+ max_val = _decode_float(int(buf[base + 2]))
221
+ fixed_sum = struct.unpack('i', struct.pack('I', int(buf[base + 3])))[0]
222
+ mean = (fixed_sum / 1024.0) / count
223
+
224
+ return {'count': count, 'min': min_val, 'max': max_val, 'mean': mean}
225
+
226
+ def dump_stats(self, tag: int):
227
+ s = self.stats(tag)
228
+ label = self._tag_label(tag)
229
+ if s['count'] == 0:
230
+ print(f"[metal-debug] stats {label}: no data")
231
+ return
232
+ print(f"[metal-debug] stats {label}: "
233
+ f"count={s['count']}, min={s['min']:.6g}, "
234
+ f"max={s['max']:.6g}, mean={s['mean']:.6g}")
235
+
236
+ # ── Histogram ────────────────────────────────────────────────────────────
237
+
238
+ def histogram(self, tag: int, lo: float, hi: float):
239
+ buf = self._buf()
240
+ base = self._hist_base + tag * METAL_DEBUG_HIST_BINS
241
+
242
+ bins = [int(buf[base + i]) for i in range(METAL_DEBUG_HIST_BINS)]
243
+ total = sum(bins)
244
+ if total == 0:
245
+ print(f"[metal-debug] histogram {self._tag_label(tag)}: no data")
246
+ return
247
+
248
+ max_count = max(bins)
249
+ bin_width = (hi - lo) / METAL_DEBUG_HIST_BINS
250
+ bar_max = 40
251
+
252
+ print(f"[metal-debug] histogram {self._tag_label(tag)} "
253
+ f"({total} values, [{lo:.4g}, {hi:.4g}]):")
254
+
255
+ for i, c in enumerate(bins):
256
+ if c == 0:
257
+ continue
258
+ b_lo = lo + i * bin_width
259
+ b_hi = b_lo + bin_width
260
+ bar_len = max(1, int(c / max_count * bar_max)) if c > 0 else 0
261
+ bar = "█" * bar_len
262
+ print(f" [{b_lo:7.3f}, {b_hi:7.3f}) {c:6d} |{bar}")
263
+
264
+ # ── Breakpoints ──────────────────────────────────────────────────────────
265
+
266
+ @property
267
+ def breakpoint_hit(self) -> bool:
268
+ buf = self._buf()
269
+ return buf[self._break_base] != 0
270
+
271
+ def dump_breakpoint(self):
272
+ buf = self._buf()
273
+ if buf[self._break_base] == 0:
274
+ print("[metal-debug] No breakpoint hit.")
275
+ return
276
+
277
+ tid = int(buf[self._break_base + 1])
278
+ tag = int(buf[self._break_base + 2])
279
+
280
+ print("[metal-debug] *** BREAKPOINT HIT ***")
281
+ print(f" First trigger: thread[{tid}] {self._tag_label(tag)}")
282
+ print(f" Debug state at break:")
283
+
284
+ entries = [e for e in self.entries() if e['thread'] == tid]
285
+ entries.sort(key=lambda e: e['tag'])
286
+ for e in entries:
287
+ print(f" {e['tag_name']}: {e['value']}")
288
+
289
+ # Count breakpoint entries
290
+ all_entries = self.entries()
291
+ hit_count = sum(1 for e in all_entries
292
+ if e['type'] == TYPE_ASSERT_FAIL and e['value_bits'] == 0xDEAD)
293
+ if hit_count > 1:
294
+ print(f" ({hit_count} threads hit this breakpoint)")
295
+
296
+ # ── Grid view ────────────────────────────────────────────────────────────
297
+
298
+ def grid(self, tag: int, width: int, height: int):
299
+ entries = [e for e in self.entries() if e['tag'] == tag]
300
+
301
+ # Build thread_id → value map (last write wins)
302
+ values = {}
303
+ for e in entries:
304
+ values[e['thread']] = e['value']
305
+
306
+ if not values:
307
+ print(f"[metal-debug] grid {self._tag_label(tag)}: no data")
308
+ return
309
+
310
+ max_len = max(3, max(len(v) for v in values.values()))
311
+ max_len = min(max_len, 10)
312
+
313
+ print(f"[metal-debug] grid {self._tag_label(tag)} ({width}x{height}):")
314
+
315
+ # Header
316
+ header = " " + "".join(f" {x:>{max_len}}" for x in range(width))
317
+ print(header)
318
+ separator = " " + "".join(" " + "─" * max_len for _ in range(width))
319
+ print(separator)
320
+
321
+ for y in range(height):
322
+ row = f" {y:3d}│"
323
+ for x in range(width):
324
+ tid = y * width + x
325
+ val = values.get(tid, "·")
326
+ if len(val) > max_len:
327
+ val = val[:max_len]
328
+ row += f" {val:>{max_len}}"
329
+ print(row)
330
+
331
+ # ── Diff ─────────────────────────────────────────────────────────────────
332
+
333
+ def snapshot(self, label: str = "") -> tuple[list[dict], str]:
334
+ return (self.entries(), label)
335
+
336
+ # ── Interactive TUI ────────────────────────────────────────────────────
337
+
338
+ def explore(self, grid_width: int = 0, grid_height: int = 0,
339
+ hist_tag: int = 0, hist_lo: float = 0, hist_hi: float = 100):
340
+ """Launch interactive TUI explorer."""
341
+ from tui import explore
342
+ explore(self, grid_width=grid_width, grid_height=grid_height,
343
+ hist_tag=hist_tag, hist_lo=hist_lo, hist_hi=hist_hi)
344
+
345
+ # ── Diff ─────────────────────────────────────────────────────────────────
346
+
347
+ @staticmethod
348
+ def diff(snap_a: tuple, snap_b: tuple):
349
+ entries_a, label_a = snap_a
350
+ entries_b, label_b = snap_b
351
+
352
+ map_a = {(e['thread'], e['tag']): e['value'] for e in entries_a}
353
+ map_b = {(e['thread'], e['tag']): e['value'] for e in entries_b}
354
+
355
+ all_keys = sorted(set(map_a.keys()) | set(map_b.keys()))
356
+
357
+ print(f'[metal-debug] diff: "{label_a or "A"}" vs "{label_b or "B"}"')
358
+
359
+ added = removed = changed = same = 0
360
+ for key in all_keys:
361
+ tid, tag = key
362
+ va = map_a.get(key)
363
+ vb = map_b.get(key)
364
+ if va and not vb:
365
+ print(f" - thread[{tid}] tag={tag}: {va}")
366
+ removed += 1
367
+ elif not va and vb:
368
+ print(f" + thread[{tid}] tag={tag}: {vb}")
369
+ added += 1
370
+ elif va != vb:
371
+ print(f" ~ thread[{tid}] tag={tag}: {va} → {vb}")
372
+ changed += 1
373
+ else:
374
+ same += 1
375
+
376
+ print(f"\n Summary: {same} same, {changed} changed, "
377
+ f"{added} added, {removed} removed")
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+