stile-verifier 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. stile_verifier-0.1.0/.gitignore +72 -0
  2. stile_verifier-0.1.0/LICENSE +21 -0
  3. stile_verifier-0.1.0/PKG-INFO +193 -0
  4. stile_verifier-0.1.0/README.md +156 -0
  5. stile_verifier-0.1.0/pyproject.toml +76 -0
  6. stile_verifier-0.1.0/stile/__init__.py +65 -0
  7. stile_verifier-0.1.0/stile/frozen_counter.py +135 -0
  8. stile_verifier-0.1.0/stile/indexing.py +738 -0
  9. stile_verifier-0.1.0/stile/jax/__init__.py +20 -0
  10. stile_verifier-0.1.0/stile/jax/_core.py +1100 -0
  11. stile_verifier-0.1.0/stile/jax/pallas/__init__.py +6 -0
  12. stile_verifier-0.1.0/stile/jax/pallas/_core.py +240 -0
  13. stile_verifier-0.1.0/stile/jax/random.py +29 -0
  14. stile_verifier-0.1.0/stile/numpy/__init__.py +13 -0
  15. stile_verifier-0.1.0/stile/numpy/_core.py +199 -0
  16. stile_verifier-0.1.0/stile/numpy/random.py +14 -0
  17. stile_verifier-0.1.0/stile/specification.py +842 -0
  18. stile_verifier-0.1.0/stile/torch/__init__.py +13 -0
  19. stile_verifier-0.1.0/stile/torch/_core.py +280 -0
  20. stile_verifier-0.1.0/stile/torch/random.py +23 -0
  21. stile_verifier-0.1.0/stile/tracing.py +444 -0
  22. stile_verifier-0.1.0/stile/triton/__init__.py +5 -0
  23. stile_verifier-0.1.0/stile/triton/_core.py +1804 -0
  24. stile_verifier-0.1.0/stile/type.py +655 -0
  25. stile_verifier-0.1.0/stile/verification.md +146 -0
  26. stile_verifier-0.1.0/stile/verification.py +2713 -0
  27. stile_verifier-0.1.0/tests/test_affine_intervals.py +114 -0
  28. stile_verifier-0.1.0/tests/test_buggy_kernels.py +393 -0
  29. stile_verifier-0.1.0/tests/test_causal_attention.py +216 -0
  30. stile_verifier-0.1.0/tests/test_dim_annotation_predicate.py +113 -0
  31. stile_verifier-0.1.0/tests/test_fused_moe.py +365 -0
  32. stile_verifier-0.1.0/tests/test_loop_invariants.py +408 -0
  33. stile_verifier-0.1.0/tests/test_mask_bias_convergence.py +97 -0
  34. stile_verifier-0.1.0/tests/test_max_tag.py +102 -0
  35. stile_verifier-0.1.0/tests/test_normalization_inequivalence.py +136 -0
  36. stile_verifier-0.1.0/tests/test_paged_flash_attention.py +303 -0
  37. stile_verifier-0.1.0/tests/test_pallas_gpu.py +154 -0
  38. stile_verifier-0.1.0/tests/test_parametric_reduce.py +305 -0
  39. stile_verifier-0.1.0/tests/test_qwen_resblock_components.py +159 -0
  40. stile_verifier-0.1.0/tests/test_rolled_loops.py +277 -0
  41. stile_verifier-0.1.0/tests/test_runtime_gather.py +90 -0
  42. stile_verifier-0.1.0/tests/test_runtime_scatter_moe.py +233 -0
  43. stile_verifier-0.1.0/tests/test_tagged_tensors.py +198 -0
  44. stile_verifier-0.1.0/tests/test_tjax_mask.py +87 -0
  45. stile_verifier-0.1.0/tests/test_typed_jax.py +300 -0
  46. stile_verifier-0.1.0/tests/test_typed_numpy.py +288 -0
  47. stile_verifier-0.1.0/tests/test_typed_pallas.py +569 -0
  48. stile_verifier-0.1.0/tests/test_typed_result_symbolic.py +85 -0
  49. stile_verifier-0.1.0/tests/test_typed_torch.py +291 -0
  50. stile_verifier-0.1.0/tests/test_typed_triton.py +940 -0
  51. stile_verifier-0.1.0/tests/test_typed_triton.py.bak +828 -0
  52. stile_verifier-0.1.0/tests/test_where_clause.py +154 -0
