stile-verifier 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stile_verifier-0.1.0/.gitignore +72 -0
- stile_verifier-0.1.0/LICENSE +21 -0
- stile_verifier-0.1.0/PKG-INFO +193 -0
- stile_verifier-0.1.0/README.md +156 -0
- stile_verifier-0.1.0/pyproject.toml +76 -0
- stile_verifier-0.1.0/stile/__init__.py +65 -0
- stile_verifier-0.1.0/stile/frozen_counter.py +135 -0
- stile_verifier-0.1.0/stile/indexing.py +738 -0
- stile_verifier-0.1.0/stile/jax/__init__.py +20 -0
- stile_verifier-0.1.0/stile/jax/_core.py +1100 -0
- stile_verifier-0.1.0/stile/jax/pallas/__init__.py +6 -0
- stile_verifier-0.1.0/stile/jax/pallas/_core.py +240 -0
- stile_verifier-0.1.0/stile/jax/random.py +29 -0
- stile_verifier-0.1.0/stile/numpy/__init__.py +13 -0
- stile_verifier-0.1.0/stile/numpy/_core.py +199 -0
- stile_verifier-0.1.0/stile/numpy/random.py +14 -0
- stile_verifier-0.1.0/stile/specification.py +842 -0
- stile_verifier-0.1.0/stile/torch/__init__.py +13 -0
- stile_verifier-0.1.0/stile/torch/_core.py +280 -0
- stile_verifier-0.1.0/stile/torch/random.py +23 -0
- stile_verifier-0.1.0/stile/tracing.py +444 -0
- stile_verifier-0.1.0/stile/triton/__init__.py +5 -0
- stile_verifier-0.1.0/stile/triton/_core.py +1804 -0
- stile_verifier-0.1.0/stile/type.py +655 -0
- stile_verifier-0.1.0/stile/verification.md +146 -0
- stile_verifier-0.1.0/stile/verification.py +2713 -0
- stile_verifier-0.1.0/tests/test_affine_intervals.py +114 -0
- stile_verifier-0.1.0/tests/test_buggy_kernels.py +393 -0
- stile_verifier-0.1.0/tests/test_causal_attention.py +216 -0
- stile_verifier-0.1.0/tests/test_dim_annotation_predicate.py +113 -0
- stile_verifier-0.1.0/tests/test_fused_moe.py +365 -0
- stile_verifier-0.1.0/tests/test_loop_invariants.py +408 -0
- stile_verifier-0.1.0/tests/test_mask_bias_convergence.py +97 -0
- stile_verifier-0.1.0/tests/test_max_tag.py +102 -0
- stile_verifier-0.1.0/tests/test_normalization_inequivalence.py +136 -0
- stile_verifier-0.1.0/tests/test_paged_flash_attention.py +303 -0
- stile_verifier-0.1.0/tests/test_pallas_gpu.py +154 -0
- stile_verifier-0.1.0/tests/test_parametric_reduce.py +305 -0
- stile_verifier-0.1.0/tests/test_qwen_resblock_components.py +159 -0
- stile_verifier-0.1.0/tests/test_rolled_loops.py +277 -0
- stile_verifier-0.1.0/tests/test_runtime_gather.py +90 -0
- stile_verifier-0.1.0/tests/test_runtime_scatter_moe.py +233 -0
- stile_verifier-0.1.0/tests/test_tagged_tensors.py +198 -0
- stile_verifier-0.1.0/tests/test_tjax_mask.py +87 -0
- stile_verifier-0.1.0/tests/test_typed_jax.py +300 -0
- stile_verifier-0.1.0/tests/test_typed_numpy.py +288 -0
- stile_verifier-0.1.0/tests/test_typed_pallas.py +569 -0
- stile_verifier-0.1.0/tests/test_typed_result_symbolic.py +85 -0
- stile_verifier-0.1.0/tests/test_typed_torch.py +291 -0
- stile_verifier-0.1.0/tests/test_typed_triton.py +940 -0
- stile_verifier-0.1.0/tests/test_typed_triton.py.bak +828 -0
- stile_verifier-0.1.0/tests/test_where_clause.py +154 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
/target
|
|
2
|
+
|
|
3
|
+
# Byte-compiled / optimized / DLL files
|
|
4
|
+
__pycache__/
|
|
5
|
+
.pytest_cache/
|
|
6
|
+
*.py[cod]
|
|
7
|
+
|
|
8
|
+
# C extensions
|
|
9
|
+
*.so
|
|
10
|
+
|
|
11
|
+
# Distribution / packaging
|
|
12
|
+
.Python
|
|
13
|
+
.venv/
|
|
14
|
+
env/
|
|
15
|
+
bin/
|
|
16
|
+
build/
|
|
17
|
+
develop-eggs/
|
|
18
|
+
dist/
|
|
19
|
+
eggs/
|
|
20
|
+
lib/
|
|
21
|
+
lib64/
|
|
22
|
+
parts/
|
|
23
|
+
sdist/
|
|
24
|
+
var/
|
|
25
|
+
include/
|
|
26
|
+
man/
|
|
27
|
+
venv/
|
|
28
|
+
*.egg-info/
|
|
29
|
+
.installed.cfg
|
|
30
|
+
*.egg
|
|
31
|
+
|
|
32
|
+
# Installer logs
|
|
33
|
+
pip-log.txt
|
|
34
|
+
pip-delete-this-directory.txt
|
|
35
|
+
pip-selfcheck.json
|
|
36
|
+
|
|
37
|
+
# Unit test / coverage reports
|
|
38
|
+
htmlcov/
|
|
39
|
+
.tox/
|
|
40
|
+
.coverage
|
|
41
|
+
.cache
|
|
42
|
+
nosetests.xml
|
|
43
|
+
coverage.xml
|
|
44
|
+
|
|
45
|
+
# Translations
|
|
46
|
+
*.mo
|
|
47
|
+
|
|
48
|
+
# Mr Developer
|
|
49
|
+
.mr.developer.cfg
|
|
50
|
+
.project
|
|
51
|
+
.pydevproject
|
|
52
|
+
|
|
53
|
+
# Rope
|
|
54
|
+
.ropeproject
|
|
55
|
+
|
|
56
|
+
# Django stuff:
|
|
57
|
+
*.log
|
|
58
|
+
*.pot
|
|
59
|
+
|
|
60
|
+
.DS_Store
|
|
61
|
+
|
|
62
|
+
# Sphinx documentation
|
|
63
|
+
docs/_build/
|
|
64
|
+
|
|
65
|
+
# PyCharm
|
|
66
|
+
.idea/
|
|
67
|
+
|
|
68
|
+
# VSCode
|
|
69
|
+
.vscode/
|
|
70
|
+
|
|
71
|
+
# Pyenv
|
|
72
|
+
.python-version
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Sasha Krassovsky
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: stile-verifier
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A type system for numerical programs: write a spec, get a structural proof your code computes it.
|
|
5
|
+
Project-URL: Homepage, https://github.com/save-buffer/stile
|
|
6
|
+
Project-URL: Repository, https://github.com/save-buffer/stile
|
|
7
|
+
Project-URL: Issues, https://github.com/save-buffer/stile/issues
|
|
8
|
+
Author-email: Sasha Krassovsky <krassovskysasha@gmail.com>
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: einsum,jax,kernels,numerical,pallas,pytorch,triton,type-system,verification
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
|
+
Requires-Python: >=3.11
|
|
23
|
+
Requires-Dist: einops>=0.8.1
|
|
24
|
+
Requires-Dist: numpy>=1.24.4
|
|
25
|
+
Provides-Extra: all
|
|
26
|
+
Requires-Dist: jax>=0.4.0; extra == 'all'
|
|
27
|
+
Requires-Dist: torch>=2.0.0; extra == 'all'
|
|
28
|
+
Requires-Dist: triton>=3.0.0; extra == 'all'
|
|
29
|
+
Provides-Extra: jax
|
|
30
|
+
Requires-Dist: jax>=0.4.0; extra == 'jax'
|
|
31
|
+
Provides-Extra: torch
|
|
32
|
+
Requires-Dist: torch>=2.0.0; extra == 'torch'
|
|
33
|
+
Provides-Extra: triton
|
|
34
|
+
Requires-Dist: torch>=2.0.0; extra == 'triton'
|
|
35
|
+
Requires-Dist: triton>=3.0.0; extra == 'triton'
|
|
36
|
+
Description-Content-Type: text/markdown
|
|
37
|
+
|
|
38
|
+
# Stile: A Type System for Numerical Programs
|
|
39
|
+
|
|
40
|
+
**Formally Verify Your Numerical Programs**
|
|
41
|
+
|
|
42
|
+
Stile is a type system for numerical programs. Describe what your program should compute in a lightweight specification language and Stile will structurally prove the program's adherance to the spec.
|
|
43
|
+
|
|
44
|
+
A short demo of what this lets you write today:
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
# Spec: causal flash attention, one line.
|
|
48
|
+
SPEC = (
|
|
49
|
+
"(softmax[nctx where nctx <= qctx]"
|
|
50
|
+
"((qctx dhead, nctx dhead -> qctx nctx) / sqrt(16)), "
|
|
51
|
+
"nctx dhead -> qctx dhead)"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Kernel: tile-walk only the lower triangle, online softmax, bias mask
|
|
55
|
+
# on the diagonal tile. Per-qctx-tile structural proof.
|
|
56
|
+
for iqctx in range(0, qctx.size, qctx_tile_size):
|
|
57
|
+
running_max, running_l, o = -jnp.inf, 0, 0
|
|
58
|
+
for ictx in range(0, iqctx + qctx_tile_size, nctx_tile_size):
|
|
59
|
+
q_tile = Q.slice(qctx, iqctx, iqctx + qctx_tile_size)
|
|
60
|
+
k_tile = K.slice(nctx, ictx, ictx + nctx_tile_size)
|
|
61
|
+
qk = tjax.einsum(q_tile, k_tile, "qctx dhead, nctx dhead -> nctx qctx") / jnp.sqrt(dhead.size)
|
|
62
|
+
# Bias-form mask: 0 inside causal region, -inf outside.
|
|
63
|
+
qk = qk + tjax.mask(qk.type.st, "nctx <= qctx", 0.0, -jnp.inf)
|
|
64
|
+
# ... online-softmax accumulator over qk, V ...
|
|
65
|
+
o.assert_equivalent(SPEC, nctx[:(iqctx + qctx_tile_size)]) # proof
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
The verifier proves the kernel's online softmax, with `-inf` bias on the diagonal tile and skipped upper-triangle tiles, normalizes to the explicit causal-attention spec.
|
|
69
|
+
|
|
70
|
+
## How it works
|
|
71
|
+
|
|
72
|
+
A Stile-typed tensor has two types:
|
|
73
|
+
- **`ShapeType`**: which logical dims the tensor has, and how each is sliced.
|
|
74
|
+
- **`ExprType`**: the AST of operations performed up to this point.
|
|
75
|
+
|
|
76
|
+
`ShapeType` is enforced eagerly: you can only multiply two tensors with matching dim signatures, you must reduce the entire dim before assigning the result, etc.
|
|
77
|
+
|
|
78
|
+
`ExprType` is checked at assignment time. The verifier normalizes both your kernel's accumulated `ExprType` and the spec's parsed `ExprType` into a canonical form and compares. The canonical form folds:
|
|
79
|
+
- algebraic identities (`exp(0) = 1`, `0 + x = x`, `0 * x = 0`, `exp(a-b) = exp(a)/exp(b)`, `exp(-inf) = 0`, …),
|
|
80
|
+
- distribution through tagged tensors (`*` and `+` push through `Cond(D, …)` branches),
|
|
81
|
+
- iteration-domain folding (`sum(body * Cond(D, 1, 0))` collapses to a sum over `[0, N) ∩ D`),
|
|
82
|
+
- adjacent-tile interval merging on sum and max reductions, including with shared cross-variable predicates,
|
|
83
|
+
- max push-through and `-inf` absorption (so bias-form masks converge with multiplicative-form masks),
|
|
84
|
+
- post-fold invariant hoisting (the piece that lets the kernel's `exp(max)` rescaling factor cancel between numerator and denominator).
|
|
85
|
+
|
|
86
|
+
If the two normalize to the same expression, they compute the same function.
|
|
87
|
+
|
|
88
|
+
## The specification language
|
|
89
|
+
|
|
90
|
+
A specification is a small expression over named dims. Tensors are written as the sequence of their dims (`Q D` is a tensor of shape `Q × D`). Slices use `D[a:b]`. Einsums use `,` to separate operands and `->` to give the output shape.
|
|
91
|
+
|
|
92
|
+
```
|
|
93
|
+
Q D K # a tensor with three dims
|
|
94
|
+
Q D[0:8] # tensor with D sliced to [0, 8)
|
|
95
|
+
2 * Q D # scaled tensor
|
|
96
|
+
A B + A B # addition (same-shape required)
|
|
97
|
+
exp(Q D) # elementwise unary
|
|
98
|
+
(Q D, K D -> Q K) # einsum: sum over D
|
|
99
|
+
sum[K](Q K) # explicit reduction
|
|
100
|
+
sum[K where K <= Q](Q K) # iteration-restricted sum (mult-mask)
|
|
101
|
+
max[K where K <= Q](Q K) # iteration-restricted max (bias-mask)
|
|
102
|
+
softmax[K](Q K) # softmax along K
|
|
103
|
+
softmax[K where K <= Q](Q K) # causal softmax — restricts both num and den
|
|
104
|
+
sum[N](Q N where N >= 4) -> Q # mult-mask `where`-clause inside sum
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
**`[d where P]`** restricts the iteration of a reduction by the affine predicate `P`. Lowers based on the surrounding op:
|
|
108
|
+
- `sum[d where P]` uses a multiplicative mask (`P` zero-elsewhere).
|
|
109
|
+
- `max[d where P]` and `softmax[d where P]` use a bias mask (`-inf` elsewhere) — masked positions vanish through `max`'s identity and `exp`'s zero, restricting both the numerator's exp and the denominator's sum.
|
|
110
|
+
|
|
111
|
+
**`body where P`** (outside a dim annotation) is always a multiplicative mask on `body`. Use it for non-reduction sparsity and bias-on-output patterns. Inside a `sum` it folds into the reduce's domain via mask-extraction; outside it stays as a `Cond` tag.
|
|
112
|
+
|
|
113
|
+
**Predicates** are conjunctions of affine inequalities over dim names: `<=`, `<`, `>=`, `>`, `==`, plus `+`, `-`, and `int * dim`. Cross-dim predicates (`nctx <= qctx`) are first-class — they ride along in the reduce's domain and survive interval merging.
|
|
114
|
+
|
|
115
|
+
## Kernel primitives
|
|
116
|
+
|
|
117
|
+
The kernel side mirrors a small slice of JAX/Numpy. Today's main backend is `stile.jax` (`tjax`); a numpy backend exists for prototyping.
|
|
118
|
+
|
|
119
|
+
```python
|
|
120
|
+
from stile import dim
|
|
121
|
+
import stile.jax as tjax
|
|
122
|
+
|
|
123
|
+
# Dims live in a global registry so specs and kernels share names.
|
|
124
|
+
qctx, nctx, dhead = dim('qctx', 128), dim('nctx', 512), dim('dhead', 16)
|
|
125
|
+
|
|
126
|
+
# Wrap concrete arrays with their dim signature.
|
|
127
|
+
Q = tjax.random.normal(key, qctx, dhead) # has ShapeType (qctx, dhead)
|
|
128
|
+
|
|
129
|
+
# Slice. Result's st remembers it's [iqctx, iqctx+T).
|
|
130
|
+
q_tile = Q.slice(qctx, iqctx, iqctx + 32)
|
|
131
|
+
|
|
132
|
+
# Einsum. Both shapes and the AST track the contraction.
|
|
133
|
+
qk = tjax.einsum(Q, K, "qctx dhead, nctx dhead -> qctx nctx")
|
|
134
|
+
|
|
135
|
+
# Reductions. Either a method or via einsum.
|
|
136
|
+
m = qk.max(nctx)
|
|
137
|
+
s = qk.sum(nctx)
|
|
138
|
+
|
|
139
|
+
# Unary functions on TypedJaxArrays.
|
|
140
|
+
e = tjax.exp(qk - m.repeat(nctx))
|
|
141
|
+
|
|
142
|
+
# Multiplicative mask sugar — score * Cond(P, 1, 0).
|
|
143
|
+
masked = score.where("nctx <= qctx")
|
|
144
|
+
|
|
145
|
+
# Tagged-constant tensor — picks 0/1, 0/-inf, etc.
|
|
146
|
+
mult_mask = tjax.mask(score.type.st, "nctx <= qctx") # 1 inside / 0 outside
|
|
147
|
+
bias_mask = tjax.mask(score.type.st, "nctx <= qctx", 0.0, -jnp.inf) # 0 / -inf
|
|
148
|
+
|
|
149
|
+
# Rolled loops. Concrete bounds unroll; symbolic bounds emit a parametric reduce.
|
|
150
|
+
total = tjax.fori_loop(0, n, lambda i, acc: acc + body(i), init_val=0.0)
|
|
151
|
+
|
|
152
|
+
# Verify against a spec.
|
|
153
|
+
result = tjax.TypedResult(SPEC)
|
|
154
|
+
result.assign(o) # full-coverage type check
|
|
155
|
+
o.assert_equivalent(SPEC, nctx[:K]) # per-tile check with a slice override
|
|
156
|
+
result.done() # tile-coverage check (no gaps/overlaps)
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
## Status
|
|
160
|
+
|
|
161
|
+
Working today, with full structural verification:
|
|
162
|
+
- **Backends**: `stile.jax` (primary), `stile.numpy` (prototype).
|
|
163
|
+
- **Verified kernels**: matmul, online softmax, full flash attention, **tile-walking causal flash attention** (online softmax with bias-mask; structurally proven equivalent to a one-line `softmax[k where k<=q]` spec).
|
|
164
|
+
- **Spec features**: einsums, slices, reductions (`sum`, `max`, `softmax`), unary (`exp`, `sin`, `cos`, `sqrt`), multiplicative `where`-clauses, iteration-restricted `[d where P]` annotations, affine predicates with cross-dim references.
|
|
165
|
+
- **Kernel features**: slicing, einsum, all the unary/binary ops, `repeat`, `rearrange`, `fori_loop` (concrete-unroll path; symbolic-loop path with parametric reductions), `mask` intrinsic and `.where(...)` sugar.
|
|
166
|
+
|
|
167
|
+
In progress / future:
|
|
168
|
+
- **TypedPallas**: same type discipline, lowering to Pallas for actual GPU/TPU codegen.
|
|
169
|
+
- **TypedTorch**: a Torch backend exists but lags the JAX one.
|
|
170
|
+
|
|
171
|
+
## Why a type system?
|
|
172
|
+
|
|
173
|
+
Way back in the 1950s, before high-level languages, programs took big arrays of bytes as input and outputted other arrays of bytes. It was up to the programmer to remember in his head which bytes corresponded
|
|
174
|
+
to which semantic piece of the program. The invention of type annotations, structures, etc., was a step change in programming productivity because it gave semantic meanings to specific regions of memory.
|
|
175
|
+
|
|
176
|
+
The current state of the art for numerical programs is not unlike the 1950s mode of programming. We take big multidimensional arrays of floats, and output big multidimensional arrays of floats. The dimensions
|
|
177
|
+
are not semantically enforced to be different, and it's up to the programmer to remember the order of dimensions at all times. Mixing them is all too easy. Then assuming a program actually completed, you
|
|
178
|
+
have another big array of floats, and you have no idea if it's right or not. If it doesn't give you the expected result, debugging a numerical program is a massive time-sink. The solution is therefore to
|
|
179
|
+
add guardrails to prevent making stupid mistakes, in other words a type system.
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
## Running
|
|
183
|
+
|
|
184
|
+
```bash
|
|
185
|
+
uv run pytest tests/
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
Backend extras:
|
|
189
|
+
|
|
190
|
+
```bash
|
|
191
|
+
uv pip install -e ".[jax]" # JAX backend
|
|
192
|
+
uv pip install -e ".[torch]" # Torch backend (lagging)
|
|
193
|
+
```
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# Stile: A Type System for Numerical Programs
|
|
2
|
+
|
|
3
|
+
**Formally Verify Your Numerical Programs**
|
|
4
|
+
|
|
5
|
+
Stile is a type system for numerical programs. Describe what your program should compute in a lightweight specification language and Stile will structurally prove the program's adherance to the spec.
|
|
6
|
+
|
|
7
|
+
A short demo of what this lets you write today:
|
|
8
|
+
|
|
9
|
+
```python
|
|
10
|
+
# Spec: causal flash attention, one line.
|
|
11
|
+
SPEC = (
|
|
12
|
+
"(softmax[nctx where nctx <= qctx]"
|
|
13
|
+
"((qctx dhead, nctx dhead -> qctx nctx) / sqrt(16)), "
|
|
14
|
+
"nctx dhead -> qctx dhead)"
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Kernel: tile-walk only the lower triangle, online softmax, bias mask
|
|
18
|
+
# on the diagonal tile. Per-qctx-tile structural proof.
|
|
19
|
+
for iqctx in range(0, qctx.size, qctx_tile_size):
|
|
20
|
+
running_max, running_l, o = -jnp.inf, 0, 0
|
|
21
|
+
for ictx in range(0, iqctx + qctx_tile_size, nctx_tile_size):
|
|
22
|
+
q_tile = Q.slice(qctx, iqctx, iqctx + qctx_tile_size)
|
|
23
|
+
k_tile = K.slice(nctx, ictx, ictx + nctx_tile_size)
|
|
24
|
+
qk = tjax.einsum(q_tile, k_tile, "qctx dhead, nctx dhead -> nctx qctx") / jnp.sqrt(dhead.size)
|
|
25
|
+
# Bias-form mask: 0 inside causal region, -inf outside.
|
|
26
|
+
qk = qk + tjax.mask(qk.type.st, "nctx <= qctx", 0.0, -jnp.inf)
|
|
27
|
+
# ... online-softmax accumulator over qk, V ...
|
|
28
|
+
o.assert_equivalent(SPEC, nctx[:(iqctx + qctx_tile_size)]) # proof
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
The verifier proves the kernel's online softmax, with `-inf` bias on the diagonal tile and skipped upper-triangle tiles, normalizes to the explicit causal-attention spec.
|
|
32
|
+
|
|
33
|
+
## How it works
|
|
34
|
+
|
|
35
|
+
A Stile-typed tensor has two types:
|
|
36
|
+
- **`ShapeType`**: which logical dims the tensor has, and how each is sliced.
|
|
37
|
+
- **`ExprType`**: the AST of operations performed up to this point.
|
|
38
|
+
|
|
39
|
+
`ShapeType` is enforced eagerly: you can only multiply two tensors with matching dim signatures, you must reduce the entire dim before assigning the result, etc.
|
|
40
|
+
|
|
41
|
+
`ExprType` is checked at assignment time. The verifier normalizes both your kernel's accumulated `ExprType` and the spec's parsed `ExprType` into a canonical form and compares. The canonical form folds:
|
|
42
|
+
- algebraic identities (`exp(0) = 1`, `0 + x = x`, `0 * x = 0`, `exp(a-b) = exp(a)/exp(b)`, `exp(-inf) = 0`, …),
|
|
43
|
+
- distribution through tagged tensors (`*` and `+` push through `Cond(D, …)` branches),
|
|
44
|
+
- iteration-domain folding (`sum(body * Cond(D, 1, 0))` collapses to a sum over `[0, N) ∩ D`),
|
|
45
|
+
- adjacent-tile interval merging on sum and max reductions, including with shared cross-variable predicates,
|
|
46
|
+
- max push-through and `-inf` absorption (so bias-form masks converge with multiplicative-form masks),
|
|
47
|
+
- post-fold invariant hoisting (the piece that lets the kernel's `exp(max)` rescaling factor cancel between numerator and denominator).
|
|
48
|
+
|
|
49
|
+
If the two normalize to the same expression, they compute the same function.
|
|
50
|
+
|
|
51
|
+
## The specification language
|
|
52
|
+
|
|
53
|
+
A specification is a small expression over named dims. Tensors are written as the sequence of their dims (`Q D` is a tensor of shape `Q × D`). Slices use `D[a:b]`. Einsums use `,` to separate operands and `->` to give the output shape.
|
|
54
|
+
|
|
55
|
+
```
|
|
56
|
+
Q D K # a tensor with three dims
|
|
57
|
+
Q D[0:8] # tensor with D sliced to [0, 8)
|
|
58
|
+
2 * Q D # scaled tensor
|
|
59
|
+
A B + A B # addition (same-shape required)
|
|
60
|
+
exp(Q D) # elementwise unary
|
|
61
|
+
(Q D, K D -> Q K) # einsum: sum over D
|
|
62
|
+
sum[K](Q K) # explicit reduction
|
|
63
|
+
sum[K where K <= Q](Q K) # iteration-restricted sum (mult-mask)
|
|
64
|
+
max[K where K <= Q](Q K) # iteration-restricted max (bias-mask)
|
|
65
|
+
softmax[K](Q K) # softmax along K
|
|
66
|
+
softmax[K where K <= Q](Q K) # causal softmax — restricts both num and den
|
|
67
|
+
sum[N](Q N where N >= 4) -> Q # mult-mask `where`-clause inside sum
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
**`[d where P]`** restricts the iteration of a reduction by the affine predicate `P`. Lowers based on the surrounding op:
|
|
71
|
+
- `sum[d where P]` uses a multiplicative mask (`P` zero-elsewhere).
|
|
72
|
+
- `max[d where P]` and `softmax[d where P]` use a bias mask (`-inf` elsewhere) — masked positions vanish through `max`'s identity and `exp`'s zero, restricting both the numerator's exp and the denominator's sum.
|
|
73
|
+
|
|
74
|
+
**`body where P`** (outside a dim annotation) is always a multiplicative mask on `body`. Use it for non-reduction sparsity and bias-on-output patterns. Inside a `sum` it folds into the reduce's domain via mask-extraction; outside it stays as a `Cond` tag.
|
|
75
|
+
|
|
76
|
+
**Predicates** are conjunctions of affine inequalities over dim names: `<=`, `<`, `>=`, `>`, `==`, plus `+`, `-`, and `int * dim`. Cross-dim predicates (`nctx <= qctx`) are first-class — they ride along in the reduce's domain and survive interval merging.
|
|
77
|
+
|
|
78
|
+
## Kernel primitives
|
|
79
|
+
|
|
80
|
+
The kernel side mirrors a small slice of JAX/Numpy. Today's main backend is `stile.jax` (`tjax`); a numpy backend exists for prototyping.
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
from stile import dim
|
|
84
|
+
import stile.jax as tjax
|
|
85
|
+
|
|
86
|
+
# Dims live in a global registry so specs and kernels share names.
|
|
87
|
+
qctx, nctx, dhead = dim('qctx', 128), dim('nctx', 512), dim('dhead', 16)
|
|
88
|
+
|
|
89
|
+
# Wrap concrete arrays with their dim signature.
|
|
90
|
+
Q = tjax.random.normal(key, qctx, dhead) # has ShapeType (qctx, dhead)
|
|
91
|
+
|
|
92
|
+
# Slice. Result's st remembers it's [iqctx, iqctx+T).
|
|
93
|
+
q_tile = Q.slice(qctx, iqctx, iqctx + 32)
|
|
94
|
+
|
|
95
|
+
# Einsum. Both shapes and the AST track the contraction.
|
|
96
|
+
qk = tjax.einsum(Q, K, "qctx dhead, nctx dhead -> qctx nctx")
|
|
97
|
+
|
|
98
|
+
# Reductions. Either a method or via einsum.
|
|
99
|
+
m = qk.max(nctx)
|
|
100
|
+
s = qk.sum(nctx)
|
|
101
|
+
|
|
102
|
+
# Unary functions on TypedJaxArrays.
|
|
103
|
+
e = tjax.exp(qk - m.repeat(nctx))
|
|
104
|
+
|
|
105
|
+
# Multiplicative mask sugar — score * Cond(P, 1, 0).
|
|
106
|
+
masked = score.where("nctx <= qctx")
|
|
107
|
+
|
|
108
|
+
# Tagged-constant tensor — picks 0/1, 0/-inf, etc.
|
|
109
|
+
mult_mask = tjax.mask(score.type.st, "nctx <= qctx") # 1 inside / 0 outside
|
|
110
|
+
bias_mask = tjax.mask(score.type.st, "nctx <= qctx", 0.0, -jnp.inf) # 0 / -inf
|
|
111
|
+
|
|
112
|
+
# Rolled loops. Concrete bounds unroll; symbolic bounds emit a parametric reduce.
|
|
113
|
+
total = tjax.fori_loop(0, n, lambda i, acc: acc + body(i), init_val=0.0)
|
|
114
|
+
|
|
115
|
+
# Verify against a spec.
|
|
116
|
+
result = tjax.TypedResult(SPEC)
|
|
117
|
+
result.assign(o) # full-coverage type check
|
|
118
|
+
o.assert_equivalent(SPEC, nctx[:K]) # per-tile check with a slice override
|
|
119
|
+
result.done() # tile-coverage check (no gaps/overlaps)
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
## Status
|
|
123
|
+
|
|
124
|
+
Working today, with full structural verification:
|
|
125
|
+
- **Backends**: `stile.jax` (primary), `stile.numpy` (prototype).
|
|
126
|
+
- **Verified kernels**: matmul, online softmax, full flash attention, **tile-walking causal flash attention** (online softmax with bias-mask; structurally proven equivalent to a one-line `softmax[k where k<=q]` spec).
|
|
127
|
+
- **Spec features**: einsums, slices, reductions (`sum`, `max`, `softmax`), unary (`exp`, `sin`, `cos`, `sqrt`), multiplicative `where`-clauses, iteration-restricted `[d where P]` annotations, affine predicates with cross-dim references.
|
|
128
|
+
- **Kernel features**: slicing, einsum, all the unary/binary ops, `repeat`, `rearrange`, `fori_loop` (concrete-unroll path; symbolic-loop path with parametric reductions), `mask` intrinsic and `.where(...)` sugar.
|
|
129
|
+
|
|
130
|
+
In progress / future:
|
|
131
|
+
- **TypedPallas**: same type discipline, lowering to Pallas for actual GPU/TPU codegen.
|
|
132
|
+
- **TypedTorch**: a Torch backend exists but lags the JAX one.
|
|
133
|
+
|
|
134
|
+
## Why a type system?
|
|
135
|
+
|
|
136
|
+
Way back in the 1950s, before high-level languages, programs took big arrays of bytes as input and outputted other arrays of bytes. It was up to the programmer to remember in his head which bytes corresponded
|
|
137
|
+
to which semantic piece of the program. The invention of type annotations, structures, etc., was a step change in programming productivity because it gave semantic meanings to specific regions of memory.
|
|
138
|
+
|
|
139
|
+
The current state of the art for numerical programs is not unlike the 1950s mode of programming. We take big multidimensional arrays of floats, and output big multidimensional arrays of floats. The dimensions
|
|
140
|
+
are not semantically enforced to be different, and it's up to the programmer to remember the order of dimensions at all times. Mixing them is all too easy. Then assuming a program actually completed, you
|
|
141
|
+
have another big array of floats, and you have no idea if it's right or not. If it doesn't give you the expected result, debugging a numerical program is a massive time-sink. The solution is therefore to
|
|
142
|
+
add guardrails to prevent making stupid mistakes, in other words a type system.
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
## Running
|
|
146
|
+
|
|
147
|
+
```bash
|
|
148
|
+
uv run pytest tests/
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
Backend extras:
|
|
152
|
+
|
|
153
|
+
```bash
|
|
154
|
+
uv pip install -e ".[jax]" # JAX backend
|
|
155
|
+
uv pip install -e ".[torch]" # Torch backend (lagging)
|
|
156
|
+
```
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "stile-verifier"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "A type system for numerical programs: write a spec, get a structural proof your code computes it."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
license-files = ["LICENSE"]
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Sasha Krassovsky", email = "krassovskysasha@gmail.com" },
|
|
15
|
+
]
|
|
16
|
+
keywords = [
|
|
17
|
+
"verification", "type-system", "numerical", "jax", "pytorch",
|
|
18
|
+
"triton", "pallas", "kernels", "einsum",
|
|
19
|
+
]
|
|
20
|
+
classifiers = [
|
|
21
|
+
"Development Status :: 3 - Alpha",
|
|
22
|
+
"Intended Audience :: Developers",
|
|
23
|
+
"Intended Audience :: Science/Research",
|
|
24
|
+
"Topic :: Scientific/Engineering",
|
|
25
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
26
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
27
|
+
"Programming Language :: Python :: 3",
|
|
28
|
+
"Programming Language :: Python :: 3.11",
|
|
29
|
+
"Programming Language :: Python :: 3.12",
|
|
30
|
+
"Programming Language :: Python :: Implementation :: CPython",
|
|
31
|
+
]
|
|
32
|
+
dependencies = [
|
|
33
|
+
"einops>=0.8.1",
|
|
34
|
+
"numpy>=1.24.4",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[project.optional-dependencies]
|
|
38
|
+
jax = ["jax>=0.4.0"]
|
|
39
|
+
torch = ["torch>=2.0.0"]
|
|
40
|
+
triton = [
|
|
41
|
+
"torch>=2.0.0",
|
|
42
|
+
"triton>=3.0.0",
|
|
43
|
+
]
|
|
44
|
+
all = [
|
|
45
|
+
"jax>=0.4.0",
|
|
46
|
+
"torch>=2.0.0",
|
|
47
|
+
"triton>=3.0.0",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
[project.urls]
|
|
51
|
+
Homepage = "https://github.com/save-buffer/stile"
|
|
52
|
+
Repository = "https://github.com/save-buffer/stile"
|
|
53
|
+
Issues = "https://github.com/save-buffer/stile/issues"
|
|
54
|
+
|
|
55
|
+
[tool.hatch.build.targets.wheel]
|
|
56
|
+
packages = ["stile"]
|
|
57
|
+
|
|
58
|
+
[tool.hatch.build.targets.sdist]
|
|
59
|
+
include = [
|
|
60
|
+
"stile/",
|
|
61
|
+
"tests/",
|
|
62
|
+
"README.md",
|
|
63
|
+
"LICENSE",
|
|
64
|
+
"pyproject.toml",
|
|
65
|
+
]
|
|
66
|
+
exclude = [
|
|
67
|
+
"**/__pycache__",
|
|
68
|
+
"*.pyc",
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
[dependency-groups]
|
|
72
|
+
dev = [
|
|
73
|
+
"pytest>=9.0.2",
|
|
74
|
+
"build>=1.2.0",
|
|
75
|
+
"twine>=5.0.0",
|
|
76
|
+
]
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from .verification import verify_exprs_equivalent
|
|
2
|
+
from .specification import parse_spec_into_type
|
|
3
|
+
from .type import (
|
|
4
|
+
Type, FullDim, g_dim_registry, Tensor, Constant, TagCond,
|
|
5
|
+
_reset_tensor_counter,
|
|
6
|
+
)
|
|
7
|
+
from .indexing import (
|
|
8
|
+
SymbolicInt, LoopVariable, AffineExpr, SymbolicIndex, Domain, range_domain,
|
|
9
|
+
LoopScope, loop, active_loop_domain, _active_loop_scopes,
|
|
10
|
+
RuntimeScalar, runtime_scalar, _g_runtime_scalars,
|
|
11
|
+
SymInfo, symint_info, _g_symint_metadata,
|
|
12
|
+
tensor_element,
|
|
13
|
+
declare_index_properties, index_has_property, _g_index_properties,
|
|
14
|
+
declare_block_pairing, paired_index_for_offsets, _g_block_pairings,
|
|
15
|
+
declare_tensor_boundary, tensor_boundary, resolve_symbolic_index,
|
|
16
|
+
_g_tensor_boundaries,
|
|
17
|
+
)
|
|
18
|
+
from .tracing import _g_runtime_arrs
|
|
19
|
+
|
|
20
|
+
def dim(name : str, size : int) -> FullDim:
|
|
21
|
+
return FullDim(name, size)
|
|
22
|
+
|
|
23
|
+
def expr_simplifies(
|
|
24
|
+
expr : Type,
|
|
25
|
+
spec : str,
|
|
26
|
+
) -> bool:
|
|
27
|
+
spec_type = parse_spec_into_type(spec)
|
|
28
|
+
return verify_exprs_equivalent(expr.type.et, spec_type.et)
|
|
29
|
+
|
|
30
|
+
def reset_stile():
|
|
31
|
+
g_dim_registry.clear()
|
|
32
|
+
_active_loop_scopes.clear()
|
|
33
|
+
_g_runtime_scalars.clear()
|
|
34
|
+
_g_symint_metadata.clear()
|
|
35
|
+
_g_index_properties.clear()
|
|
36
|
+
_g_block_pairings.clear()
|
|
37
|
+
_g_tensor_boundaries.clear()
|
|
38
|
+
_g_runtime_arrs.clear()
|
|
39
|
+
_reset_tensor_counter()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def mask_expr(
|
|
43
|
+
dims : tuple[FullDim, ...],
|
|
44
|
+
domain : Domain,
|
|
45
|
+
) -> Tensor:
|
|
46
|
+
"""
|
|
47
|
+
A tagged Tensor whose value is `1` on positions in `domain` and `0`
|
|
48
|
+
elsewhere. The tag is `Cond(domain, Value(1), Value(0))`. `domain`'s
|
|
49
|
+
constraints should reference `LoopVariable`s named after the tensor's
|
|
50
|
+
dims — those are the symbolic dim-indices.
|
|
51
|
+
|
|
52
|
+
Common masks (causal, band, block-diagonal) are library wrappers over
|
|
53
|
+
this primitive, produced by constructing the appropriate `Domain`.
|
|
54
|
+
"""
|
|
55
|
+
tag = TagCond(
|
|
56
|
+
domain=domain,
|
|
57
|
+
if_true=Constant(1.0),
|
|
58
|
+
if_false=Constant(0.0),
|
|
59
|
+
)
|
|
60
|
+
# Mask tensors use a fixed name so two `mask_expr` calls with the
|
|
61
|
+
# same dims+predicate produce equal tensors (the tag carries the
|
|
62
|
+
# identifying info).
|
|
63
|
+
return Tensor(dims=dims, tag=tag, name="_mask")
|
|
64
|
+
|
|
65
|
+
|