sutra-dev 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sutra_compiler/__init__.py +49 -0
- sutra_compiler/__main__.py +514 -0
- sutra_compiler/ast_nodes.py +553 -0
- sutra_compiler/codegen.py +1811 -0
- sutra_compiler/codegen_base.py +2436 -0
- sutra_compiler/codegen_pytorch.py +1472 -0
- sutra_compiler/diagnostics.py +145 -0
- sutra_compiler/inliner.py +581 -0
- sutra_compiler/lexer.py +821 -0
- sutra_compiler/parser.py +2112 -0
- sutra_compiler/review.py +322 -0
- sutra_compiler/simplify.py +1046 -0
- sutra_compiler/simplify_egglog.py +674 -0
- sutra_compiler/stdlib/axons.su +53 -0
- sutra_compiler/stdlib/embed.su +48 -0
- sutra_compiler/stdlib/javascript_object.su +18 -0
- sutra_compiler/stdlib/logic.su +202 -0
- sutra_compiler/stdlib/math.su +12 -0
- sutra_compiler/stdlib/memory.su +82 -0
- sutra_compiler/stdlib/numbers.su +99 -0
- sutra_compiler/stdlib/rotation.su +83 -0
- sutra_compiler/stdlib/similarity.su +97 -0
- sutra_compiler/stdlib/strings.su +56 -0
- sutra_compiler/stdlib/tensor.su +82 -0
- sutra_compiler/stdlib/vectors.su +119 -0
- sutra_compiler/stdlib_loader.py +219 -0
- sutra_compiler/sutradb_embedded.py +273 -0
- sutra_compiler/trace.py +135 -0
- sutra_compiler/validator.py +552 -0
- sutra_compiler/workspace.py +655 -0
- sutra_dev-0.2.0.dist-info/METADATA +80 -0
- sutra_dev-0.2.0.dist-info/RECORD +36 -0
- sutra_dev-0.2.0.dist-info/WHEEL +5 -0
- sutra_dev-0.2.0.dist-info/entry_points.txt +2 -0
- sutra_dev-0.2.0.dist-info/licenses/LICENSE +201 -0
- sutra_dev-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1472 @@
|
|
|
1
|
+
"""AST -> PyTorch/CUDA Python source translator.
|
|
2
|
+
|
|
3
|
+
The GPU path. Emits self-contained Python modules that depend only on
|
|
4
|
+
torch (numpy is still imported for a single bridge at ingestion time —
|
|
5
|
+
Ollama hands us lists of floats and we construct tensors from them).
|
|
6
|
+
Ops run as torch tensors; when CUDA is available the module picks
|
|
7
|
+
`cuda` as its device automatically, falling back to `cpu` otherwise.
|
|
8
|
+
|
|
9
|
+
Relationship to the CPU codegen:
|
|
10
|
+
|
|
11
|
+
BaseCodegen ← backend-agnostic AST walker
|
|
12
|
+
└── Codegen ← canonical CPU path (numpy ndarrays)
|
|
13
|
+
└── PyTorchCodegen ← GPU path (torch tensors)
|
|
14
|
+
|
|
15
|
+
PyTorchCodegen inherits the translator from `Codegen` (same AST walk,
|
|
16
|
+
same bundle-of-binds fusion, same vector-accessor lowering, same
|
|
17
|
+
extended-state-vector layout) and only overrides the prelude so the
|
|
18
|
+
emitted runtime class is `_TorchVSA` operating on tensors. The fused
|
|
19
|
+
shapes that the simplifier and codegen produce (stacked Q matmul via
|
|
20
|
+
einsum, stacked candidate matmul for argmax_cosine) collapse O(N)
|
|
21
|
+
small kernel launches into O(1) large ones on GPU — which is the
|
|
22
|
+
reason this backend exists.
|
|
23
|
+
|
|
24
|
+
Extended state vector and canonical axis allocation are preserved
|
|
25
|
+
exactly: every tensor is `[semantic (semantic_dim) | synthetic
|
|
26
|
+
(synthetic_dim)]`, bind rotation is block-diagonal with identity on
|
|
27
|
+
the synthetic block, `synthetic[0..2]` are the canonical real/imag/
|
|
28
|
+
truth axes per the 2026-04-23 design.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
from . import ast_nodes as ast
|
|
34
|
+
from .codegen_base import CodegenNotSupported
|
|
35
|
+
from .codegen import Codegen
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PyTorchCodegen(Codegen):
|
|
39
|
+
"""Emits a self-contained torch module.
|
|
40
|
+
|
|
41
|
+
Inherits the entire translator from `Codegen` and only overrides the
|
|
42
|
+
prelude. Vector accessor methods (`.component()`, `.real()`, etc.)
|
|
43
|
+
still route through `_VSA.*` calls — the runtime method names match
|
|
44
|
+
the CPU codegen so the translator needs no divergence.
|
|
45
|
+
|
|
46
|
+
Bool literal lowering is inherited from `Codegen` (true/false →
|
|
47
|
+
make_truth(±1)); logical ops (`!`, `&&`, `||`) likewise inherit the
|
|
48
|
+
base override and resolve against the torch runtime's make_truth /
|
|
49
|
+
_as_truth_vector.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def _emit_select_helper(self) -> None:
|
|
53
|
+
"""Torch-based softmax for the Sutra `select` primitive.
|
|
54
|
+
|
|
55
|
+
Same numerical shape as the numpy version (subtract max for
|
|
56
|
+
stability, exp, normalize, weighted sum), all on tensors so the
|
|
57
|
+
whole path stays on the chosen device.
|
|
58
|
+
"""
|
|
59
|
+
self._emit("def _select_softmax(scores, options):")
|
|
60
|
+
self._indent += 1
|
|
61
|
+
self._emit('"""Softmax-weighted superposition of option vectors (torch)."""')
|
|
62
|
+
self._emit("s = _torch.as_tensor(scores, dtype=_DTYPE, device=_DEVICE)")
|
|
63
|
+
self._emit("s = s - _torch.amax(s)")
|
|
64
|
+
self._emit("w = _torch.exp(s)")
|
|
65
|
+
self._emit("w = w / _torch.sum(w)")
|
|
66
|
+
self._emit("opts = _torch.stack([")
|
|
67
|
+
self._indent += 1
|
|
68
|
+
self._emit("_torch.as_tensor(o, dtype=_DTYPE, device=_DEVICE)")
|
|
69
|
+
self._emit("for o in options")
|
|
70
|
+
self._indent -= 1
|
|
71
|
+
self._emit("])")
|
|
72
|
+
self._emit("return (w[:, None] * opts).sum(dim=0)")
|
|
73
|
+
self._indent -= 1
|
|
74
|
+
|
|
75
|
+
def _translate_var_decl_zero_init(self, decl): # pragma: no cover — helper
|
|
76
|
+
# Not actually used by the parent directly; the parent inlines
|
|
77
|
+
# the `_np.zeros(_VSA.dim)` string. We patch at translate time
|
|
78
|
+
# by string replacement below.
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
def translate(self, module: ast.Module) -> str:
|
|
82
|
+
"""Translate and then patch any `_np.zeros(_VSA.dim)` emissions.
|
|
83
|
+
|
|
84
|
+
The parent class hard-codes `_np.zeros(_VSA.dim)` for
|
|
85
|
+
uninitialized-vector declarations. The pytorch backend has no
|
|
86
|
+
`_np` symbol in scope, so any such emission would crash at
|
|
87
|
+
module init. We post-process the output to swap those specific
|
|
88
|
+
string occurrences to the torch equivalent. Everything else is
|
|
89
|
+
emitted directly as torch via `_emit_prelude`.
|
|
90
|
+
|
|
91
|
+
Then optionally appends a `torch.compile` wrapping block for
|
|
92
|
+
every loop function. Gated on env var SUTRA_TORCH_COMPILE=1 —
|
|
93
|
+
default off because the first call pays a graph-capture cost
|
|
94
|
+
that dwarfs the runtime for tiny loops; opt-in for the cases
|
|
95
|
+
where the speedup pays back the warmup.
|
|
96
|
+
"""
|
|
97
|
+
out = super().translate(module)
|
|
98
|
+
out = out.replace(
|
|
99
|
+
"_np.zeros(_VSA.dim)",
|
|
100
|
+
"_torch.zeros(_VSA.dim, dtype=_DTYPE, device=_DEVICE)",
|
|
101
|
+
)
|
|
102
|
+
# Append torch.compile wrapping for each loop function. Each
|
|
103
|
+
# wrap is guarded by env var SUTRA_TORCH_COMPILE. The wrap
|
|
104
|
+
# fuses the T-step soft-halt cell + body tensor ops into a
|
|
105
|
+
# single graph; substantial speedup on GPU for hot loops, but
|
|
106
|
+
# graph-capture overhead can dominate cold-start for small T.
|
|
107
|
+
if self._loop_decls:
|
|
108
|
+
wrap_lines = [
|
|
109
|
+
"",
|
|
110
|
+
"",
|
|
111
|
+
"# Optional torch.compile wrapping for loop functions.",
|
|
112
|
+
"# Enable via SUTRA_TORCH_COMPILE=1.",
|
|
113
|
+
"import os as _sutra_compile_os",
|
|
114
|
+
"if _sutra_compile_os.environ.get('SUTRA_TORCH_COMPILE'):",
|
|
115
|
+
" try:",
|
|
116
|
+
]
|
|
117
|
+
for loop_name in self._loop_decls.keys():
|
|
118
|
+
# backend='eager' does graph capture (Dynamo trace) without
|
|
119
|
+
# requiring Triton. The default 'inductor' backend produces
|
|
120
|
+
# fused CUDA kernels but needs Triton, which isn't bundled
|
|
121
|
+
# in standard torch installs. Eager is correct + portable;
|
|
122
|
+
# users who want fused kernels can rebuild with Triton and
|
|
123
|
+
# set SUTRA_TORCH_COMPILE_BACKEND=inductor.
|
|
124
|
+
# Class-bodied loops have dotted registry keys
|
|
125
|
+
# (`Greeter.run`); the emitted Python identifier mangles
|
|
126
|
+
# `.` to `_` so it's a valid Python attribute name.
|
|
127
|
+
py_loop_name = f"_loop_{loop_name.replace('.', '_')}"
|
|
128
|
+
wrap_lines.append(
|
|
129
|
+
f" {py_loop_name} = _torch.compile("
|
|
130
|
+
f"{py_loop_name}, "
|
|
131
|
+
f"backend=_sutra_compile_os.environ.get("
|
|
132
|
+
f"'SUTRA_TORCH_COMPILE_BACKEND', 'eager'))"
|
|
133
|
+
)
|
|
134
|
+
wrap_lines.extend([
|
|
135
|
+
" except Exception:",
|
|
136
|
+
" pass # torch.compile not available or trace failed",
|
|
137
|
+
"",
|
|
138
|
+
])
|
|
139
|
+
out = out + "\n".join(wrap_lines)
|
|
140
|
+
return out
|
|
141
|
+
|
|
142
|
+
def _emit_prelude(self) -> None:
|
|
143
|
+
self._emit('"""Generated by sutra_compiler.codegen_pytorch. Do not edit by hand."""')
|
|
144
|
+
self._emit("from __future__ import annotations")
|
|
145
|
+
self._emit()
|
|
146
|
+
self._emit("import torch as _torch")
|
|
147
|
+
self._emit()
|
|
148
|
+
self._emit("# Pick device and dtype once at module import. CUDA is preferred")
|
|
149
|
+
self._emit("# because the whole reason for this backend is to collapse the")
|
|
150
|
+
self._emit("# fused bind / bundle / argmax_cosine shapes into single big")
|
|
151
|
+
self._emit("# kernel launches on GPU. CPU fallback keeps the module usable")
|
|
152
|
+
self._emit("# on machines without CUDA — the numerics are identical.")
|
|
153
|
+
self._emit("_DEVICE = _torch.device('cuda' if _torch.cuda.is_available() else 'cpu')")
|
|
154
|
+
self._emit("# float32 on GPU is the fast path; keep dtype consistent across")
|
|
155
|
+
self._emit("# every tensor so einsum / matmul don't trigger implicit upcasts.")
|
|
156
|
+
self._emit("_DTYPE = _torch.float32")
|
|
157
|
+
self._emit()
|
|
158
|
+
self._emit()
|
|
159
|
+
self._emit("class _TorchVSA:")
|
|
160
|
+
self._indent += 1
|
|
161
|
+
self._emit('"""Torch-backed VSA runtime. Rotation binding, normalized bundle.')
|
|
162
|
+
self._emit('')
|
|
163
|
+
self._emit('State tensors carry the extended layout:')
|
|
164
|
+
self._emit('`[semantic (semantic_dim) | synthetic (synthetic_dim)]`. The')
|
|
165
|
+
self._emit('semantic block is filled by `embed()` from the frozen LLM; the')
|
|
166
|
+
self._emit('synthetic block is reserved computational/symbolic space with')
|
|
167
|
+
self._emit('canonical axes at synthetic[0..2] (real, imag, truth). See')
|
|
168
|
+
self._emit('planning/findings/2026-04-21-extended-state-and-rotation-binding.md.')
|
|
169
|
+
self._emit('')
|
|
170
|
+
self._emit('Bind is role-seeded Haar-random orthogonal rotation applied to')
|
|
171
|
+
self._emit('filler: bind(filler, role) = Q_role @ filler. The rotation is')
|
|
172
|
+
self._emit('block-diagonal — Haar in the semantic block, identity in the')
|
|
173
|
+
self._emit('synthetic block — so rotation acts only on semantic content and')
|
|
174
|
+
self._emit('the synthetic block is preserved through bind/unbind.')
|
|
175
|
+
self._emit('"""')
|
|
176
|
+
self._emit()
|
|
177
|
+
self._emit("# Canonical synthetic-axis allocation — real, imag, truth at")
|
|
178
|
+
self._emit("# synthetic[0..2], string-flag at synthetic[3], loop-done at")
|
|
179
|
+
self._emit("# synthetic[4]. Mirrored from the CPU runtime so the two agree")
|
|
180
|
+
self._emit("# bit-for-bit on layout. AXIS_LOOP_DONE is the substrate-side")
|
|
181
|
+
self._emit("# completion flag set by the RNN-style branchless loop.")
|
|
182
|
+
self._emit("# AXIS_STRING_FLAG marks a vector as a String value (a")
|
|
183
|
+
self._emit("# packed array of codepoints — 1-character strings are the")
|
|
184
|
+
self._emit("# new home for what was formerly the `char` type). See")
|
|
185
|
+
self._emit("# planning/sutra-spec/strings.md.")
|
|
186
|
+
self._emit("AXIS_REAL = 0")
|
|
187
|
+
self._emit("AXIS_IMAG = 1")
|
|
188
|
+
self._emit("AXIS_TRUTH = 2")
|
|
189
|
+
self._emit("AXIS_STRING_FLAG = 3")
|
|
190
|
+
self._emit("# Backwards-compat alias for code that still references")
|
|
191
|
+
self._emit("# AXIS_CHAR_FLAG. New code should use AXIS_STRING_FLAG.")
|
|
192
|
+
self._emit("AXIS_CHAR_FLAG = 3")
|
|
193
|
+
self._emit("AXIS_LOOP_DONE = 4")
|
|
194
|
+
self._emit()
|
|
195
|
+
self._emit("def __init__(self, semantic_dim, synthetic_dim, seed, llm_model):")
|
|
196
|
+
self._indent += 1
|
|
197
|
+
self._emit("self.semantic_dim = semantic_dim")
|
|
198
|
+
self._emit("self.synthetic_dim = synthetic_dim")
|
|
199
|
+
self._emit("self.dim = semantic_dim + synthetic_dim")
|
|
200
|
+
self._emit("self.seed = seed")
|
|
201
|
+
self._emit("self.llm_model = llm_model")
|
|
202
|
+
self._emit("self.device = _DEVICE")
|
|
203
|
+
self._emit("self.dtype = _DTYPE")
|
|
204
|
+
self._emit("self._codebook = {}")
|
|
205
|
+
self._emit("# Rotation matrix cache: role-hash -> tensor on self.device.")
|
|
206
|
+
self._emit("# Generating a 768x768 Haar rotation is O(d^3) on CPU (seeded")
|
|
207
|
+
self._emit("# via numpy for Haar-uniformity). Cached on the GPU after the")
|
|
208
|
+
self._emit("# first draw so repeated bind/unbind with the same role is a")
|
|
209
|
+
self._emit("# lookup + one matmul, no transfer.")
|
|
210
|
+
self._emit("self._rot_cache = {}")
|
|
211
|
+
self._emit("# On-disk embedding cache. Keyed by (model, dim) so switching")
|
|
212
|
+
self._emit("# embedding model OR changing the extended-state dim invalidates")
|
|
213
|
+
self._emit("# automatically (different filename). Torch cache uses .pt so")
|
|
214
|
+
self._emit("# it doesn't collide with the numpy backend's .npz.")
|
|
215
|
+
self._emit("import os as _os")
|
|
216
|
+
self._emit("self._cache_dir = _os.path.join(")
|
|
217
|
+
self._indent += 1
|
|
218
|
+
self._emit("_os.environ.get('XDG_CACHE_HOME', _os.path.expanduser('~/.cache')),")
|
|
219
|
+
self._emit("'sutra', 'embeddings')")
|
|
220
|
+
self._indent -= 1
|
|
221
|
+
self._emit("_os.makedirs(self._cache_dir, exist_ok=True)")
|
|
222
|
+
self._emit("_safe_model = llm_model.replace('/', '_').replace(':', '_')")
|
|
223
|
+
self._emit("self._cache_path = _os.path.join(")
|
|
224
|
+
self._indent += 1
|
|
225
|
+
self._emit("self._cache_dir, f'{_safe_model}-d{self.dim}.pt')")
|
|
226
|
+
self._indent -= 1
|
|
227
|
+
self._emit("self._load_disk_cache()")
|
|
228
|
+
self._indent -= 1
|
|
229
|
+
self._emit()
|
|
230
|
+
self._emit("def _load_disk_cache(self):")
|
|
231
|
+
self._indent += 1
|
|
232
|
+
self._emit('"""Populate self._codebook from disk if the cache file exists.')
|
|
233
|
+
self._emit('')
|
|
234
|
+
self._emit("Tolerant of missing or corrupt files — a failed load just leaves")
|
|
235
|
+
self._emit("the codebook empty and lets Ollama repopulate it.")
|
|
236
|
+
self._emit('"""')
|
|
237
|
+
self._emit("import os as _os")
|
|
238
|
+
self._emit("if not _os.path.exists(self._cache_path):")
|
|
239
|
+
self._indent += 1
|
|
240
|
+
self._emit("return")
|
|
241
|
+
self._indent -= 1
|
|
242
|
+
self._emit("try:")
|
|
243
|
+
self._indent += 1
|
|
244
|
+
self._emit("data = _torch.load(self._cache_path, map_location=self.device, weights_only=True)")
|
|
245
|
+
self._emit("for key, tensor in data.items():")
|
|
246
|
+
self._indent += 1
|
|
247
|
+
self._emit("self._codebook[key] = tensor.to(dtype=self.dtype)")
|
|
248
|
+
self._indent -= 1
|
|
249
|
+
self._indent -= 1
|
|
250
|
+
self._emit("except Exception:")
|
|
251
|
+
self._indent += 1
|
|
252
|
+
self._emit("# Corrupt cache: ignore and let Ollama repopulate.")
|
|
253
|
+
self._emit("self._codebook = {}")
|
|
254
|
+
self._indent -= 1
|
|
255
|
+
self._indent -= 1
|
|
256
|
+
self._emit()
|
|
257
|
+
self._emit("def _write_disk_cache(self):")
|
|
258
|
+
self._indent += 1
|
|
259
|
+
self._emit('"""Persist self._codebook to disk via tempfile + atomic rename.')
|
|
260
|
+
self._emit('')
|
|
261
|
+
self._emit("A partial write (crash, SIGKILL) leaves the old cache intact")
|
|
262
|
+
self._emit("rather than corrupted.")
|
|
263
|
+
self._emit('"""')
|
|
264
|
+
self._emit("import os as _os, tempfile as _tempfile")
|
|
265
|
+
self._emit("if not self._codebook:")
|
|
266
|
+
self._indent += 1
|
|
267
|
+
self._emit("return")
|
|
268
|
+
self._indent -= 1
|
|
269
|
+
self._emit("fd, tmp = _tempfile.mkstemp(")
|
|
270
|
+
self._indent += 1
|
|
271
|
+
self._emit("dir=self._cache_dir, prefix='.tmp-', suffix='.pt')")
|
|
272
|
+
self._indent -= 1
|
|
273
|
+
self._emit("_os.close(fd)")
|
|
274
|
+
self._emit("try:")
|
|
275
|
+
self._indent += 1
|
|
276
|
+
self._emit("# Save tensors on CPU so the cache file is portable — the")
|
|
277
|
+
self._emit("# next run can load on any device. Reload will move them.")
|
|
278
|
+
self._emit("cpu_codebook = {k: v.detach().cpu() for k, v in self._codebook.items()}")
|
|
279
|
+
self._emit("_torch.save(cpu_codebook, tmp)")
|
|
280
|
+
self._emit("_os.replace(tmp, self._cache_path)")
|
|
281
|
+
self._indent -= 1
|
|
282
|
+
self._emit("except Exception:")
|
|
283
|
+
self._indent += 1
|
|
284
|
+
self._emit("try:")
|
|
285
|
+
self._indent += 1
|
|
286
|
+
self._emit("_os.unlink(tmp)")
|
|
287
|
+
self._indent -= 1
|
|
288
|
+
self._emit("except OSError:")
|
|
289
|
+
self._indent += 1
|
|
290
|
+
self._emit("pass")
|
|
291
|
+
self._indent -= 1
|
|
292
|
+
self._indent -= 1
|
|
293
|
+
self._indent -= 1
|
|
294
|
+
self._emit()
|
|
295
|
+
self._emit("def embed(self, name):")
|
|
296
|
+
self._indent += 1
|
|
297
|
+
self._emit('"""Frozen-LLM embedding via Ollama. Returns a tensor on self.device.')
|
|
298
|
+
self._emit('')
|
|
299
|
+
self._emit("Extended-state layout: `[semantic (semantic_dim) | zeros (synthetic_dim)]`.")
|
|
300
|
+
self._emit("No random fallback — if Ollama is unavailable this raises.")
|
|
301
|
+
self._emit('"""')
|
|
302
|
+
self._emit("if name not in self._codebook:")
|
|
303
|
+
self._indent += 1
|
|
304
|
+
self._emit("import ollama")
|
|
305
|
+
self._emit("r = ollama.embed(model=self.llm_model, input=name)")
|
|
306
|
+
self._emit("v = _torch.tensor(r['embeddings'][0], dtype=self.dtype, device=self.device)")
|
|
307
|
+
self._emit("# Mean-center; raw LLM embeddings cluster in a cone and centering")
|
|
308
|
+
self._emit("# keeps rotation/bind algebra well-behaved.")
|
|
309
|
+
self._emit("v = v - _torch.mean(v)")
|
|
310
|
+
self._emit("n = _torch.linalg.norm(v)")
|
|
311
|
+
self._emit("if n > 0: v = v / n")
|
|
312
|
+
self._emit("# Fit to semantic block.")
|
|
313
|
+
self._emit("if v.shape[0] > self.semantic_dim:")
|
|
314
|
+
self._indent += 1
|
|
315
|
+
self._emit("v = v[:self.semantic_dim]")
|
|
316
|
+
self._indent -= 1
|
|
317
|
+
self._emit("elif v.shape[0] < self.semantic_dim:")
|
|
318
|
+
self._indent += 1
|
|
319
|
+
self._emit("pad = _torch.zeros(self.semantic_dim - v.shape[0], dtype=self.dtype, device=self.device)")
|
|
320
|
+
self._emit("v = _torch.cat([v, pad])")
|
|
321
|
+
self._indent -= 1
|
|
322
|
+
self._emit("# Append synthetic block — reserved, starts zero.")
|
|
323
|
+
self._emit("syn = _torch.zeros(self.synthetic_dim, dtype=self.dtype, device=self.device)")
|
|
324
|
+
self._emit("v = _torch.cat([v, syn])")
|
|
325
|
+
self._emit("n = _torch.linalg.norm(v)")
|
|
326
|
+
self._emit("if n > 0: v = v / n")
|
|
327
|
+
self._emit("self._codebook[name] = v")
|
|
328
|
+
self._emit("self._write_disk_cache()")
|
|
329
|
+
self._indent -= 1
|
|
330
|
+
self._emit("return self._codebook[name].clone()")
|
|
331
|
+
self._indent -= 1
|
|
332
|
+
self._emit()
|
|
333
|
+
self._emit("def embed_batch(self, names):")
|
|
334
|
+
self._indent += 1
|
|
335
|
+
self._emit('"""Batched Ollama embed: one HTTP round-trip for many names.')
|
|
336
|
+
self._emit('')
|
|
337
|
+
self._emit("Same layout as embed(). Writes back to disk once after all")
|
|
338
|
+
self._emit("fetches to amortize the save.")
|
|
339
|
+
self._emit('"""')
|
|
340
|
+
self._emit("missing = [n for n in names if n not in self._codebook]")
|
|
341
|
+
self._emit("if not missing:")
|
|
342
|
+
self._indent += 1
|
|
343
|
+
self._emit("return")
|
|
344
|
+
self._indent -= 1
|
|
345
|
+
self._emit("import ollama")
|
|
346
|
+
self._emit("r = ollama.embed(model=self.llm_model, input=missing)")
|
|
347
|
+
self._emit("for i, name in enumerate(missing):")
|
|
348
|
+
self._indent += 1
|
|
349
|
+
self._emit("v = _torch.tensor(r['embeddings'][i], dtype=self.dtype, device=self.device)")
|
|
350
|
+
self._emit("v = v - _torch.mean(v)")
|
|
351
|
+
self._emit("n = _torch.linalg.norm(v)")
|
|
352
|
+
self._emit("if n > 0: v = v / n")
|
|
353
|
+
self._emit("if v.shape[0] > self.semantic_dim:")
|
|
354
|
+
self._indent += 1
|
|
355
|
+
self._emit("v = v[:self.semantic_dim]")
|
|
356
|
+
self._indent -= 1
|
|
357
|
+
self._emit("elif v.shape[0] < self.semantic_dim:")
|
|
358
|
+
self._indent += 1
|
|
359
|
+
self._emit("pad = _torch.zeros(self.semantic_dim - v.shape[0], dtype=self.dtype, device=self.device)")
|
|
360
|
+
self._emit("v = _torch.cat([v, pad])")
|
|
361
|
+
self._indent -= 1
|
|
362
|
+
self._emit("syn = _torch.zeros(self.synthetic_dim, dtype=self.dtype, device=self.device)")
|
|
363
|
+
self._emit("v = _torch.cat([v, syn])")
|
|
364
|
+
self._emit("n = _torch.linalg.norm(v)")
|
|
365
|
+
self._emit("if n > 0: v = v / n")
|
|
366
|
+
self._emit("self._codebook[name] = v")
|
|
367
|
+
self._indent -= 1
|
|
368
|
+
self._emit("self._write_disk_cache()")
|
|
369
|
+
self._indent -= 1
|
|
370
|
+
self._emit()
|
|
371
|
+
self._emit("# ---- Embedded SutraDB (compile-time string codebook) ----")
|
|
372
|
+
self._emit("# Every embedded string in a Sutra program goes into SutraDB")
|
|
373
|
+
self._emit("# at compile time. The embeddings live in the .sdb file SutraDB")
|
|
374
|
+
self._emit("# manages, not in the Python module's data section. The runtime")
|
|
375
|
+
self._emit("# decodes a query vector back to a string via nearest_string()")
|
|
376
|
+
self._emit("# (the inverse of embed()). Strings declared but not used in")
|
|
377
|
+
self._emit("# expressions are still inserted so they remain decodable.")
|
|
378
|
+
self._emit()
|
|
379
|
+
self._emit("def _ensure_sutradb(self):")
|
|
380
|
+
self._indent += 1
|
|
381
|
+
self._emit('"""Lazy-init the SutraDB handle on first use. Returns None if the')
|
|
382
|
+
self._emit("FFI DLL isn't built (caller decides what to do).")
|
|
383
|
+
self._emit('')
|
|
384
|
+
self._emit("Path resolution:")
|
|
385
|
+
self._emit(" 1. env var SUTRA_DB_PATH if set (persistent across runs)")
|
|
386
|
+
self._emit(" 2. else a tempdir (ephemeral; freed at process exit)")
|
|
387
|
+
self._emit('')
|
|
388
|
+
self._emit("Full atman.toml [vector_db] section is deferred until there's a")
|
|
389
|
+
self._emit("concrete config requirement — env var covers the immediate")
|
|
390
|
+
self._emit("'persistent codebook' use case.")
|
|
391
|
+
self._emit('"""')
|
|
392
|
+
self._emit("if hasattr(self, '_sutradb') and self._sutradb is not None:")
|
|
393
|
+
self._indent += 1
|
|
394
|
+
self._emit("return self._sutradb")
|
|
395
|
+
self._indent -= 1
|
|
396
|
+
self._emit("try:")
|
|
397
|
+
self._indent += 1
|
|
398
|
+
self._emit("import importlib, tempfile, os as _os2")
|
|
399
|
+
self._emit("mod = importlib.import_module('sutra_compiler.sutradb_embedded')")
|
|
400
|
+
self._emit("env_path = _os2.environ.get('SUTRA_DB_PATH')")
|
|
401
|
+
self._emit("if env_path:")
|
|
402
|
+
self._indent += 1
|
|
403
|
+
self._emit("path = env_path")
|
|
404
|
+
self._emit("self._sutradb_tmpdir = None")
|
|
405
|
+
self._indent -= 1
|
|
406
|
+
self._emit("else:")
|
|
407
|
+
self._indent += 1
|
|
408
|
+
self._emit("self._sutradb_tmpdir = tempfile.mkdtemp(prefix='sutra_codebook_')")
|
|
409
|
+
self._emit("path = _os2.path.join(self._sutradb_tmpdir, 'codebook.sdb')")
|
|
410
|
+
self._indent -= 1
|
|
411
|
+
self._emit("self._sutradb = mod.SutraDBEmbedded(path)")
|
|
412
|
+
self._emit("return self._sutradb")
|
|
413
|
+
self._indent -= 1
|
|
414
|
+
self._emit("except Exception:")
|
|
415
|
+
self._indent += 1
|
|
416
|
+
self._emit("self._sutradb = None # mark attempted-and-failed")
|
|
417
|
+
self._emit("return None")
|
|
418
|
+
self._indent -= 1
|
|
419
|
+
self._indent -= 1
|
|
420
|
+
self._emit()
|
|
421
|
+
self._emit("def populate_sutradb(self):")
|
|
422
|
+
self._indent += 1
|
|
423
|
+
self._emit('"""Push every codebook entry into SutraDB.')
|
|
424
|
+
self._emit('')
|
|
425
|
+
self._emit("Called from the codegen prelude after embed_batch finishes")
|
|
426
|
+
self._emit("populating self._codebook. Each (name, vec) becomes a triple")
|
|
427
|
+
self._emit('<urn:sutra:label:NAME> <urn:sutra:embedding> "VEC"^^<f32vec> .')
|
|
428
|
+
self._emit('"""')
|
|
429
|
+
self._emit("db = self._ensure_sutradb()")
|
|
430
|
+
self._emit("if db is None:")
|
|
431
|
+
self._indent += 1
|
|
432
|
+
self._emit("return # FFI unavailable; nearest_string will return None")
|
|
433
|
+
self._indent -= 1
|
|
434
|
+
self._emit("for name, vec in self._codebook.items():")
|
|
435
|
+
self._indent += 1
|
|
436
|
+
self._emit("# Skip non-URL-safe characters in label by URL-quoting.")
|
|
437
|
+
self._emit("import urllib.parse as _urllib_parse")
|
|
438
|
+
self._emit("safe = _urllib_parse.quote(name, safe='')")
|
|
439
|
+
self._emit("vec_list = vec.tolist() if hasattr(vec, 'tolist') else list(vec)")
|
|
440
|
+
self._emit("try:")
|
|
441
|
+
self._indent += 1
|
|
442
|
+
self._emit("db.add(safe, vec_list)")
|
|
443
|
+
self._indent -= 1
|
|
444
|
+
self._emit("except Exception:")
|
|
445
|
+
self._indent += 1
|
|
446
|
+
self._emit("pass # one bad insert shouldn't kill the rest")
|
|
447
|
+
self._indent -= 1
|
|
448
|
+
self._indent -= 1
|
|
449
|
+
self._indent -= 1
|
|
450
|
+
self._emit()
|
|
451
|
+
self._emit("def prewarm_rotation_cache(self):")
|
|
452
|
+
self._indent += 1
|
|
453
|
+
self._emit('"""Pre-compute rotation matrices for every codebook entry.')
|
|
454
|
+
self._emit('')
|
|
455
|
+
self._emit("The runtime never pays the QR construction cost on the hot")
|
|
456
|
+
self._emit("path: pre-warming at module init means every bind/unbind hits")
|
|
457
|
+
self._emit("the cache. Conservative over the codebook (some entries are")
|
|
458
|
+
self._emit("fillers, not roles); the cost is one-time and proportional")
|
|
459
|
+
self._emit("to codebook size.")
|
|
460
|
+
self._emit('"""')
|
|
461
|
+
self._emit("for name, vec in self._codebook.items():")
|
|
462
|
+
self._indent += 1
|
|
463
|
+
self._emit("try:")
|
|
464
|
+
self._indent += 1
|
|
465
|
+
self._emit("self._rotation_for(vec)")
|
|
466
|
+
self._indent -= 1
|
|
467
|
+
self._emit("except Exception:")
|
|
468
|
+
self._indent += 1
|
|
469
|
+
self._emit("pass # one bad rotation shouldn't kill the rest")
|
|
470
|
+
self._indent -= 1
|
|
471
|
+
self._indent -= 1
|
|
472
|
+
self._indent -= 1
|
|
473
|
+
self._emit()
|
|
474
|
+
self._emit("def nearest_string(self, query):")
|
|
475
|
+
self._indent += 1
|
|
476
|
+
self._emit('"""Inverse of embed(): given a query vector, return the nearest')
|
|
477
|
+
self._emit("string from the compile-time-populated SutraDB codebook. None")
|
|
478
|
+
self._emit("if SutraDB is unavailable. The query vector is the full extended-")
|
|
479
|
+
self._emit("state vector; only the semantic block is consulted by SutraDB.")
|
|
480
|
+
self._emit('"""')
|
|
481
|
+
self._emit("db = self._ensure_sutradb()")
|
|
482
|
+
self._emit("if db is None:")
|
|
483
|
+
self._indent += 1
|
|
484
|
+
self._emit("return None")
|
|
485
|
+
self._indent -= 1
|
|
486
|
+
self._emit("q_list = query.tolist() if hasattr(query, 'tolist') else list(query)")
|
|
487
|
+
self._emit("try:")
|
|
488
|
+
self._indent += 1
|
|
489
|
+
self._emit("labels = db.nearest(q_list, k=1)")
|
|
490
|
+
self._indent -= 1
|
|
491
|
+
self._emit("except Exception:")
|
|
492
|
+
self._indent += 1
|
|
493
|
+
self._emit("return None")
|
|
494
|
+
self._indent -= 1
|
|
495
|
+
self._emit("if not labels:")
|
|
496
|
+
self._indent += 1
|
|
497
|
+
self._emit("return None")
|
|
498
|
+
self._indent -= 1
|
|
499
|
+
self._emit("import urllib.parse as _urllib_parse")
|
|
500
|
+
self._emit("return _urllib_parse.unquote(labels[0])")
|
|
501
|
+
self._indent -= 1
|
|
502
|
+
self._emit()
|
|
503
|
+
self._emit("def _role_hash(self, role_vec):")
|
|
504
|
+
self._indent += 1
|
|
505
|
+
self._emit('"""Deterministic uint32 seed from a role tensor.')
|
|
506
|
+
self._emit('')
|
|
507
|
+
self._emit("Computed from the CPU bytes of the tensor so numerical bit-")
|
|
508
|
+
self._emit("identity across runs gives the same rotation. Matches the")
|
|
509
|
+
self._emit("numpy backend's hash scheme exactly when the semantic content")
|
|
510
|
+
self._emit("is bit-for-bit equal.")
|
|
511
|
+
self._emit('"""')
|
|
512
|
+
self._emit("import hashlib")
|
|
513
|
+
self._emit("b = role_vec.detach().cpu().contiguous().numpy().tobytes()")
|
|
514
|
+
self._emit("h = hashlib.blake2b(b, digest_size=8).digest()")
|
|
515
|
+
self._emit("return int.from_bytes(h, 'little') & 0xFFFFFFFF")
|
|
516
|
+
self._indent -= 1
|
|
517
|
+
self._emit()
|
|
518
|
+
self._emit("def _rotation_for(self, role_vec):")
|
|
519
|
+
self._indent += 1
|
|
520
|
+
self._emit('"""Block-diagonal Haar rotation seeded by the role tensor.')
|
|
521
|
+
self._emit('')
|
|
522
|
+
self._emit("Haar-uniform in the semantic block, identity in the synthetic")
|
|
523
|
+
self._emit("block — same layout as the numpy backend so rotation-binding")
|
|
524
|
+
self._emit("semantics are identical. The Haar draw uses numpy because")
|
|
525
|
+
self._emit("numpy's RandomState(seed) is the canonical bit-reproducible")
|
|
526
|
+
self._emit("generator; we move the result to the torch device before")
|
|
527
|
+
self._emit("caching.")
|
|
528
|
+
self._emit('')
|
|
529
|
+
self._emit("Cached per role-hash so the same role always produces the same")
|
|
530
|
+
self._emit("rotation — required for bind/unbind round-trip.")
|
|
531
|
+
self._emit('"""')
|
|
532
|
+
self._emit("key = self._role_hash(role_vec)")
|
|
533
|
+
self._emit("if key not in self._rot_cache:")
|
|
534
|
+
self._indent += 1
|
|
535
|
+
self._emit("import numpy as _np_bridge")
|
|
536
|
+
self._emit("rng = _np_bridge.random.RandomState(key)")
|
|
537
|
+
self._emit("A = rng.randn(self.semantic_dim, self.semantic_dim)")
|
|
538
|
+
self._emit("Q_sem_np, R_np = _np_bridge.linalg.qr(A)")
|
|
539
|
+
self._emit("d = _np_bridge.sign(_np_bridge.diag(R_np))")
|
|
540
|
+
self._emit("d[d == 0] = 1.0")
|
|
541
|
+
self._emit("Q_sem_np = Q_sem_np * d")
|
|
542
|
+
self._emit("Q_sem = _torch.as_tensor(Q_sem_np, dtype=self.dtype, device=self.device)")
|
|
543
|
+
self._emit("# Block-diagonal embedding: Q_sem on the semantic block,")
|
|
544
|
+
self._emit("# identity everywhere else.")
|
|
545
|
+
self._emit("Q = _torch.eye(self.dim, dtype=self.dtype, device=self.device)")
|
|
546
|
+
self._emit("Q[:self.semantic_dim, :self.semantic_dim] = Q_sem")
|
|
547
|
+
self._emit("self._rot_cache[key] = Q")
|
|
548
|
+
self._indent -= 1
|
|
549
|
+
self._emit("return self._rot_cache[key]")
|
|
550
|
+
self._indent -= 1
|
|
551
|
+
self._emit()
|
|
552
|
+
self._emit("def bind(self, role, filler):")
|
|
553
|
+
self._indent += 1
|
|
554
|
+
self._emit("# Rotation binding. bind(role, filler) = Q_role @ filler. Role-")
|
|
555
|
+
self._emit("# first convention (matches numpy backend and the .su demos).")
|
|
556
|
+
self._emit("Q = self._rotation_for(role)")
|
|
557
|
+
self._emit("return Q @ filler")
|
|
558
|
+
self._indent -= 1
|
|
559
|
+
self._emit()
|
|
560
|
+
self._emit("def unbind(self, role, record):")
|
|
561
|
+
self._indent += 1
|
|
562
|
+
self._emit("# Q is orthogonal so unbind(role, record) = Q_role^T @ record.")
|
|
563
|
+
self._emit("# Round-trip: unbind(r, bind(r, v)) = Q^T @ Q @ v = v exactly.")
|
|
564
|
+
self._emit("Q = self._rotation_for(role)")
|
|
565
|
+
self._emit("return Q.T @ record")
|
|
566
|
+
self._indent -= 1
|
|
567
|
+
self._emit()
|
|
568
|
+
self._emit("def bundle(self, *vectors):")
|
|
569
|
+
self._indent += 1
|
|
570
|
+
self._emit("s = _torch.stack([")
|
|
571
|
+
self._indent += 1
|
|
572
|
+
self._emit("_torch.as_tensor(v, dtype=self.dtype, device=self.device)")
|
|
573
|
+
self._emit("for v in vectors")
|
|
574
|
+
self._indent -= 1
|
|
575
|
+
self._emit("]).sum(dim=0)")
|
|
576
|
+
self._emit("n = _torch.linalg.norm(s)")
|
|
577
|
+
self._emit("return s / n if n > 0 else s")
|
|
578
|
+
self._indent -= 1
|
|
579
|
+
self._emit()
|
|
580
|
+
self._emit("def zero_vector(self):")
|
|
581
|
+
self._indent += 1
|
|
582
|
+
self._emit('"""Zero vector in the runtime dim. Emitted by simplifier identities."""')
|
|
583
|
+
self._emit("return _torch.zeros(self.dim, dtype=self.dtype, device=self.device)")
|
|
584
|
+
self._indent -= 1
|
|
585
|
+
self._emit()
|
|
586
|
+
self._emit("def bundle_of_binds(self, *role_filler_pairs):")
|
|
587
|
+
self._indent += 1
|
|
588
|
+
self._emit('"""Fused bind+sum+normalize over N role-filler pairs.')
|
|
589
|
+
self._emit('')
|
|
590
|
+
self._emit("This is the GPU-shaped primitive: stack roles into (N, d, d),")
|
|
591
|
+
self._emit("stack fillers into (N, d), one batched einsum + reduce. On")
|
|
592
|
+
self._emit("CUDA, N small bind+bundle kernel launches collapse into O(1)")
|
|
593
|
+
self._emit("big launches. Same numerics as sequential bind + bundle.")
|
|
594
|
+
self._emit('"""')
|
|
595
|
+
self._emit("if not role_filler_pairs:")
|
|
596
|
+
self._indent += 1
|
|
597
|
+
self._emit("return self.zero_vector()")
|
|
598
|
+
self._indent -= 1
|
|
599
|
+
self._emit("roles = [rf[0] for rf in role_filler_pairs]")
|
|
600
|
+
self._emit("fillers = [rf[1] for rf in role_filler_pairs]")
|
|
601
|
+
self._emit("Q_stack = _torch.stack([self._rotation_for(r) for r in roles])")
|
|
602
|
+
self._emit("F_stack = _torch.stack([")
|
|
603
|
+
self._indent += 1
|
|
604
|
+
self._emit("_torch.as_tensor(f, dtype=self.dtype, device=self.device)")
|
|
605
|
+
self._emit("for f in fillers")
|
|
606
|
+
self._indent -= 1
|
|
607
|
+
self._emit("])")
|
|
608
|
+
self._emit("# Batched bind: element-i is Q_i @ f_i; shape (N, d).")
|
|
609
|
+
self._emit("bound = _torch.einsum('nij,nj->ni', Q_stack, F_stack)")
|
|
610
|
+
self._emit("s = bound.sum(dim=0)")
|
|
611
|
+
self._emit("n = _torch.linalg.norm(s)")
|
|
612
|
+
self._emit("return s / n if n > 0 else s")
|
|
613
|
+
self._indent -= 1
|
|
614
|
+
self._emit()
|
|
615
|
+
self._emit("# ---- Rotation-hashmap (same shape as numpy backend) ----")
|
|
616
|
+
self._emit()
|
|
617
|
+
self._emit("def hashmap_new(self):")
|
|
618
|
+
self._indent += 1
|
|
619
|
+
self._emit("return _torch.zeros(self.dim, dtype=self.dtype, device=self.device)")
|
|
620
|
+
self._indent -= 1
|
|
621
|
+
self._emit()
|
|
622
|
+
self._emit("def hashmap_set(self, acc, key_vec, val_vec):")
|
|
623
|
+
self._indent += 1
|
|
624
|
+
self._emit("return acc + self.bind(key_vec, val_vec)")
|
|
625
|
+
self._indent -= 1
|
|
626
|
+
self._emit()
|
|
627
|
+
self._emit("def hashmap_get(self, acc, key_vec):")
|
|
628
|
+
self._indent += 1
|
|
629
|
+
self._emit("return self.unbind(key_vec, acc)")
|
|
630
|
+
self._indent -= 1
|
|
631
|
+
self._emit()
|
|
632
|
+
# ---- Axon runtime methods ----
|
|
633
|
+
# Axons share the substrate operations of the rotation hashmap
|
|
634
|
+
# (an axon is a bundle of bind(role, value) terms over a
|
|
635
|
+
# codebook of role-by-string-name) but are a distinct
|
|
636
|
+
# user-facing class — see planning/sutra-spec/axons.md. The
|
|
637
|
+
# methods below implement the substrate operations the
|
|
638
|
+
# `Axon` stdlib class declares as `static intrinsic method`.
|
|
639
|
+
self._emit("# ---- Axon runtime methods ----")
|
|
640
|
+
self._emit("def axon_new(self):")
|
|
641
|
+
self._indent += 1
|
|
642
|
+
self._emit("return _torch.zeros(self.dim, dtype=self.dtype, device=self.device)")
|
|
643
|
+
self._indent -= 1
|
|
644
|
+
self._emit()
|
|
645
|
+
self._emit("def axon_add(self, axon, key, value):")
|
|
646
|
+
self._indent += 1
|
|
647
|
+
self._emit("# Key may arrive as a Python string (compile-time")
|
|
648
|
+
self._emit("# identifier) or as an already-embedded vector.")
|
|
649
|
+
self._emit("# Strings are auto-embedded into a basis vector.")
|
|
650
|
+
self._emit("key_vec = self.embed(key) if isinstance(key, str) else key")
|
|
651
|
+
self._emit("# Scalar fillers (Python int / float) are promoted to")
|
|
652
|
+
self._emit("# a real-axis vector via make_real so the bind matmul")
|
|
653
|
+
self._emit("# works. Per the axon spec, axons can carry values of")
|
|
654
|
+
self._emit("# any kind; on the substrate they all become vectors.")
|
|
655
|
+
self._emit("if isinstance(value, (int, float)):")
|
|
656
|
+
self._indent += 1
|
|
657
|
+
self._emit("value = self.make_real(float(value))")
|
|
658
|
+
self._indent -= 1
|
|
659
|
+
self._emit("return axon + self.bind(key_vec, value)")
|
|
660
|
+
self._indent -= 1
|
|
661
|
+
self._emit()
|
|
662
|
+
self._emit("def axon_item(self, axon, key):")
|
|
663
|
+
self._indent += 1
|
|
664
|
+
self._emit("key_vec = self.embed(key) if isinstance(key, str) else key")
|
|
665
|
+
self._emit("return self.unbind(key_vec, axon)")
|
|
666
|
+
self._indent -= 1
|
|
667
|
+
self._emit()
|
|
668
|
+
# ---- 2D-Givens-per-slot rotation binding (synthetic subspace) ----
|
|
669
|
+
# Mirrors the numpy backend's slot block. See codegen.py for the
|
|
670
|
+
# block; this is the pytorch realization, with `_torch.zeros`
|
|
671
|
+
# and `tensor.clone()` instead of `_np.copy()`.
|
|
672
|
+
self._emit("# ---- 2D-Givens-per-slot rotation binding (synthetic subspace) ----")
|
|
673
|
+
self._emit("# Mirrors the numpy backend slot block; see codegen.py.")
|
|
674
|
+
self._emit("# SLOT_BASE = 5 to leave room for AXIS_LOOP_DONE at synthetic[4].")
|
|
675
|
+
self._emit("SLOT_BASE = 5")
|
|
676
|
+
self._emit()
|
|
677
|
+
self._emit("def _slot_plane(self, slot_idx):")
|
|
678
|
+
self._indent += 1
|
|
679
|
+
self._emit("n_planes = (self.synthetic_dim - self.SLOT_BASE) // 2")
|
|
680
|
+
self._emit("if n_planes <= 0:")
|
|
681
|
+
self._indent += 1
|
|
682
|
+
self._emit("raise RuntimeError(")
|
|
683
|
+
self._indent += 1
|
|
684
|
+
self._emit('"synthetic subspace has no room for slot planes; "')
|
|
685
|
+
self._emit('"increase synthetic_dim or SLOT_BASE budget")')
|
|
686
|
+
self._indent -= 1
|
|
687
|
+
self._indent -= 1
|
|
688
|
+
self._emit("s = int(slot_idx) % n_planes")
|
|
689
|
+
self._emit("base = self.semantic_dim + self.SLOT_BASE + 2 * s")
|
|
690
|
+
self._emit("return (base, base + 1)")
|
|
691
|
+
self._indent -= 1
|
|
692
|
+
self._emit()
|
|
693
|
+
self._emit("def slot_store(self, state, slot_idx, scalar):")
|
|
694
|
+
self._indent += 1
|
|
695
|
+
self._emit("i, j = self._slot_plane(slot_idx)")
|
|
696
|
+
self._emit("new = state.clone() if hasattr(state, 'clone') else state.copy()")
|
|
697
|
+
self._emit("new[i] = float(scalar)")
|
|
698
|
+
self._emit("new[j] = 0.0")
|
|
699
|
+
self._emit("return new")
|
|
700
|
+
self._indent -= 1
|
|
701
|
+
self._emit()
|
|
702
|
+
self._emit("def slot_load(self, state, slot_idx):")
|
|
703
|
+
self._indent += 1
|
|
704
|
+
self._emit('"""Read the slot scalar. Returns a torch 0-dim tensor.')
|
|
705
|
+
self._emit('')
|
|
706
|
+
self._emit("Substrate-pure: downstream arithmetic stays in tensor land. See")
|
|
707
|
+
self._emit("planning/findings/2026-04-30-substrate-purity-leak-enumeration.md.")
|
|
708
|
+
self._emit('"""')
|
|
709
|
+
self._emit("i, _j = self._slot_plane(slot_idx)")
|
|
710
|
+
self._emit("return state[i]")
|
|
711
|
+
self._indent -= 1
|
|
712
|
+
self._emit()
|
|
713
|
+
self._emit("# ---- Binding-array primitive (substrate-stored ordered list) ----")
|
|
714
|
+
self._emit("# Layout: arr[0] = length scalar, arr[1..length] = elements. Used by")
|
|
715
|
+
self._emit("# foreach_loop. Pure tensor reads/writes; no Python list, no heap")
|
|
716
|
+
self._emit("# allocation beyond the initial tensor.")
|
|
717
|
+
self._emit()
|
|
718
|
+
self._emit("def array_from_literal(self, *values):")
|
|
719
|
+
self._indent += 1
|
|
720
|
+
self._emit('"""Build an array from compile-time-known scalar values."""')
|
|
721
|
+
self._emit("arr = _torch.zeros(len(values) + 1, dtype=self.dtype, device=self.device)")
|
|
722
|
+
self._emit("arr[0] = float(len(values))")
|
|
723
|
+
self._emit("for i, v in enumerate(values):")
|
|
724
|
+
self._indent += 1
|
|
725
|
+
self._emit("arr[1 + i] = float(v)")
|
|
726
|
+
self._indent -= 1
|
|
727
|
+
self._emit("return arr")
|
|
728
|
+
self._indent -= 1
|
|
729
|
+
self._emit()
|
|
730
|
+
self._emit("def array_length(self, arr):")
|
|
731
|
+
self._indent += 1
|
|
732
|
+
self._emit('"""Read the length prefix as an int (used for Python loop bound)."""')
|
|
733
|
+
self._emit("return int(arr[0].item())")
|
|
734
|
+
self._indent -= 1
|
|
735
|
+
self._emit()
|
|
736
|
+
self._emit("def array_get(self, arr, i):")
|
|
737
|
+
self._indent += 1
|
|
738
|
+
self._emit('"""Read element at index i (0-based). Returns torch 0-dim tensor."""')
|
|
739
|
+
self._emit("return arr[1 + int(i)]")
|
|
740
|
+
self._indent -= 1
|
|
741
|
+
self._emit()
|
|
742
|
+
self._emit("# ---- Substrate scalar primitives (boundary-leak reductions) ----")
|
|
743
|
+
self._emit()
|
|
744
|
+
self._emit("def truth_axis(self, vec_or_scalar):")
|
|
745
|
+
self._indent += 1
|
|
746
|
+
self._emit('"""Read AXIS_TRUTH from a fuzzy-vector result, or pass scalars through.')
|
|
747
|
+
self._emit('')
|
|
748
|
+
self._emit("Returns a torch 0-dim tensor; substrate-pure loop halt checks consume")
|
|
749
|
+
self._emit("the result without crossing the Python boundary.")
|
|
750
|
+
self._emit('"""')
|
|
751
|
+
self._emit("if hasattr(vec_or_scalar, '__len__') and len(vec_or_scalar) > 1:")
|
|
752
|
+
self._indent += 1
|
|
753
|
+
self._emit("return vec_or_scalar[self.semantic_dim + self.AXIS_TRUTH]")
|
|
754
|
+
self._indent -= 1
|
|
755
|
+
self._emit("if _torch.is_tensor(vec_or_scalar):")
|
|
756
|
+
self._indent += 1
|
|
757
|
+
self._emit("return vec_or_scalar")
|
|
758
|
+
self._indent -= 1
|
|
759
|
+
self._emit("return _torch.tensor(vec_or_scalar, dtype=self.dtype, device=self.device)")
|
|
760
|
+
self._indent -= 1
|
|
761
|
+
self._emit()
|
|
762
|
+
self._emit("def heaviside(self, x):")
|
|
763
|
+
self._indent += 1
|
|
764
|
+
self._emit('"""Step function: 1.0 where x > 0, else 0.0. Torch 0-dim tensor."""')
|
|
765
|
+
self._emit("if not _torch.is_tensor(x):")
|
|
766
|
+
self._indent += 1
|
|
767
|
+
self._emit("x = _torch.tensor(x, dtype=self.dtype, device=self.device)")
|
|
768
|
+
self._indent -= 1
|
|
769
|
+
self._emit("zero = _torch.zeros((), dtype=self.dtype, device=self.device)")
|
|
770
|
+
self._emit("return _torch.heaviside(x.to(self.dtype), zero)")
|
|
771
|
+
self._indent -= 1
|
|
772
|
+
self._emit()
|
|
773
|
+
self._emit("def saturate_unit(self, x):")
|
|
774
|
+
self._indent += 1
|
|
775
|
+
self._emit('"""min(x, 1.0) implemented as torch.minimum. Torch 0-dim tensor."""')
|
|
776
|
+
self._emit("if not _torch.is_tensor(x):")
|
|
777
|
+
self._indent += 1
|
|
778
|
+
self._emit("x = _torch.tensor(x, dtype=self.dtype, device=self.device)")
|
|
779
|
+
self._indent -= 1
|
|
780
|
+
self._emit("one = _torch.ones((), dtype=self.dtype, device=self.device)")
|
|
781
|
+
self._emit("return _torch.minimum(x, one)")
|
|
782
|
+
self._indent -= 1
|
|
783
|
+
self._emit()
|
|
784
|
+
self._emit("def rotate_slot(self, state, slot_idx, angle):")
|
|
785
|
+
self._indent += 1
|
|
786
|
+
self._emit("import math as _math")
|
|
787
|
+
self._emit("i, j = self._slot_plane(slot_idx)")
|
|
788
|
+
self._emit("c, s = _math.cos(float(angle)), _math.sin(float(angle))")
|
|
789
|
+
self._emit("new = state.clone() if hasattr(state, 'clone') else state.copy()")
|
|
790
|
+
self._emit("xi, xj = float(state[i]), float(state[j])")
|
|
791
|
+
self._emit("new[i] = c * xi - s * xj")
|
|
792
|
+
self._emit("new[j] = s * xi + c * xj")
|
|
793
|
+
self._emit("return new")
|
|
794
|
+
self._indent -= 1
|
|
795
|
+
self._emit()
|
|
796
|
+
self._emit("def similarity(self, a, b):")
|
|
797
|
+
self._indent += 1
|
|
798
|
+
self._emit("na = _torch.linalg.norm(a)")
|
|
799
|
+
self._emit("nb = _torch.linalg.norm(b)")
|
|
800
|
+
self._emit("# eps-guarded divide — zero-norm case evaluates to 0 without branch.")
|
|
801
|
+
self._emit("return float(_torch.dot(a, b) / (na * nb + _torch.finfo(self.dtype).tiny))")
|
|
802
|
+
self._indent -= 1
|
|
803
|
+
self._emit()
|
|
804
|
+
# General-purpose tensor operations — see codegen.py for the
|
|
805
|
+
# numpy-backend equivalent and stdlib/tensor.su for the Sutra
|
|
806
|
+
# surface (`Tensor.MatrixMul` etc.).
|
|
807
|
+
self._emit("def matmul(self, a, b):")
|
|
808
|
+
self._indent += 1
|
|
809
|
+
self._emit('"""Matrix multiplication (torch matmul / `a @ b`)."""')
|
|
810
|
+
self._emit("return _torch.matmul(a, b)")
|
|
811
|
+
self._indent -= 1
|
|
812
|
+
self._emit()
|
|
813
|
+
self._emit("def tensor_product(self, a, b):")
|
|
814
|
+
self._indent += 1
|
|
815
|
+
self._emit('"""Tensor / Kronecker product."""')
|
|
816
|
+
self._emit("return _torch.kron(a, b)")
|
|
817
|
+
self._indent -= 1
|
|
818
|
+
self._emit()
|
|
819
|
+
self._emit("def outer(self, a, b):")
|
|
820
|
+
self._indent += 1
|
|
821
|
+
self._emit('"""Vector outer product → rank-2 tensor."""')
|
|
822
|
+
self._emit("return _torch.outer(a, b)")
|
|
823
|
+
self._indent -= 1
|
|
824
|
+
self._emit()
|
|
825
|
+
self._emit("def dot(self, a, b):")
|
|
826
|
+
self._indent += 1
|
|
827
|
+
self._emit('"""Inner / dot product → scalar."""')
|
|
828
|
+
self._emit("return float(_torch.dot(a, b))")
|
|
829
|
+
self._indent -= 1
|
|
830
|
+
self._emit()
|
|
831
|
+
self._emit("def transpose(self, m):")
|
|
832
|
+
self._indent += 1
|
|
833
|
+
self._emit('"""Transpose (last two dims for 2-D+; identity for 1-D)."""')
|
|
834
|
+
self._emit("if m.ndim < 2:")
|
|
835
|
+
self._indent += 1
|
|
836
|
+
self._emit("return m")
|
|
837
|
+
self._indent -= 1
|
|
838
|
+
self._emit("return _torch.transpose(m, -2, -1)")
|
|
839
|
+
self._indent -= 1
|
|
840
|
+
self._emit()
|
|
841
|
+
self._emit("def norm(self, v):")
|
|
842
|
+
self._indent += 1
|
|
843
|
+
self._emit('"""L2 norm. Scalar result."""')
|
|
844
|
+
self._emit("return float(_torch.linalg.norm(v))")
|
|
845
|
+
self._indent -= 1
|
|
846
|
+
self._emit()
|
|
847
|
+
self._emit("def normalize(self, v):")
|
|
848
|
+
self._indent += 1
|
|
849
|
+
self._emit('"""L2-normalize with an eps-guard so zero-norm input returns zero."""')
|
|
850
|
+
self._emit("n = _torch.linalg.norm(v)")
|
|
851
|
+
self._emit("return v / (n + _torch.finfo(self.dtype).tiny)")
|
|
852
|
+
self._indent -= 1
|
|
853
|
+
self._emit()
|
|
854
|
+
self._emit("def rotation_for(self, role):")
|
|
855
|
+
self._indent += 1
|
|
856
|
+
self._emit('"""Cached Haar-random orthogonal rotation matrix for the role vector."""')
|
|
857
|
+
self._emit("return self._rotation_for(role)")
|
|
858
|
+
self._indent -= 1
|
|
859
|
+
self._emit()
|
|
860
|
+
# PascalCase aliases — the preferred Sutra-side spelling.
|
|
861
|
+
self._emit("MatrixMul = matmul")
|
|
862
|
+
self._emit("TensorProduct = tensor_product")
|
|
863
|
+
self._emit("Outer = outer")
|
|
864
|
+
self._emit("Dot = dot")
|
|
865
|
+
self._emit("Transpose = transpose")
|
|
866
|
+
self._emit("Norm = norm")
|
|
867
|
+
self._emit("Normalize = normalize")
|
|
868
|
+
self._emit("RotationFor = rotation_for")
|
|
869
|
+
self._emit()
|
|
870
|
+
self._emit("# ---- Vector component accessors (debugging / teaching) ----")
|
|
871
|
+
self._emit()
|
|
872
|
+
self._emit("def component(self, v, i):")
|
|
873
|
+
self._indent += 1
|
|
874
|
+
self._emit('"""Return element i of v over the full extended state vector."""')
|
|
875
|
+
self._emit("return float(v[int(i)].item())")
|
|
876
|
+
self._indent -= 1
|
|
877
|
+
self._emit()
|
|
878
|
+
self._emit("def semantic(self, v, i):")
|
|
879
|
+
self._indent += 1
|
|
880
|
+
self._emit('"""Return element i within the semantic block."""')
|
|
881
|
+
self._emit("idx = int(i)")
|
|
882
|
+
self._emit("if idx < 0 or idx >= self.semantic_dim:")
|
|
883
|
+
self._indent += 1
|
|
884
|
+
self._emit("raise IndexError(")
|
|
885
|
+
self._indent += 1
|
|
886
|
+
self._emit('f"semantic index {idx} out of range [0, {self.semantic_dim})")')
|
|
887
|
+
self._indent -= 1
|
|
888
|
+
self._indent -= 1
|
|
889
|
+
self._emit("return float(v[idx].item())")
|
|
890
|
+
self._indent -= 1
|
|
891
|
+
self._emit()
|
|
892
|
+
self._emit("def synthetic(self, v, i):")
|
|
893
|
+
self._indent += 1
|
|
894
|
+
self._emit('"""Return element i within the synthetic block."""')
|
|
895
|
+
self._emit("idx = int(i)")
|
|
896
|
+
self._emit("if idx < 0 or idx >= self.synthetic_dim:")
|
|
897
|
+
self._indent += 1
|
|
898
|
+
self._emit("raise IndexError(")
|
|
899
|
+
self._indent += 1
|
|
900
|
+
self._emit('f"synthetic index {idx} out of range [0, {self.synthetic_dim})")')
|
|
901
|
+
self._indent -= 1
|
|
902
|
+
self._indent -= 1
|
|
903
|
+
self._emit("return float(v[self.semantic_dim + idx].item())")
|
|
904
|
+
self._indent -= 1
|
|
905
|
+
self._emit()
|
|
906
|
+
self._emit("# ---- Canonical-axis accessors (real/imag/truth) ----")
|
|
907
|
+
self._emit()
|
|
908
|
+
self._emit("def real(self, v):")
|
|
909
|
+
self._indent += 1
|
|
910
|
+
self._emit("return float(v[self.semantic_dim + self.AXIS_REAL].item())")
|
|
911
|
+
self._indent -= 1
|
|
912
|
+
self._emit()
|
|
913
|
+
self._emit("def imag(self, v):")
|
|
914
|
+
self._indent += 1
|
|
915
|
+
self._emit("return float(v[self.semantic_dim + self.AXIS_IMAG].item())")
|
|
916
|
+
self._indent -= 1
|
|
917
|
+
self._emit()
|
|
918
|
+
self._emit("def truth(self, v):")
|
|
919
|
+
self._indent += 1
|
|
920
|
+
self._emit("return float(v[self.semantic_dim + self.AXIS_TRUTH].item())")
|
|
921
|
+
self._indent -= 1
|
|
922
|
+
self._emit()
|
|
923
|
+
self._emit("def make_real(self, x):")
|
|
924
|
+
self._indent += 1
|
|
925
|
+
self._emit("v = _torch.zeros(self.dim, dtype=self.dtype, device=self.device)")
|
|
926
|
+
self._emit("v[self.semantic_dim + self.AXIS_REAL] = float(x)")
|
|
927
|
+
self._emit("return v")
|
|
928
|
+
self._indent -= 1
|
|
929
|
+
self._emit()
|
|
930
|
+
self._emit("def make_complex(self, re, im):")
|
|
931
|
+
self._indent += 1
|
|
932
|
+
self._emit("v = _torch.zeros(self.dim, dtype=self.dtype, device=self.device)")
|
|
933
|
+
self._emit("v[self.semantic_dim + self.AXIS_REAL] = float(re)")
|
|
934
|
+
self._emit("v[self.semantic_dim + self.AXIS_IMAG] = float(im)")
|
|
935
|
+
self._emit("return v")
|
|
936
|
+
self._indent -= 1
|
|
937
|
+
self._emit()
|
|
938
|
+
self._emit("def _swap_ri_matrix(self):")
|
|
939
|
+
self._indent += 1
|
|
940
|
+
self._emit("if not hasattr(self, '_swap_ri_cache') or self._swap_ri_cache is None:")
|
|
941
|
+
self._indent += 1
|
|
942
|
+
self._emit("M = _torch.zeros((self.dim, self.dim), dtype=self.dtype, device=self.device)")
|
|
943
|
+
self._emit("r = self.semantic_dim + self.AXIS_REAL")
|
|
944
|
+
self._emit("i = self.semantic_dim + self.AXIS_IMAG")
|
|
945
|
+
self._emit("M[r, i] = 1.0; M[i, r] = 1.0")
|
|
946
|
+
self._emit("self._swap_ri_cache = M")
|
|
947
|
+
self._indent -= 1
|
|
948
|
+
self._emit("return self._swap_ri_cache")
|
|
949
|
+
self._indent -= 1
|
|
950
|
+
self._emit()
|
|
951
|
+
self._emit("def _cm_real_matrix(self):")
|
|
952
|
+
self._indent += 1
|
|
953
|
+
self._emit("if not hasattr(self, '_cm_real_cache') or self._cm_real_cache is None:")
|
|
954
|
+
self._indent += 1
|
|
955
|
+
self._emit("M = _torch.zeros((self.dim, self.dim), dtype=self.dtype, device=self.device)")
|
|
956
|
+
self._emit("r = self.semantic_dim + self.AXIS_REAL")
|
|
957
|
+
self._emit("i = self.semantic_dim + self.AXIS_IMAG")
|
|
958
|
+
self._emit("M[r, r] = 1.0; M[r, i] = -1.0")
|
|
959
|
+
self._emit("self._cm_real_cache = M")
|
|
960
|
+
self._indent -= 1
|
|
961
|
+
self._emit("return self._cm_real_cache")
|
|
962
|
+
self._indent -= 1
|
|
963
|
+
self._emit()
|
|
964
|
+
self._emit("def _cm_imag_matrix(self):")
|
|
965
|
+
self._indent += 1
|
|
966
|
+
self._emit("if not hasattr(self, '_cm_imag_cache') or self._cm_imag_cache is None:")
|
|
967
|
+
self._indent += 1
|
|
968
|
+
self._emit("M = _torch.zeros((self.dim, self.dim), dtype=self.dtype, device=self.device)")
|
|
969
|
+
self._emit("r = self.semantic_dim + self.AXIS_REAL")
|
|
970
|
+
self._emit("i = self.semantic_dim + self.AXIS_IMAG")
|
|
971
|
+
self._emit("M[i, r] = 1.0; M[i, i] = 1.0")
|
|
972
|
+
self._emit("self._cm_imag_cache = M")
|
|
973
|
+
self._indent -= 1
|
|
974
|
+
self._emit("return self._cm_imag_cache")
|
|
975
|
+
self._indent -= 1
|
|
976
|
+
self._emit()
|
|
977
|
+
self._emit("def complex_mul(self, a, b):")
|
|
978
|
+
self._indent += 1
|
|
979
|
+
self._emit('"""Complex product: matrix form, no scalar extraction.')
|
|
980
|
+
self._emit('')
|
|
981
|
+
self._emit("c = _cm_real @ (a * b) + _cm_imag @ ((_swap_ri @ a) * b)")
|
|
982
|
+
self._emit('"""')
|
|
983
|
+
self._emit("av = self._as_complex_vector(a)")
|
|
984
|
+
self._emit("bv = self._as_complex_vector(b)")
|
|
985
|
+
self._emit("ab = av * bv")
|
|
986
|
+
self._emit("swapped_ab = (self._swap_ri_matrix() @ av) * bv")
|
|
987
|
+
self._emit("return self._cm_real_matrix() @ ab + self._cm_imag_matrix() @ swapped_ab")
|
|
988
|
+
self._indent -= 1
|
|
989
|
+
self._emit()
|
|
990
|
+
self._emit("def _as_complex_vector(self, x):")
|
|
991
|
+
self._indent += 1
|
|
992
|
+
self._emit('"""Coerce Python scalar / tensor to complex-plane form."""')
|
|
993
|
+
self._emit("if isinstance(x, _torch.Tensor):")
|
|
994
|
+
self._indent += 1
|
|
995
|
+
self._emit("return x")
|
|
996
|
+
self._indent -= 1
|
|
997
|
+
self._emit("if isinstance(x, bool):")
|
|
998
|
+
self._indent += 1
|
|
999
|
+
self._emit("return self.make_real(1.0 if x else 0.0)")
|
|
1000
|
+
self._indent -= 1
|
|
1001
|
+
self._emit("return self.make_real(float(x))")
|
|
1002
|
+
self._indent -= 1
|
|
1003
|
+
self._emit()
|
|
1004
|
+
self._emit("def make_truth(self, t):")
|
|
1005
|
+
self._indent += 1
|
|
1006
|
+
self._emit("v = _torch.zeros(self.dim, dtype=self.dtype, device=self.device)")
|
|
1007
|
+
self._emit("v[self.semantic_dim + self.AXIS_TRUTH] = float(t)")
|
|
1008
|
+
self._emit("return v")
|
|
1009
|
+
self._indent -= 1
|
|
1010
|
+
self._emit()
|
|
1011
|
+
self._emit("def make_char(self, codepoint):")
|
|
1012
|
+
self._indent += 1
|
|
1013
|
+
self._emit('"""Character literal: a 1-character String. Equivalent to')
|
|
1014
|
+
self._emit('make_string(chr(codepoint)). The `char` type is now a')
|
|
1015
|
+
self._emit('1-character String; AXIS_CHAR_FLAG is an alias for')
|
|
1016
|
+
self._emit('AXIS_STRING_FLAG."""')
|
|
1017
|
+
self._emit("return self.make_string(chr(int(codepoint)))")
|
|
1018
|
+
self._indent -= 1
|
|
1019
|
+
self._emit()
|
|
1020
|
+
self._emit("def is_char(self, v):")
|
|
1021
|
+
self._indent += 1
|
|
1022
|
+
self._emit('"""True iff v is a String value (kept as `is_char` for')
|
|
1023
|
+
self._emit('backward-compat with code that pre-dated the rename to')
|
|
1024
|
+
self._emit('AXIS_STRING_FLAG; new code should use is_string)."""')
|
|
1025
|
+
self._emit("return bool(v[self.semantic_dim + self.AXIS_STRING_FLAG].item() >= 0.5)")
|
|
1026
|
+
self._indent -= 1
|
|
1027
|
+
self._emit()
|
|
1028
|
+
self._emit("# ---- String runtime methods ----")
|
|
1029
|
+
self._emit("# Encoding: AXIS_STRING_FLAG marks the vector as a String.")
|
|
1030
|
+
self._emit("# Characters pack into the synthetic axes — char[0] at")
|
|
1031
|
+
self._emit("# AXIS_REAL (=synthetic[0]), char[1] at AXIS_IMAG")
|
|
1032
|
+
self._emit("# (=synthetic[1]), char[k] for k>=2 at synthetic[k+3]")
|
|
1033
|
+
self._emit("# (skipping AXIS_TRUTH/STRING_FLAG/LOOP_DONE at synthetic")
|
|
1034
|
+
self._emit("# [2..4]). Length is recovered by walking from the highest")
|
|
1035
|
+
self._emit("# possible char position down to the first non-zero. See")
|
|
1036
|
+
self._emit("# planning/sutra-spec/strings.md.")
|
|
1037
|
+
self._emit("def _string_axis(self, char_index):")
|
|
1038
|
+
self._indent += 1
|
|
1039
|
+
self._emit('"""Map a character index k into the absolute axis offset')
|
|
1040
|
+
self._emit('inside the synthetic block (relative to semantic_dim)."""')
|
|
1041
|
+
self._emit("return char_index if char_index < 2 else char_index + 3")
|
|
1042
|
+
self._indent -= 1
|
|
1043
|
+
self._emit()
|
|
1044
|
+
self._emit("def string_max_length(self):")
|
|
1045
|
+
self._indent += 1
|
|
1046
|
+
self._emit('"""Maximum string length that fits in the current')
|
|
1047
|
+
self._emit('synthetic_dim. char positions occupy synthetic[0,1] plus')
|
|
1048
|
+
self._emit('synthetic[5..synthetic_dim-1]."""')
|
|
1049
|
+
self._emit("if self.synthetic_dim < 5:")
|
|
1050
|
+
self._indent += 1
|
|
1051
|
+
self._emit("return min(self.synthetic_dim, 2)")
|
|
1052
|
+
self._indent -= 1
|
|
1053
|
+
self._emit("return 2 + (self.synthetic_dim - 5)")
|
|
1054
|
+
self._indent -= 1
|
|
1055
|
+
self._emit()
|
|
1056
|
+
self._emit("def make_string(self, s):")
|
|
1057
|
+
self._indent += 1
|
|
1058
|
+
self._emit('"""Construct a String value from a Python str."""')
|
|
1059
|
+
self._emit("if not isinstance(s, str):")
|
|
1060
|
+
self._indent += 1
|
|
1061
|
+
self._emit("s = str(s)")
|
|
1062
|
+
self._indent -= 1
|
|
1063
|
+
self._emit("max_len = self.string_max_length()")
|
|
1064
|
+
self._emit("if len(s) > max_len:")
|
|
1065
|
+
self._indent += 1
|
|
1066
|
+
self._emit("raise ValueError(")
|
|
1067
|
+
self._emit('"string %r length %d exceeds synthetic-axis budget %d; '
|
|
1068
|
+
'increase synthetic_dim" % (s, len(s), max_len))')
|
|
1069
|
+
self._indent -= 1
|
|
1070
|
+
self._emit("v = _torch.zeros(self.dim, dtype=self.dtype, device=self.device)")
|
|
1071
|
+
self._emit("v[self.semantic_dim + self.AXIS_STRING_FLAG] = 1.0")
|
|
1072
|
+
self._emit("for k, ch in enumerate(s):")
|
|
1073
|
+
self._indent += 1
|
|
1074
|
+
self._emit("axis = self._string_axis(k)")
|
|
1075
|
+
self._emit("v[self.semantic_dim + axis] = float(ord(ch))")
|
|
1076
|
+
self._indent -= 1
|
|
1077
|
+
self._emit("return v")
|
|
1078
|
+
self._indent -= 1
|
|
1079
|
+
self._emit()
|
|
1080
|
+
self._emit("def is_string(self, v):")
|
|
1081
|
+
self._indent += 1
|
|
1082
|
+
self._emit('"""True iff v has the AXIS_STRING_FLAG set."""')
|
|
1083
|
+
self._emit("return bool(v[self.semantic_dim + self.AXIS_STRING_FLAG].item() >= 0.5)")
|
|
1084
|
+
self._indent -= 1
|
|
1085
|
+
self._emit()
|
|
1086
|
+
self._emit("def string_length(self, v):")
|
|
1087
|
+
self._indent += 1
|
|
1088
|
+
self._emit('"""Return the length of String v by scanning from the')
|
|
1089
|
+
self._emit('highest possible char position down to the first non-zero')
|
|
1090
|
+
self._emit('codepoint. Trailing-zero-as-sentinel: a string with')
|
|
1091
|
+
self._emit('codepoint 0 in its tail will read shorter than written."""')
|
|
1092
|
+
self._emit("max_k = self.string_max_length()")
|
|
1093
|
+
self._emit("for k in range(max_k - 1, -1, -1):")
|
|
1094
|
+
self._indent += 1
|
|
1095
|
+
self._emit("axis = self._string_axis(k)")
|
|
1096
|
+
self._emit("if v[self.semantic_dim + axis].item() != 0.0:")
|
|
1097
|
+
self._indent += 1
|
|
1098
|
+
self._emit("return k + 1")
|
|
1099
|
+
self._indent -= 1
|
|
1100
|
+
self._indent -= 1
|
|
1101
|
+
self._emit("return 0")
|
|
1102
|
+
self._indent -= 1
|
|
1103
|
+
self._emit()
|
|
1104
|
+
self._emit("def string_char_at(self, v, i):")
|
|
1105
|
+
self._indent += 1
|
|
1106
|
+
self._emit('"""Return the codepoint at position i (as an int). Out-of-')
|
|
1107
|
+
self._emit('range positions return 0."""')
|
|
1108
|
+
self._emit("i = int(i) if not isinstance(i, int) else i")
|
|
1109
|
+
self._emit("if i < 0 or i >= self.string_max_length():")
|
|
1110
|
+
self._indent += 1
|
|
1111
|
+
self._emit("return 0")
|
|
1112
|
+
self._indent -= 1
|
|
1113
|
+
self._emit("axis = self._string_axis(i)")
|
|
1114
|
+
self._emit("return int(v[self.semantic_dim + axis].item())")
|
|
1115
|
+
self._indent -= 1
|
|
1116
|
+
self._emit()
|
|
1117
|
+
self._emit("def string_to_python(self, v):")
|
|
1118
|
+
self._indent += 1
|
|
1119
|
+
self._emit('"""Decode a String value back to a Python str. Useful for')
|
|
1120
|
+
self._emit('returning string-valued results to the host."""')
|
|
1121
|
+
self._emit("n = self.string_length(v)")
|
|
1122
|
+
self._emit("chars = []")
|
|
1123
|
+
self._emit("for i in range(n):")
|
|
1124
|
+
self._indent += 1
|
|
1125
|
+
self._emit("axis = self._string_axis(i)")
|
|
1126
|
+
self._emit("chars.append(chr(int(v[self.semantic_dim + axis].item())))")
|
|
1127
|
+
self._indent -= 1
|
|
1128
|
+
self._emit('return "".join(chars)')
|
|
1129
|
+
self._indent -= 1
|
|
1130
|
+
self._emit()
|
|
1131
|
+
self._emit("def make_trit(self, t):")
|
|
1132
|
+
self._indent += 1
|
|
1133
|
+
self._emit('"""Three-valued primitive class — aliases make_truth."""')
|
|
1134
|
+
self._emit("return self.make_truth(t)")
|
|
1135
|
+
self._indent -= 1
|
|
1136
|
+
self._emit()
|
|
1137
|
+
self._emit("def defuzzify_trit(self, v, iters=10, beta=2.0):")
|
|
1138
|
+
self._indent += 1
|
|
1139
|
+
self._emit('"""Three-way polarizer toward {-1, 0, +1} — torch version."""')
|
|
1140
|
+
self._emit("x = float(v[self.semantic_dim + self.AXIS_TRUTH].item())")
|
|
1141
|
+
self._emit("b = float(beta)")
|
|
1142
|
+
self._emit("for _ in range(int(iters)):")
|
|
1143
|
+
self._indent += 1
|
|
1144
|
+
self._emit("import math as _math")
|
|
1145
|
+
self._emit("w_neg = _math.exp(-b * (x + 1.0) ** 2)")
|
|
1146
|
+
self._emit("w_zero = _math.exp(-b * x ** 2)")
|
|
1147
|
+
self._emit("w_pos = _math.exp(-b * (x - 1.0) ** 2)")
|
|
1148
|
+
self._emit("s = w_neg + w_zero + w_pos")
|
|
1149
|
+
self._emit("x = (-w_neg + w_pos) / s")
|
|
1150
|
+
self._emit("b *= 2.0")
|
|
1151
|
+
self._indent -= 1
|
|
1152
|
+
self._emit("out = v.clone()")
|
|
1153
|
+
self._emit("out[self.semantic_dim + self.AXIS_TRUTH] = float(x)")
|
|
1154
|
+
self._emit("return out")
|
|
1155
|
+
self._indent -= 1
|
|
1156
|
+
self._emit()
|
|
1157
|
+
|
|
1158
|
+
self._emit("# ---- Logical operators — smooth polynomial form ----")
|
|
1159
|
+
self._emit("#")
|
|
1160
|
+
self._emit("# Same Lagrange-derived polynomials as the numpy backend:")
|
|
1161
|
+
self._emit("# min(a, b) = (a + b + ab - a² - b² + a²b²) / 2")
|
|
1162
|
+
self._emit("# max(a, b) = (a + b - ab + a² + b² - a²b²) / 2")
|
|
1163
|
+
self._emit("# Exact on {-1, 0, +1}², C^∞ everywhere, CUDA via torch ops.")
|
|
1164
|
+
self._emit()
|
|
1165
|
+
self._emit("def _as_truth_vector(self, x):")
|
|
1166
|
+
self._indent += 1
|
|
1167
|
+
self._emit('"""Return x as a tensor. Scalar / bool → make_truth."""')
|
|
1168
|
+
self._emit("if isinstance(x, _torch.Tensor):")
|
|
1169
|
+
self._indent += 1
|
|
1170
|
+
self._emit("return x")
|
|
1171
|
+
self._indent -= 1
|
|
1172
|
+
self._emit("if isinstance(x, bool):")
|
|
1173
|
+
self._indent += 1
|
|
1174
|
+
self._emit("return self.make_truth(1.0 if x else -1.0)")
|
|
1175
|
+
self._indent -= 1
|
|
1176
|
+
self._emit("return self.make_truth(float(x))")
|
|
1177
|
+
self._indent -= 1
|
|
1178
|
+
self._emit()
|
|
1179
|
+
# logical_and / logical_or / logical_not runtime methods
|
|
1180
|
+
# deleted in v0.3 step 4 — operator lowering + stdlib inline
|
|
1181
|
+
# replaces every caller with the inline polynomial form.
|
|
1182
|
+
|
|
1183
|
+
self._emit("# ---- Ordered comparison — pure tensor ops, no branches ----")
|
|
1184
|
+
self._emit()
|
|
1185
|
+
self._emit("def _real_projector(self):")
|
|
1186
|
+
self._indent += 1
|
|
1187
|
+
self._emit('"""Diagonal real-axis projector. Cached tensor on device."""')
|
|
1188
|
+
self._emit("if not hasattr(self, '_real_proj_cache') or self._real_proj_cache is None:")
|
|
1189
|
+
self._indent += 1
|
|
1190
|
+
self._emit("M = _torch.zeros((self.dim, self.dim), dtype=self.dtype, device=self.device)")
|
|
1191
|
+
self._emit("idx = self.semantic_dim + self.AXIS_REAL")
|
|
1192
|
+
self._emit("M[idx, idx] = 1.0")
|
|
1193
|
+
self._emit("self._real_proj_cache = M")
|
|
1194
|
+
self._indent -= 1
|
|
1195
|
+
self._emit("return self._real_proj_cache")
|
|
1196
|
+
self._indent -= 1
|
|
1197
|
+
self._emit()
|
|
1198
|
+
self._emit("def _truth_from_real(self):")
|
|
1199
|
+
self._indent += 1
|
|
1200
|
+
self._emit('"""Matrix moving the real-axis entry to the truth axis."""')
|
|
1201
|
+
self._emit("if not hasattr(self, '_t_from_r_cache') or self._t_from_r_cache is None:")
|
|
1202
|
+
self._indent += 1
|
|
1203
|
+
self._emit("M = _torch.zeros((self.dim, self.dim), dtype=self.dtype, device=self.device)")
|
|
1204
|
+
self._emit("M[self.semantic_dim + self.AXIS_TRUTH,")
|
|
1205
|
+
self._indent += 1
|
|
1206
|
+
self._emit("self.semantic_dim + self.AXIS_REAL] = 1.0")
|
|
1207
|
+
self._indent -= 1
|
|
1208
|
+
self._emit("self._t_from_r_cache = M")
|
|
1209
|
+
self._indent -= 1
|
|
1210
|
+
self._emit("return self._t_from_r_cache")
|
|
1211
|
+
self._indent -= 1
|
|
1212
|
+
self._emit()
|
|
1213
|
+
self._emit("CMP_SLOPE = 100.0")
|
|
1214
|
+
self._emit()
|
|
1215
|
+
self._emit("def gt(self, a, b):")
|
|
1216
|
+
self._indent += 1
|
|
1217
|
+
self._emit('"""a > b — differentiable tanh on real-axis difference."""')
|
|
1218
|
+
self._emit("av = self._as_complex_vector(a)")
|
|
1219
|
+
self._emit("bv = self._as_complex_vector(b)")
|
|
1220
|
+
self._emit("diff_r = self._real_projector() @ (av - bv)")
|
|
1221
|
+
self._emit("signed = _torch.tanh(self.CMP_SLOPE * diff_r)")
|
|
1222
|
+
self._emit("return self._truth_from_real() @ signed")
|
|
1223
|
+
self._indent -= 1
|
|
1224
|
+
self._emit()
|
|
1225
|
+
# lt / ge / le runtime methods deleted in v0.3 step 4.
|
|
1226
|
+
|
|
1227
|
+
self._emit("# ---- Equality — cosine similarity on tensors ----")
|
|
1228
|
+
self._emit()
|
|
1229
|
+
self._emit("def eq(self, a, b):")
|
|
1230
|
+
self._indent += 1
|
|
1231
|
+
self._emit('"""a == b — cosine similarity, eps-guarded divide, no branch."""')
|
|
1232
|
+
self._emit("av = self._as_any_vector(a)")
|
|
1233
|
+
self._emit("bv = self._as_any_vector(b)")
|
|
1234
|
+
self._emit("na = _torch.sqrt((av * av).sum())")
|
|
1235
|
+
self._emit("nb = _torch.sqrt((bv * bv).sum())")
|
|
1236
|
+
self._emit("cos = (av * bv).sum() / (na * nb + _torch.finfo(self.dtype).tiny)")
|
|
1237
|
+
self._emit("return self.make_truth(float(cos.item()))")
|
|
1238
|
+
self._indent -= 1
|
|
1239
|
+
self._emit()
|
|
1240
|
+
# neq runtime method deleted in v0.3 step 4.
|
|
1241
|
+
|
|
1242
|
+
self._emit("def _as_any_vector(self, x):")
|
|
1243
|
+
self._indent += 1
|
|
1244
|
+
self._emit('"""Coerce any runtime value to a d-dim tensor for comparison."""')
|
|
1245
|
+
self._emit("if isinstance(x, _torch.Tensor):")
|
|
1246
|
+
self._indent += 1
|
|
1247
|
+
self._emit("return x")
|
|
1248
|
+
self._indent -= 1
|
|
1249
|
+
self._emit("if isinstance(x, bool):")
|
|
1250
|
+
self._indent += 1
|
|
1251
|
+
self._emit("return self.make_truth(1.0 if x else -1.0)")
|
|
1252
|
+
self._indent -= 1
|
|
1253
|
+
self._emit("if isinstance(x, (int, float)):")
|
|
1254
|
+
self._indent += 1
|
|
1255
|
+
self._emit("return self.make_real(float(x))")
|
|
1256
|
+
self._indent -= 1
|
|
1257
|
+
self._emit("if isinstance(x, str):")
|
|
1258
|
+
self._indent += 1
|
|
1259
|
+
self._emit("return self.embed(x)")
|
|
1260
|
+
self._indent -= 1
|
|
1261
|
+
self._emit("raise TypeError(f'cannot coerce {type(x).__name__} to a tensor for comparison')")
|
|
1262
|
+
self._indent -= 1
|
|
1263
|
+
self._emit()
|
|
1264
|
+
self._emit("# ---- Defuzzification — torch version ----")
|
|
1265
|
+
self._emit()
|
|
1266
|
+
self._emit("def _truth_projector(self):")
|
|
1267
|
+
self._indent += 1
|
|
1268
|
+
self._emit('"""Diagonal dim×dim projector onto truth axis. Cached tensor."""')
|
|
1269
|
+
self._emit("if not hasattr(self, '_truth_proj_cache') or self._truth_proj_cache is None:")
|
|
1270
|
+
self._indent += 1
|
|
1271
|
+
self._emit("M = _torch.zeros((self.dim, self.dim), dtype=self.dtype, device=self.device)")
|
|
1272
|
+
self._emit("idx = self.semantic_dim + self.AXIS_TRUTH")
|
|
1273
|
+
self._emit("M[idx, idx] = 1.0")
|
|
1274
|
+
self._emit("self._truth_proj_cache = M")
|
|
1275
|
+
self._indent -= 1
|
|
1276
|
+
self._emit("return self._truth_proj_cache")
|
|
1277
|
+
self._indent -= 1
|
|
1278
|
+
self._emit()
|
|
1279
|
+
# defuzzify runtime method deleted in v0.3 step 4. The
|
|
1280
|
+
# `defuzzy(x)` source form is expanded inline by codegen.py's
|
|
1281
|
+
# `_defuzzy_expr_src` into ten nested eq calls (inherited
|
|
1282
|
+
# unchanged here).
|
|
1283
|
+
self._emit()
|
|
1284
|
+
self._emit("def make_random_rotation(self, angle, n_planes=1, seed=None):")
|
|
1285
|
+
self._indent += 1
|
|
1286
|
+
self._emit('"""Block-diagonal Haar rotation, scaled by fractional power.')
|
|
1287
|
+
self._emit('')
|
|
1288
|
+
self._emit("Seeded by numpy's RandomState for deterministic Haar-uniformity;")
|
|
1289
|
+
self._emit("the result is converted to a torch tensor on self.device. Used")
|
|
1290
|
+
self._emit("by eigenrotation loops.")
|
|
1291
|
+
self._emit('"""')
|
|
1292
|
+
self._emit("import numpy as _np_bridge")
|
|
1293
|
+
self._emit("rng = _np_bridge.random.RandomState(seed if seed is not None else self.seed)")
|
|
1294
|
+
self._emit("A = rng.randn(self.semantic_dim, self.semantic_dim)")
|
|
1295
|
+
self._emit("Q_sem_np, _ = _np_bridge.linalg.qr(A)")
|
|
1296
|
+
self._emit("w, V = _np_bridge.linalg.eig(Q_sem_np)")
|
|
1297
|
+
self._emit("phases = _np_bridge.angle(w) * (angle / _np_bridge.pi)")
|
|
1298
|
+
self._emit("R_sem_np = _np_bridge.real((V * _np_bridge.exp(1j * phases)) @ _np_bridge.linalg.inv(V))")
|
|
1299
|
+
self._emit("R_sem = _torch.as_tensor(R_sem_np, dtype=self.dtype, device=self.device)")
|
|
1300
|
+
self._emit("R = _torch.eye(self.dim, dtype=self.dtype, device=self.device)")
|
|
1301
|
+
self._emit("R[:self.semantic_dim, :self.semantic_dim] = R_sem")
|
|
1302
|
+
self._emit("return R")
|
|
1303
|
+
self._indent -= 1
|
|
1304
|
+
self._emit()
|
|
1305
|
+
self._emit("def compile_prototypes(self, prototype_vectors, frame_seed=None):")
|
|
1306
|
+
self._indent += 1
|
|
1307
|
+
self._emit("return dict(prototype_vectors)")
|
|
1308
|
+
self._indent -= 1
|
|
1309
|
+
self._emit()
|
|
1310
|
+
self._emit("def _step(self, state, R, target, halted, k, threshold, eps=1e-12):")
|
|
1311
|
+
self._indent += 1
|
|
1312
|
+
self._emit('"""RNN cell: one branchless eigenrotation step (torch tensor ops)."""')
|
|
1313
|
+
self._emit("cand = R @ state")
|
|
1314
|
+
self._emit("cand = cand / (_torch.linalg.norm(cand) + eps)")
|
|
1315
|
+
self._emit("sim = _torch.dot(cand, target) / (_torch.linalg.norm(target) + eps)")
|
|
1316
|
+
self._emit("halt = 1.0 / (1.0 + _torch.exp(-k * (sim - threshold)))")
|
|
1317
|
+
self._emit("one = _torch.tensor(1.0, dtype=self.dtype, device=self.device)")
|
|
1318
|
+
self._emit("halted = _torch.minimum(halted + halt, one)")
|
|
1319
|
+
self._emit("state = (1.0 - halted) * cand + halted * state")
|
|
1320
|
+
self._emit("return state, halted")
|
|
1321
|
+
self._indent -= 1
|
|
1322
|
+
self._emit()
|
|
1323
|
+
self._emit("def loop(self, initial_state, rotation, compiled_prototypes,")
|
|
1324
|
+
self._indent += 1
|
|
1325
|
+
self._emit("target_name=None, threshold=0.5, max_iters=50, k=20.0, frame_seed=None):")
|
|
1326
|
+
self._emit('"""Branchless RNN-style eigenrotation loop (torch backend).')
|
|
1327
|
+
self._emit('')
|
|
1328
|
+
self._emit("Same semantics as the numpy backend. T-step unroll, soft halt via")
|
|
1329
|
+
self._emit("sigmoid, output gating via AXIS_LOOP_DONE. Autograd-friendly:")
|
|
1330
|
+
self._emit("every op is differentiable with respect to state, target, threshold.")
|
|
1331
|
+
self._emit('"""')
|
|
1332
|
+
self._emit("state = initial_state.clone()")
|
|
1333
|
+
self._emit("halted = _torch.tensor(0.0, dtype=self.dtype, device=self.device)")
|
|
1334
|
+
self._emit("iters_active = _torch.tensor(0.0, dtype=self.dtype, device=self.device)")
|
|
1335
|
+
self._emit("if target_name is not None:")
|
|
1336
|
+
self._indent += 1
|
|
1337
|
+
self._emit("target = compiled_prototypes[target_name]")
|
|
1338
|
+
self._indent -= 1
|
|
1339
|
+
self._emit("else:")
|
|
1340
|
+
self._indent += 1
|
|
1341
|
+
self._emit("target = next(iter(compiled_prototypes.values()))")
|
|
1342
|
+
self._indent -= 1
|
|
1343
|
+
self._emit("for _t in range(max_iters):")
|
|
1344
|
+
self._indent += 1
|
|
1345
|
+
self._emit("iters_active = iters_active + (1.0 - halted)")
|
|
1346
|
+
self._emit("state, halted = self._step(state, rotation, target, halted, k, threshold)")
|
|
1347
|
+
self._indent -= 1
|
|
1348
|
+
self._emit("# Output gating: scale value axes by halted; mark AXIS_LOOP_DONE.")
|
|
1349
|
+
self._emit("gated = state * halted")
|
|
1350
|
+
self._emit("gated[self.semantic_dim + self.AXIS_LOOP_DONE] = halted")
|
|
1351
|
+
self._emit("return target_name, gated, iters_active")
|
|
1352
|
+
self._indent -= 1
|
|
1353
|
+
self._indent -= 1
|
|
1354
|
+
self._emit()
|
|
1355
|
+
self._emit()
|
|
1356
|
+
self._emit(
|
|
1357
|
+
f"_VSA = _TorchVSA("
|
|
1358
|
+
f"semantic_dim={self._semantic_dim}, "
|
|
1359
|
+
f"synthetic_dim={self._synthetic_dim}, "
|
|
1360
|
+
f"seed={self.runtime_seed}, "
|
|
1361
|
+
f"llm_model={self._llm_model!r})"
|
|
1362
|
+
)
|
|
1363
|
+
if self._prefetch_strings:
|
|
1364
|
+
self._emit(f"_VSA.embed_batch({self._prefetch_strings!r})")
|
|
1365
|
+
# Compile-time SutraDB population (queue item 2). Every embedded
|
|
1366
|
+
# string in the program is now in the SutraDB codebook and
|
|
1367
|
+
# decodable via _VSA.nearest_string. Strings declared but not
|
|
1368
|
+
# used in expressions are still in the prefetch list and so
|
|
1369
|
+
# still get inserted; they're available for decode even though
|
|
1370
|
+
# no expression in the program references them.
|
|
1371
|
+
self._emit("_VSA.populate_sutradb()")
|
|
1372
|
+
# Compile-time rotation pre-warm (queue item 3). Conservatively
|
|
1373
|
+
# pre-warms a rotation matrix for every codebook entry so the
|
|
1374
|
+
# runtime never pays the QR cost on the hot path. Over-warms
|
|
1375
|
+
# for fillers that aren't ever used as roles, but the cost is
|
|
1376
|
+
# one-time and proportional to the codebook size which is
|
|
1377
|
+
# small for typical programs. A targeted "scan for bind() role
|
|
1378
|
+
# args only" pass would be a future optimization.
|
|
1379
|
+
self._emit("_VSA.prewarm_rotation_cache()")
|
|
1380
|
+
self._emit()
|
|
1381
|
+
self._emit()
|
|
1382
|
+
self._emit("def _argmax_cosine(query, candidates):")
|
|
1383
|
+
self._indent += 1
|
|
1384
|
+
self._emit('"""Vectorized cosine argmax on torch tensors.')
|
|
1385
|
+
self._emit('')
|
|
1386
|
+
self._emit("Stacks candidates into (N, d), computes all N cosines as one")
|
|
1387
|
+
self._emit("matmul against the query, returns the candidate at the argmax.")
|
|
1388
|
+
self._emit("This is the GPU-shaped form: O(1) big kernel, not O(N) small ones.")
|
|
1389
|
+
self._emit("")
|
|
1390
|
+
self._emit("Note: SutraDB integration (queue item 2) does NOT route through")
|
|
1391
|
+
self._emit("here — see _VSA.nearest_string for the embedded-DB decode path.")
|
|
1392
|
+
self._emit("argmax_cosine takes a runtime candidate-vector list; SutraDB is")
|
|
1393
|
+
self._emit("the compile-time-populated string-to-embedding store.")
|
|
1394
|
+
self._emit('"""')
|
|
1395
|
+
self._emit("if not candidates:")
|
|
1396
|
+
self._indent += 1
|
|
1397
|
+
self._emit("return None")
|
|
1398
|
+
self._indent -= 1
|
|
1399
|
+
self._emit("M = _torch.stack([")
|
|
1400
|
+
self._indent += 1
|
|
1401
|
+
self._emit("_torch.as_tensor(c, dtype=_DTYPE, device=_DEVICE)")
|
|
1402
|
+
self._emit("for c in candidates")
|
|
1403
|
+
self._indent -= 1
|
|
1404
|
+
self._emit("])")
|
|
1405
|
+
self._emit("q = _torch.as_tensor(query, dtype=_DTYPE, device=_DEVICE)")
|
|
1406
|
+
self._emit("row_norms = _torch.linalg.norm(M, dim=1)")
|
|
1407
|
+
self._emit("q_norm = _torch.linalg.norm(q)")
|
|
1408
|
+
self._emit("if float(q_norm) == 0:")
|
|
1409
|
+
self._indent += 1
|
|
1410
|
+
self._emit("return candidates[0]")
|
|
1411
|
+
self._indent -= 1
|
|
1412
|
+
self._emit("safe_rn = _torch.where(row_norms > 0, row_norms, _torch.ones_like(row_norms))")
|
|
1413
|
+
self._emit("scores = (M @ q) / (safe_rn * q_norm)")
|
|
1414
|
+
self._emit("neg_inf = _torch.full_like(scores, float('-inf'))")
|
|
1415
|
+
self._emit("scores = _torch.where(row_norms > 0, scores, neg_inf)")
|
|
1416
|
+
self._emit("return candidates[int(_torch.argmax(scores).item())]")
|
|
1417
|
+
self._indent -= 1
|
|
1418
|
+
self._emit()
|
|
1419
|
+
self._emit()
|
|
1420
|
+
self._emit_select_helper()
|
|
1421
|
+
self._emit()
|
|
1422
|
+
self._emit("def _vector_map_lookup(pairs, key):")
|
|
1423
|
+
self._indent += 1
|
|
1424
|
+
self._emit('"""Identity-first lookup for vector-keyed maps, cosine fallback."""')
|
|
1425
|
+
self._emit("for k, v in pairs:")
|
|
1426
|
+
self._indent += 1
|
|
1427
|
+
self._emit("if k is key:")
|
|
1428
|
+
self._indent += 1
|
|
1429
|
+
self._emit("return v")
|
|
1430
|
+
self._indent -= 1
|
|
1431
|
+
self._indent -= 1
|
|
1432
|
+
self._emit("if not pairs:")
|
|
1433
|
+
self._indent += 1
|
|
1434
|
+
self._emit("return None")
|
|
1435
|
+
self._indent -= 1
|
|
1436
|
+
self._emit("keys = _torch.stack([")
|
|
1437
|
+
self._indent += 1
|
|
1438
|
+
self._emit("_torch.as_tensor(k, dtype=_DTYPE, device=_DEVICE)")
|
|
1439
|
+
self._emit("for k, _ in pairs")
|
|
1440
|
+
self._indent -= 1
|
|
1441
|
+
self._emit("])")
|
|
1442
|
+
self._emit("q = _torch.as_tensor(key, dtype=_DTYPE, device=_DEVICE)")
|
|
1443
|
+
self._emit("row_norms = _torch.linalg.norm(keys, dim=1)")
|
|
1444
|
+
self._emit("q_norm = _torch.linalg.norm(q)")
|
|
1445
|
+
self._emit("if float(q_norm) == 0:")
|
|
1446
|
+
self._indent += 1
|
|
1447
|
+
self._emit("return pairs[0][1]")
|
|
1448
|
+
self._indent -= 1
|
|
1449
|
+
self._emit("safe_rn = _torch.where(row_norms > 0, row_norms, _torch.ones_like(row_norms))")
|
|
1450
|
+
self._emit("scores = (keys @ q) / (safe_rn * q_norm)")
|
|
1451
|
+
self._emit("neg_inf = _torch.full_like(scores, float('-inf'))")
|
|
1452
|
+
self._emit("scores = _torch.where(row_norms > 0, scores, neg_inf)")
|
|
1453
|
+
self._emit("return pairs[int(_torch.argmax(scores).item())][1]")
|
|
1454
|
+
self._indent -= 1
|
|
1455
|
+
|
|
1456
|
+
|
|
1457
|
+
def translate_module(module: ast.Module, **kwargs) -> str:
|
|
1458
|
+
"""Translate a parsed Sutra module to self-contained torch Python.
|
|
1459
|
+
|
|
1460
|
+
Same simplify + prefetch-collection pass as the numpy backend, so
|
|
1461
|
+
the torch backend benefits from every algebraic rewrite and the
|
|
1462
|
+
batched Ollama pre-fetch without duplicating that infrastructure.
|
|
1463
|
+
"""
|
|
1464
|
+
from .simplify import simplify_module, collect_basis_vector_strings
|
|
1465
|
+
from .inliner import inline_stdlib_calls
|
|
1466
|
+
# Inline stdlib calls first — same pass as the CPU codegen uses.
|
|
1467
|
+
inline_stdlib_calls(module)
|
|
1468
|
+
simplify_module(module)
|
|
1469
|
+
strings = collect_basis_vector_strings(module)
|
|
1470
|
+
cg = PyTorchCodegen(**kwargs)
|
|
1471
|
+
cg._prefetch_strings = strings
|
|
1472
|
+
return cg.translate(module)
|