@@ -0,0 +1,72 @@
1
+ /target
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ .pytest_cache/
6
+ *.py[cod]
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ .venv/
14
+ env/
15
+ bin/
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ include/
26
+ man/
27
+ venv/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+
32
+ # Installer logs
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+ pip-selfcheck.json
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .coverage
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+
45
+ # Translations
46
+ *.mo
47
+
48
+ # Mr Developer
49
+ .mr.developer.cfg
50
+ .project
51
+ .pydevproject
52
+
53
+ # Rope
54
+ .ropeproject
55
+
56
+ # Django stuff:
57
+ *.log
58
+ *.pot
59
+
60
+ .DS_Store
61
+
62
+ # Sphinx documentation
63
+ docs/_build/
64
+
65
+ # PyCharm
66
+ .idea/
67
+
68
+ # VSCode
69
+ .vscode/
70
+
71
+ # Pyenv
72
+ .python-version
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Sasha Krassovsky
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,193 @@
1
+ Metadata-Version: 2.4
2
+ Name: stile-verifier
3
+ Version: 0.1.0
4
+ Summary: A type system for numerical programs: write a spec, get a structural proof your code computes it.
5
+ Project-URL: Homepage, https://github.com/save-buffer/stile
6
+ Project-URL: Repository, https://github.com/save-buffer/stile
7
+ Project-URL: Issues, https://github.com/save-buffer/stile/issues
8
+ Author-email: Sasha Krassovsky <krassovskysasha@gmail.com>
9
+ License-Expression: MIT
10
+ License-File: LICENSE
11
+ Keywords: einsum,jax,kernels,numerical,pallas,pytorch,triton,type-system,verification
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: Implementation :: CPython
19
+ Classifier: Topic :: Scientific/Engineering
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
22
+ Requires-Python: >=3.11
23
+ Requires-Dist: einops>=0.8.1
24
+ Requires-Dist: numpy>=1.24.4
25
+ Provides-Extra: all
26
+ Requires-Dist: jax>=0.4.0; extra == 'all'
27
+ Requires-Dist: torch>=2.0.0; extra == 'all'
28
+ Requires-Dist: triton>=3.0.0; extra == 'all'
29
+ Provides-Extra: jax
30
+ Requires-Dist: jax>=0.4.0; extra == 'jax'
31
+ Provides-Extra: torch
32
+ Requires-Dist: torch>=2.0.0; extra == 'torch'
33
+ Provides-Extra: triton
34
+ Requires-Dist: torch>=2.0.0; extra == 'triton'
35
+ Requires-Dist: triton>=3.0.0; extra == 'triton'
36
+ Description-Content-Type: text/markdown
37
+
38
+ # Stile: A Type System for Numerical Programs
39
+
40
+ **Formally Verify Your Numerical Programs**
41
+
42
+ Stile is a type system for numerical programs. Describe what your program should compute in a lightweight specification language and Stile will structurally prove the program's adherance to the spec.
43
+
44
+ A short demo of what this lets you write today:
45
+
46
+ ```python
47
+ # Spec: causal flash attention, one line.
48
+ SPEC = (
49
+ "(softmax[nctx where nctx <= qctx]"
50
+ "((qctx dhead, nctx dhead -> qctx nctx) / sqrt(16)), "
51
+ "nctx dhead -> qctx dhead)"
52
+ )
53
+
54
+ # Kernel: tile-walk only the lower triangle, online softmax, bias mask
55
+ # on the diagonal tile. Per-qctx-tile structural proof.
56
+ for iqctx in range(0, qctx.size, qctx_tile_size):
57
+ running_max, running_l, o = -jnp.inf, 0, 0
58
+ for ictx in range(0, iqctx + qctx_tile_size, nctx_tile_size):
59
+ q_tile = Q.slice(qctx, iqctx, iqctx + qctx_tile_size)
60
+ k_tile = K.slice(nctx, ictx, ictx + nctx_tile_size)
61
+ qk = tjax.einsum(q_tile, k_tile, "qctx dhead, nctx dhead -> nctx qctx") / jnp.sqrt(dhead.size)
62
+ # Bias-form mask: 0 inside causal region, -inf outside.
63
+ qk = qk + tjax.mask(qk.type.st, "nctx <= qctx", 0.0, -jnp.inf)
64
+ # ... online-softmax accumulator over qk, V ...
65
+ o.assert_equivalent(SPEC, nctx[:(iqctx + qctx_tile_size)]) # proof
66
+ ```
67
+
68
+ The verifier proves the kernel's online softmax, with `-inf` bias on the diagonal tile and skipped upper-triangle tiles, normalizes to the explicit causal-attention spec.
69
+
70
+ ## How it works
71
+
72
+ A Stile-typed tensor has two types:
73
+ - **`ShapeType`**: which logical dims the tensor has, and how each is sliced.
74
+ - **`ExprType`**: the AST of operations performed up to this point.
75
+
76
+ `ShapeType` is enforced eagerly: you can only multiply two tensors with matching dim signatures, you must reduce the entire dim before assigning the result, etc.
77
+
78
+ `ExprType` is checked at assignment time. The verifier normalizes both your kernel's accumulated `ExprType` and the spec's parsed `ExprType` into a canonical form and compares. The canonical form folds:
79
+ - algebraic identities (`exp(0) = 1`, `0 + x = x`, `0 * x = 0`, `exp(a-b) = exp(a)/exp(b)`, `exp(-inf) = 0`, …),
80
+ - distribution through tagged tensors (`*` and `+` push through `Cond(D, …)` branches),
81
+ - iteration-domain folding (`sum(body * Cond(D, 1, 0))` collapses to a sum over `[0, N) ∩ D`),
82
+ - adjacent-tile interval merging on sum and max reductions, including with shared cross-variable predicates,
83
+ - max push-through and `-inf` absorption (so bias-form masks converge with multiplicative-form masks),
84
+ - post-fold invariant hoisting (the piece that lets the kernel's `exp(max)` rescaling factor cancel between numerator and denominator).
85
+
86
+ If the two normalize to the same expression, they compute the same function.
87
+
88
+ ## The specification language
89
+
90
+ A specification is a small expression over named dims. Tensors are written as the sequence of their dims (`Q D` is a tensor of shape `Q × D`). Slices use `D[a:b]`. Einsums use `,` to separate operands and `->` to give the output shape.
91
+
92
+ ```
93
+ Q D K # a tensor with three dims
94
+ Q D[0:8] # tensor with D sliced to [0, 8)
95
+ 2 * Q D # scaled tensor
96
+ A B + A B # addition (same-shape required)
97
+ exp(Q D) # elementwise unary
98
+ (Q D, K D -> Q K) # einsum: sum over D
99
+ sum[K](Q K) # explicit reduction
100
+ sum[K where K <= Q](Q K) # iteration-restricted sum (mult-mask)
101
+ max[K where K <= Q](Q K) # iteration-restricted max (bias-mask)
102
+ softmax[K](Q K) # softmax along K
103
+ softmax[K where K <= Q](Q K) # causal softmax — restricts both num and den
104
+ sum[N](Q N where N >= 4) -> Q # mult-mask `where`-clause inside sum
105
+ ```
106
+
107
+ **`[d where P]`** restricts the iteration of a reduction by the affine predicate `P`. Lowers based on the surrounding op:
108
+ - `sum[d where P]` uses a multiplicative mask (`P` zero-elsewhere).
109
+ - `max[d where P]` and `softmax[d where P]` use a bias mask (`-inf` elsewhere) — masked positions vanish through `max`'s identity and `exp`'s zero, restricting both the numerator's exp and the denominator's sum.
110
+
111
+ **`body where P`** (outside a dim annotation) is always a multiplicative mask on `body`. Use it for non-reduction sparsity and bias-on-output patterns. Inside a `sum` it folds into the reduce's domain via mask-extraction; outside it stays as a `Cond` tag.
112
+
113
+ **Predicates** are conjunctions of affine inequalities over dim names: `<=`, `<`, `>=`, `>`, `==`, plus `+`, `-`, and `int * dim`. Cross-dim predicates (`nctx <= qctx`) are first-class — they ride along in the reduce's domain and survive interval merging.
114
+
115
+ ## Kernel primitives
116
+
117
+ The kernel side mirrors a small slice of JAX/Numpy. Today's main backend is `stile.jax` (`tjax`); a numpy backend exists for prototyping.
118
+
119
+ ```python
120
+ from stile import dim
121
+ import stile.jax as tjax
122
+
123
+ # Dims live in a global registry so specs and kernels share names.
124
+ qctx, nctx, dhead = dim('qctx', 128), dim('nctx', 512), dim('dhead', 16)
125
+
126
+ # Wrap concrete arrays with their dim signature.
127
+ Q = tjax.random.normal(key, qctx, dhead) # has ShapeType (qctx, dhead)
128
+
129
+ # Slice. Result's st remembers it's [iqctx, iqctx+T).
130
+ q_tile = Q.slice(qctx, iqctx, iqctx + 32)
131
+
132
+ # Einsum. Both shapes and the AST track the contraction.
133
+ qk = tjax.einsum(Q, K, "qctx dhead, nctx dhead -> qctx nctx")
134
+
135
+ # Reductions. Either a method or via einsum.
136
+ m = qk.max(nctx)
137
+ s = qk.sum(nctx)
138
+
139
+ # Unary functions on TypedJaxArrays.
140
+ e = tjax.exp(qk - m.repeat(nctx))
141
+
142
+ # Multiplicative mask sugar — score * Cond(P, 1, 0).
143
+ masked = score.where("nctx <= qctx")
144
+
145
+ # Tagged-constant tensor — picks 0/1, 0/-inf, etc.
146
+ mult_mask = tjax.mask(score.type.st, "nctx <= qctx") # 1 inside / 0 outside
147
+ bias_mask = tjax.mask(score.type.st, "nctx <= qctx", 0.0, -jnp.inf) # 0 / -inf
148
+
149
+ # Rolled loops. Concrete bounds unroll; symbolic bounds emit a parametric reduce.
150
+ total = tjax.fori_loop(0, n, lambda i, acc: acc + body(i), init_val=0.0)
151
+
152
+ # Verify against a spec.
153
+ result = tjax.TypedResult(SPEC)
154
+ result.assign(o) # full-coverage type check
155
+ o.assert_equivalent(SPEC, nctx[:K]) # per-tile check with a slice override
156
+ result.done() # tile-coverage check (no gaps/overlaps)
157
+ ```
158
+
159
+ ## Status
160
+
161
+ Working today, with full structural verification:
162
+ - **Backends**: `stile.jax` (primary), `stile.numpy` (prototype).
163
+ - **Verified kernels**: matmul, online softmax, full flash attention, **tile-walking causal flash attention** (online softmax with bias-mask; structurally proven equivalent to a one-line `softmax[k where k<=q]` spec).
164
+ - **Spec features**: einsums, slices, reductions (`sum`, `max`, `softmax`), unary (`exp`, `sin`, `cos`, `sqrt`), multiplicative `where`-clauses, iteration-restricted `[d where P]` annotations, affine predicates with cross-dim references.
165
+ - **Kernel features**: slicing, einsum, all the unary/binary ops, `repeat`, `rearrange`, `fori_loop` (concrete-unroll path; symbolic-loop path with parametric reductions), `mask` intrinsic and `.where(...)` sugar.
166
+
167
+ In progress / future:
168
+ - **TypedPallas**: same type discipline, lowering to Pallas for actual GPU/TPU codegen.
169
+ - **TypedTorch**: a Torch backend exists but lags the JAX one.
170
+
171
+ ## Why a type system?
172
+
173
+ Way back in the 1950s, before high-level languages, programs took big arrays of bytes as input and outputted other arrays of bytes. It was up to the programmer to remember in his head which bytes corresponded
174
+ to which semantic piece of the program. The invention of type annotations, structures, etc., was a step change in programming productivity because it gave semantic meanings to specific regions of memory.
175
+
176
+ The current state of the art for numerical programs is not unlike the 1950s mode of programming. We take big multidimensional arrays of floats, and output big multidimensional arrays of floats. The dimensions
177
+ are not semantically enforced to be different, and it's up to the programmer to remember the order of dimensions at all times. Mixing them is all too easy. Then assuming a program actually completed, you
178
+ have another big array of floats, and you have no idea if it's right or not. If it doesn't give you the expected result, debugging a numerical program is a massive time-sink. The solution is therefore to
179
+ add guardrails to prevent making stupid mistakes, in other words a type system.
180
+
181
+
182
+ ## Running
183
+
184
+ ```bash
185
+ uv run pytest tests/
186
+ ```
187
+
188
+ Backend extras:
189
+
190
+ ```bash
191
+ uv pip install -e ".[jax]" # JAX backend
192
+ uv pip install -e ".[torch]" # Torch backend (lagging)
193
+ ```
@@ -0,0 +1,156 @@
1
+ # Stile: A Type System for Numerical Programs
2
+
3
+ **Formally Verify Your Numerical Programs**
4
+
5
+ Stile is a type system for numerical programs. Describe what your program should compute in a lightweight specification language and Stile will structurally prove the program's adherance to the spec.
6
+
7
+ A short demo of what this lets you write today:
8
+
9
+ ```python
10
+ # Spec: causal flash attention, one line.
11
+ SPEC = (
12
+ "(softmax[nctx where nctx <= qctx]"
13
+ "((qctx dhead, nctx dhead -> qctx nctx) / sqrt(16)), "
14
+ "nctx dhead -> qctx dhead)"
15
+ )
16
+
17
+ # Kernel: tile-walk only the lower triangle, online softmax, bias mask
18
+ # on the diagonal tile. Per-qctx-tile structural proof.
19
+ for iqctx in range(0, qctx.size, qctx_tile_size):
20
+ running_max, running_l, o = -jnp.inf, 0, 0
21
+ for ictx in range(0, iqctx + qctx_tile_size, nctx_tile_size):
22
+ q_tile = Q.slice(qctx, iqctx, iqctx + qctx_tile_size)
23
+ k_tile = K.slice(nctx, ictx, ictx + nctx_tile_size)
24
+ qk = tjax.einsum(q_tile, k_tile, "qctx dhead, nctx dhead -> nctx qctx") / jnp.sqrt(dhead.size)
25
+ # Bias-form mask: 0 inside causal region, -inf outside.
26
+ qk = qk + tjax.mask(qk.type.st, "nctx <= qctx", 0.0, -jnp.inf)
27
+ # ... online-softmax accumulator over qk, V ...
28
+ o.assert_equivalent(SPEC, nctx[:(iqctx + qctx_tile_size)]) # proof
29
+ ```
30
+
31
+ The verifier proves the kernel's online softmax, with `-inf` bias on the diagonal tile and skipped upper-triangle tiles, normalizes to the explicit causal-attention spec.
32
+
33
+ ## How it works
34
+
35
+ A Stile-typed tensor has two types:
36
+ - **`ShapeType`**: which logical dims the tensor has, and how each is sliced.
37
+ - **`ExprType`**: the AST of operations performed up to this point.
38
+
39
+ `ShapeType` is enforced eagerly: you can only multiply two tensors with matching dim signatures, you must reduce the entire dim before assigning the result, etc.
40
+
41
+ `ExprType` is checked at assignment time. The verifier normalizes both your kernel's accumulated `ExprType` and the spec's parsed `ExprType` into a canonical form and compares. The canonical form folds:
42
+ - algebraic identities (`exp(0) = 1`, `0 + x = x`, `0 * x = 0`, `exp(a-b) = exp(a)/exp(b)`, `exp(-inf) = 0`, …),
43
+ - distribution through tagged tensors (`*` and `+` push through `Cond(D, …)` branches),
44
+ - iteration-domain folding (`sum(body * Cond(D, 1, 0))` collapses to a sum over `[0, N) ∩ D`),
45
+ - adjacent-tile interval merging on sum and max reductions, including with shared cross-variable predicates,
46
+ - max push-through and `-inf` absorption (so bias-form masks converge with multiplicative-form masks),
47
+ - post-fold invariant hoisting (the piece that lets the kernel's `exp(max)` rescaling factor cancel between numerator and denominator).
48
+
49
+ If the two normalize to the same expression, they compute the same function.
50
+
51
+ ## The specification language
52
+
53
+ A specification is a small expression over named dims. Tensors are written as the sequence of their dims (`Q D` is a tensor of shape `Q × D`). Slices use `D[a:b]`. Einsums use `,` to separate operands and `->` to give the output shape.
54
+
55
+ ```
56
+ Q D K # a tensor with three dims
57
+ Q D[0:8] # tensor with D sliced to [0, 8)
58
+ 2 * Q D # scaled tensor
59
+ A B + A B # addition (same-shape required)
60
+ exp(Q D) # elementwise unary
61
+ (Q D, K D -> Q K) # einsum: sum over D
62
+ sum[K](Q K) # explicit reduction
63
+ sum[K where K <= Q](Q K) # iteration-restricted sum (mult-mask)
64
+ max[K where K <= Q](Q K) # iteration-restricted max (bias-mask)
65
+ softmax[K](Q K) # softmax along K
66
+ softmax[K where K <= Q](Q K) # causal softmax — restricts both num and den
67
+ sum[N](Q N where N >= 4) -> Q # mult-mask `where`-clause inside sum
68
+ ```
69
+
70
+ **`[d where P]`** restricts the iteration of a reduction by the affine predicate `P`. Lowers based on the surrounding op:
71
+ - `sum[d where P]` uses a multiplicative mask (`P` zero-elsewhere).
72
+ - `max[d where P]` and `softmax[d where P]` use a bias mask (`-inf` elsewhere) — masked positions vanish through `max`'s identity and `exp`'s zero, restricting both the numerator's exp and the denominator's sum.
73
+
74
+ **`body where P`** (outside a dim annotation) is always a multiplicative mask on `body`. Use it for non-reduction sparsity and bias-on-output patterns. Inside a `sum` it folds into the reduce's domain via mask-extraction; outside it stays as a `Cond` tag.
75
+
76
+ **Predicates** are conjunctions of affine inequalities over dim names: `<=`, `<`, `>=`, `>`, `==`, plus `+`, `-`, and `int * dim`. Cross-dim predicates (`nctx <= qctx`) are first-class — they ride along in the reduce's domain and survive interval merging.
77
+
78
+ ## Kernel primitives
79
+
80
+ The kernel side mirrors a small slice of JAX/Numpy. Today's main backend is `stile.jax` (`tjax`); a numpy backend exists for prototyping.
81
+
82
+ ```python
83
+ from stile import dim
84
+ import stile.jax as tjax
85
+
86
+ # Dims live in a global registry so specs and kernels share names.
87
+ qctx, nctx, dhead = dim('qctx', 128), dim('nctx', 512), dim('dhead', 16)
88
+
89
+ # Wrap concrete arrays with their dim signature.
90
+ Q = tjax.random.normal(key, qctx, dhead) # has ShapeType (qctx, dhead)
91
+
92
+ # Slice. Result's st remembers it's [iqctx, iqctx+T).
93
+ q_tile = Q.slice(qctx, iqctx, iqctx + 32)
94
+
95
+ # Einsum. Both shapes and the AST track the contraction.
96
+ qk = tjax.einsum(Q, K, "qctx dhead, nctx dhead -> qctx nctx")
97
+
98
+ # Reductions. Either a method or via einsum.
99
+ m = qk.max(nctx)
100
+ s = qk.sum(nctx)
101
+
102
+ # Unary functions on TypedJaxArrays.
103
+ e = tjax.exp(qk - m.repeat(nctx))
104
+
105
+ # Multiplicative mask sugar — score * Cond(P, 1, 0).
106
+ masked = score.where("nctx <= qctx")
107
+
108
+ # Tagged-constant tensor — picks 0/1, 0/-inf, etc.
109
+ mult_mask = tjax.mask(score.type.st, "nctx <= qctx") # 1 inside / 0 outside
110
+ bias_mask = tjax.mask(score.type.st, "nctx <= qctx", 0.0, -jnp.inf) # 0 / -inf
111
+
112
+ # Rolled loops. Concrete bounds unroll; symbolic bounds emit a parametric reduce.
113
+ total = tjax.fori_loop(0, n, lambda i, acc: acc + body(i), init_val=0.0)
114
+
115
+ # Verify against a spec.
116
+ result = tjax.TypedResult(SPEC)
117
+ result.assign(o) # full-coverage type check
118
+ o.assert_equivalent(SPEC, nctx[:K]) # per-tile check with a slice override
119
+ result.done() # tile-coverage check (no gaps/overlaps)
120
+ ```
121
+
122
+ ## Status
123
+
124
+ Working today, with full structural verification:
125
+ - **Backends**: `stile.jax` (primary), `stile.numpy` (prototype).
126
+ - **Verified kernels**: matmul, online softmax, full flash attention, **tile-walking causal flash attention** (online softmax with bias-mask; structurally proven equivalent to a one-line `softmax[k where k<=q]` spec).
127
+ - **Spec features**: einsums, slices, reductions (`sum`, `max`, `softmax`), unary (`exp`, `sin`, `cos`, `sqrt`), multiplicative `where`-clauses, iteration-restricted `[d where P]` annotations, affine predicates with cross-dim references.
128
+ - **Kernel features**: slicing, einsum, all the unary/binary ops, `repeat`, `rearrange`, `fori_loop` (concrete-unroll path; symbolic-loop path with parametric reductions), `mask` intrinsic and `.where(...)` sugar.
129
+
130
+ In progress / future:
131
+ - **TypedPallas**: same type discipline, lowering to Pallas for actual GPU/TPU codegen.
132
+ - **TypedTorch**: a Torch backend exists but lags the JAX one.
133
+
134
+ ## Why a type system?
135
+
136
+ Way back in the 1950s, before high-level languages, programs took big arrays of bytes as input and outputted other arrays of bytes. It was up to the programmer to remember in his head which bytes corresponded
137
+ to which semantic piece of the program. The invention of type annotations, structures, etc., was a step change in programming productivity because it gave semantic meanings to specific regions of memory.
138
+
139
+ The current state of the art for numerical programs is not unlike the 1950s mode of programming. We take big multidimensional arrays of floats, and output big multidimensional arrays of floats. The dimensions
140
+ are not semantically enforced to be different, and it's up to the programmer to remember the order of dimensions at all times. Mixing them is all too easy. Then assuming a program actually completed, you
141
+ have another big array of floats, and you have no idea if it's right or not. If it doesn't give you the expected result, debugging a numerical program is a massive time-sink. The solution is therefore to
142
+ add guardrails to prevent making stupid mistakes, in other words a type system.
143
+
144
+
145
+ ## Running
146
+
147
+ ```bash
148
+ uv run pytest tests/
149
+ ```
150
+
151
+ Backend extras:
152
+
153
+ ```bash
154
+ uv pip install -e ".[jax]" # JAX backend
155
+ uv pip install -e ".[torch]" # Torch backend (lagging)
156
+ ```
@@ -0,0 +1,76 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "stile-verifier"
7
+ version = "0.1.0"
8
+ description = "A type system for numerical programs: write a spec, get a structural proof your code computes it."
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ license = "MIT"
12
+ license-files = ["LICENSE"]
13
+ authors = [
14
+ { name = "Sasha Krassovsky", email = "krassovskysasha@gmail.com" },
15
+ ]
16
+ keywords = [
17
+ "verification", "type-system", "numerical", "jax", "pytorch",
18
+ "triton", "pallas", "kernels", "einsum",
19
+ ]
20
+ classifiers = [
21
+ "Development Status :: 3 - Alpha",
22
+ "Intended Audience :: Developers",
23
+ "Intended Audience :: Science/Research",
24
+ "Topic :: Scientific/Engineering",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ "Topic :: Software Development :: Libraries :: Python Modules",
27
+ "Programming Language :: Python :: 3",
28
+ "Programming Language :: Python :: 3.11",
29
+ "Programming Language :: Python :: 3.12",
30
+ "Programming Language :: Python :: Implementation :: CPython",
31
+ ]
32
+ dependencies = [
33
+ "einops>=0.8.1",
34
+ "numpy>=1.24.4",
35
+ ]
36
+
37
+ [project.optional-dependencies]
38
+ jax = ["jax>=0.4.0"]
39
+ torch = ["torch>=2.0.0"]
40
+ triton = [
41
+ "torch>=2.0.0",
42
+ "triton>=3.0.0",
43
+ ]
44
+ all = [
45
+ "jax>=0.4.0",
46
+ "torch>=2.0.0",
47
+ "triton>=3.0.0",
48
+ ]
49
+
50
+ [project.urls]
51
+ Homepage = "https://github.com/save-buffer/stile"
52
+ Repository = "https://github.com/save-buffer/stile"
53
+ Issues = "https://github.com/save-buffer/stile/issues"
54
+
55
+ [tool.hatch.build.targets.wheel]
56
+ packages = ["stile"]
57
+
58
+ [tool.hatch.build.targets.sdist]
59
+ include = [
60
+ "stile/",
61
+ "tests/",
62
+ "README.md",
63
+ "LICENSE",
64
+ "pyproject.toml",
65
+ ]
66
+ exclude = [
67
+ "**/__pycache__",
68
+ "*.pyc",
69
+ ]
70
+
71
+ [dependency-groups]
72
+ dev = [
73
+ "pytest>=9.0.2",
74
+ "build>=1.2.0",
75
+ "twine>=5.0.0",
76
+ ]
@@ -0,0 +1,65 @@
1
+ from .verification import verify_exprs_equivalent
2
+ from .specification import parse_spec_into_type
3
+ from .type import (
4
+ Type, FullDim, g_dim_registry, Tensor, Constant, TagCond,
5
+ _reset_tensor_counter,
6
+ )
7
+ from .indexing import (
8
+ SymbolicInt, LoopVariable, AffineExpr, SymbolicIndex, Domain, range_domain,
9
+ LoopScope, loop, active_loop_domain, _active_loop_scopes,
10
+ RuntimeScalar, runtime_scalar, _g_runtime_scalars,
11
+ SymInfo, symint_info, _g_symint_metadata,
12
+ tensor_element,
13
+ declare_index_properties, index_has_property, _g_index_properties,
14
+ declare_block_pairing, paired_index_for_offsets, _g_block_pairings,
15
+ declare_tensor_boundary, tensor_boundary, resolve_symbolic_index,
16
+ _g_tensor_boundaries,
17
+ )
18
+ from .tracing import _g_runtime_arrs
19
+
20
+ def dim(name : str, size : int) -> FullDim:
21
+ return FullDim(name, size)
22
+
23
+ def expr_simplifies(
24
+ expr : Type,
25
+ spec : str,
26
+ ) -> bool:
27
+ spec_type = parse_spec_into_type(spec)
28
+ return verify_exprs_equivalent(expr.type.et, spec_type.et)
29
+
30
+ def reset_stile():
31
+ g_dim_registry.clear()
32
+ _active_loop_scopes.clear()
33
+ _g_runtime_scalars.clear()
34
+ _g_symint_metadata.clear()
35
+ _g_index_properties.clear()
36
+ _g_block_pairings.clear()
37
+ _g_tensor_boundaries.clear()
38
+ _g_runtime_arrs.clear()
39
+ _reset_tensor_counter()
40
+
41
+
42
+ def mask_expr(
43
+ dims : tuple[FullDim, ...],
44
+ domain : Domain,
45
+ ) -> Tensor:
46
+ """
47
+ A tagged Tensor whose value is `1` on positions in `domain` and `0`
48
+ elsewhere. The tag is `Cond(domain, Value(1), Value(0))`. `domain`'s
49
+ constraints should reference `LoopVariable`s named after the tensor's
50
+ dims — those are the symbolic dim-indices.
51
+
52
+ Common masks (causal, band, block-diagonal) are library wrappers over
53
+ this primitive, produced by constructing the appropriate `Domain`.
54
+ """
55
+ tag = TagCond(
56
+ domain=domain,
57
+ if_true=Constant(1.0),
58
+ if_false=Constant(0.0),
59
+ )
60
+ # Mask tensors use a fixed name so two `mask_expr` calls with the
61
+ # same dims+predicate produce equal tensors (the tag carries the
62
+ # identifying info).
63
+ return Tensor(dims=dims, tag=tag, name="_mask")
64
+
65
+