sutra-dev 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sutra_compiler/__init__.py +49 -0
- sutra_compiler/__main__.py +514 -0
- sutra_compiler/ast_nodes.py +553 -0
- sutra_compiler/codegen.py +1811 -0
- sutra_compiler/codegen_base.py +2436 -0
- sutra_compiler/codegen_pytorch.py +1472 -0
- sutra_compiler/diagnostics.py +145 -0
- sutra_compiler/inliner.py +581 -0
- sutra_compiler/lexer.py +821 -0
- sutra_compiler/parser.py +2112 -0
- sutra_compiler/review.py +322 -0
- sutra_compiler/simplify.py +1046 -0
- sutra_compiler/simplify_egglog.py +674 -0
- sutra_compiler/stdlib/axons.su +53 -0
- sutra_compiler/stdlib/embed.su +48 -0
- sutra_compiler/stdlib/javascript_object.su +18 -0
- sutra_compiler/stdlib/logic.su +202 -0
- sutra_compiler/stdlib/math.su +12 -0
- sutra_compiler/stdlib/memory.su +82 -0
- sutra_compiler/stdlib/numbers.su +99 -0
- sutra_compiler/stdlib/rotation.su +83 -0
- sutra_compiler/stdlib/similarity.su +97 -0
- sutra_compiler/stdlib/strings.su +56 -0
- sutra_compiler/stdlib/tensor.su +82 -0
- sutra_compiler/stdlib/vectors.su +119 -0
- sutra_compiler/stdlib_loader.py +219 -0
- sutra_compiler/sutradb_embedded.py +273 -0
- sutra_compiler/trace.py +135 -0
- sutra_compiler/validator.py +552 -0
- sutra_compiler/workspace.py +655 -0
- sutra_dev-0.2.0.dist-info/METADATA +80 -0
- sutra_dev-0.2.0.dist-info/RECORD +36 -0
- sutra_dev-0.2.0.dist-info/WHEEL +5 -0
- sutra_dev-0.2.0.dist-info/entry_points.txt +2 -0
- sutra_dev-0.2.0.dist-info/licenses/LICENSE +201 -0
- sutra_dev-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Sutra language compiler / validator.
|
|
2
|
+
|
|
3
|
+
This package implements the first pass of the Sutra SDK: a lexer,
|
|
4
|
+
parser, and syntactic validator for `.su` source files.
|
|
5
|
+
|
|
6
|
+
Scope (v0.1):
|
|
7
|
+
- Full tokenization of Sutra source (all comment forms, string
|
|
8
|
+
interpolation, numeric literals, identifiers, operators).
|
|
9
|
+
- Recursive-descent parser that recognizes the declaration and
|
|
10
|
+
statement grammar described in planning/sutra-spec/. The
|
|
11
|
+
historical predecessor — sutra-syntax-decisions.md, the
|
|
12
|
+
rolling decisions log that preceded the formal spec — lives
|
|
13
|
+
under planning/sutra-spec-deprecated/ as read-only reference.
|
|
14
|
+
- Structural validation: balanced brackets, semicolons where the
|
|
15
|
+
grammar requires them, well-formed declarations and control flow.
|
|
16
|
+
- A small set of rule checks that the syntax-decisions doc makes
|
|
17
|
+
explicit (e.g. `var TYPE x` is forbidden, `if (...)` requires
|
|
18
|
+
parentheses, a bare identifier cannot be used as a condition).
|
|
19
|
+
|
|
20
|
+
Out of scope for v0.1:
|
|
21
|
+
- Type checking
|
|
22
|
+
- Name resolution across files
|
|
23
|
+
- Code generation / runtime lowering
|
|
24
|
+
- Cross-file workspace analysis
|
|
25
|
+
|
|
26
|
+
The compiler is intentionally liberal where the spec is still open
|
|
27
|
+
(anonymous functions, pipe operator, etc.) - it accepts the documented
|
|
28
|
+
forms and flags the clearly-forbidden ones.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
__version__ = "0.2.0"
|
|
32
|
+
|
|
33
|
+
from .diagnostics import Diagnostic, DiagnosticLevel, DiagnosticBag
|
|
34
|
+
from .lexer import Lexer, Token, TokenKind
|
|
35
|
+
from .parser import Parser
|
|
36
|
+
from .validator import validate_source, validate_file
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"Diagnostic",
|
|
40
|
+
"DiagnosticLevel",
|
|
41
|
+
"DiagnosticBag",
|
|
42
|
+
"Lexer",
|
|
43
|
+
"Token",
|
|
44
|
+
"TokenKind",
|
|
45
|
+
"Parser",
|
|
46
|
+
"validate_source",
|
|
47
|
+
"validate_file",
|
|
48
|
+
"__version__",
|
|
49
|
+
]
|
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
"""Command-line entry point for the Sutra compiler/validator.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
|
|
5
|
+
python -m sutra_compiler FILE [FILE ...]
|
|
6
|
+
python -m sutra_compiler --json FILE
|
|
7
|
+
python -m sutra_compiler --summary DIR_OR_FILE [...]
|
|
8
|
+
|
|
9
|
+
The CLI lexes, parses, and validates each `.su` file and prints any
|
|
10
|
+
diagnostics in `file:line:col: level: message` form — the same shape
|
|
11
|
+
every major compiler and every editor knows how to parse.
|
|
12
|
+
|
|
13
|
+
Exit code is 0 if no errors were reported, 1 otherwise.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import json
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
from typing import List
|
|
23
|
+
|
|
24
|
+
from . import __version__
|
|
25
|
+
from . import ast_nodes as ast
|
|
26
|
+
from .codegen_pytorch import translate_module as translate_pytorch
|
|
27
|
+
from .diagnostics import Diagnostic, DiagnosticLevel
|
|
28
|
+
from .lexer import Lexer
|
|
29
|
+
from .parser import Parser
|
|
30
|
+
from .validator import validate_file, _Walker, _check_pipe_forward
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _iter_akasha_files(paths: List[str]) -> List[str]:
|
|
34
|
+
"""Expand a list of files/directories into a flat list of `.su`
|
|
35
|
+
files. Non-existent paths are left to the caller to report."""
|
|
36
|
+
out: List[str] = []
|
|
37
|
+
for p in paths:
|
|
38
|
+
if os.path.isdir(p):
|
|
39
|
+
for root, _, files in os.walk(p):
|
|
40
|
+
for f in sorted(files):
|
|
41
|
+
if f.endswith(".su"):
|
|
42
|
+
out.append(os.path.join(root, f))
|
|
43
|
+
else:
|
|
44
|
+
out.append(p)
|
|
45
|
+
return out
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _diag_to_dict(d: Diagnostic) -> dict:
|
|
49
|
+
return {
|
|
50
|
+
"file": d.file,
|
|
51
|
+
"line": d.span.start.line,
|
|
52
|
+
"column": d.span.start.column,
|
|
53
|
+
"end_line": d.span.end.line,
|
|
54
|
+
"end_column": d.span.end.column,
|
|
55
|
+
"level": d.level.value,
|
|
56
|
+
"code": d.code,
|
|
57
|
+
"message": d.message,
|
|
58
|
+
"hint": d.hint,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _run_text(paths: List[str], *, summary: bool) -> int:
|
|
63
|
+
files = _iter_akasha_files(paths)
|
|
64
|
+
total_errors = 0
|
|
65
|
+
total_warnings = 0
|
|
66
|
+
per_file = []
|
|
67
|
+
for f in files:
|
|
68
|
+
if not os.path.exists(f):
|
|
69
|
+
print(f"{f}: error: file not found", file=sys.stderr)
|
|
70
|
+
total_errors += 1
|
|
71
|
+
continue
|
|
72
|
+
bag = validate_file(f)
|
|
73
|
+
n_err = len(bag.errors)
|
|
74
|
+
n_warn = len(bag.warnings)
|
|
75
|
+
total_errors += n_err
|
|
76
|
+
total_warnings += n_warn
|
|
77
|
+
per_file.append((f, n_err, n_warn))
|
|
78
|
+
if not summary:
|
|
79
|
+
for d in bag:
|
|
80
|
+
print(d.format())
|
|
81
|
+
if summary:
|
|
82
|
+
width = max((len(f) for f, _, _ in per_file), default=0)
|
|
83
|
+
print(f"{'file'.ljust(width)} errors warnings")
|
|
84
|
+
print("-" * (width + 20))
|
|
85
|
+
for f, e, w in per_file:
|
|
86
|
+
print(f"{f.ljust(width)} {e:6d} {w:8d}")
|
|
87
|
+
print("-" * (width + 20))
|
|
88
|
+
print(f"{'total'.ljust(width)} {total_errors:6d} {total_warnings:8d}")
|
|
89
|
+
else:
|
|
90
|
+
if total_errors == 0 and total_warnings == 0:
|
|
91
|
+
print(f"ok: {len(files)} file(s) validated, 0 diagnostics")
|
|
92
|
+
else:
|
|
93
|
+
print(
|
|
94
|
+
f"done: {len(files)} file(s) validated, "
|
|
95
|
+
f"{total_errors} error(s), {total_warnings} warning(s)"
|
|
96
|
+
)
|
|
97
|
+
return 1 if total_errors else 0
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _run_json(paths: List[str]) -> int:
|
|
101
|
+
files = _iter_akasha_files(paths)
|
|
102
|
+
out = []
|
|
103
|
+
total_errors = 0
|
|
104
|
+
for f in files:
|
|
105
|
+
entry = {"file": f, "diagnostics": []}
|
|
106
|
+
if not os.path.exists(f):
|
|
107
|
+
entry["diagnostics"].append(
|
|
108
|
+
{
|
|
109
|
+
"file": f,
|
|
110
|
+
"line": 1,
|
|
111
|
+
"column": 1,
|
|
112
|
+
"end_line": 1,
|
|
113
|
+
"end_column": 1,
|
|
114
|
+
"level": "error",
|
|
115
|
+
"code": "SUT9999",
|
|
116
|
+
"message": "file not found",
|
|
117
|
+
"hint": None,
|
|
118
|
+
}
|
|
119
|
+
)
|
|
120
|
+
total_errors += 1
|
|
121
|
+
out.append(entry)
|
|
122
|
+
continue
|
|
123
|
+
bag = validate_file(f)
|
|
124
|
+
for d in bag:
|
|
125
|
+
entry["diagnostics"].append(_diag_to_dict(d))
|
|
126
|
+
total_errors += len(bag.errors)
|
|
127
|
+
out.append(entry)
|
|
128
|
+
json.dump({"files": out, "version": __version__}, sys.stdout, indent=2)
|
|
129
|
+
sys.stdout.write("\n")
|
|
130
|
+
return 1 if total_errors else 0
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _run_consistency(paths: List[str]) -> int:
|
|
134
|
+
"""Cross-file class-name casing check.
|
|
135
|
+
|
|
136
|
+
For each non-primitive type name that appears across the file set,
|
|
137
|
+
report every distinct casing and the files it appears in. This
|
|
138
|
+
flags drift like `animal` vs `Animal` across the repo.
|
|
139
|
+
"""
|
|
140
|
+
files = _iter_akasha_files(paths)
|
|
141
|
+
# name_lower -> { casing -> set of files }
|
|
142
|
+
usages: dict = {}
|
|
143
|
+
for f in files:
|
|
144
|
+
if not os.path.exists(f):
|
|
145
|
+
print(f"{f}: error: file not found", file=sys.stderr)
|
|
146
|
+
continue
|
|
147
|
+
with open(f, encoding="utf-8") as fp:
|
|
148
|
+
src = fp.read()
|
|
149
|
+
lexer = Lexer(src, file=f)
|
|
150
|
+
tokens = lexer.tokenize()
|
|
151
|
+
parser = Parser(tokens, file=f, diagnostics=lexer.diagnostics)
|
|
152
|
+
module = parser.parse_module()
|
|
153
|
+
walker = _Walker(lexer.diagnostics)
|
|
154
|
+
# Walk just the declarations to collect type-name usages.
|
|
155
|
+
for item in module.items:
|
|
156
|
+
walker.visit(item)
|
|
157
|
+
for name in walker._class_name_usages:
|
|
158
|
+
entry = usages.setdefault(name.lower(), {})
|
|
159
|
+
entry.setdefault(name, set()).add(f)
|
|
160
|
+
|
|
161
|
+
drift_count = 0
|
|
162
|
+
print("Cross-file class-name casing check")
|
|
163
|
+
print("=" * 60)
|
|
164
|
+
for lower_name, casings in sorted(usages.items()):
|
|
165
|
+
if len(casings) < 2:
|
|
166
|
+
continue
|
|
167
|
+
drift_count += 1
|
|
168
|
+
print(f"\n DRIFT: {lower_name} appears in {len(casings)} casings")
|
|
169
|
+
for casing in sorted(casings.keys()):
|
|
170
|
+
file_list = sorted(casings[casing])
|
|
171
|
+
print(f" `{casing}`")
|
|
172
|
+
for f in file_list:
|
|
173
|
+
print(f" {f}")
|
|
174
|
+
if drift_count == 0:
|
|
175
|
+
print("\n no cross-file casing drift detected")
|
|
176
|
+
else:
|
|
177
|
+
print(f"\n{drift_count} class name(s) with casing drift across the file set")
|
|
178
|
+
return 1 if drift_count else 0
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _read_atman_loop_T(source_path: str) -> int | None:
|
|
182
|
+
"""Walk up from the .su source file looking for an atman.toml that
|
|
183
|
+
declares `[project.compile] loop_max_iterations = N`. Returns N if
|
|
184
|
+
found, else None.
|
|
185
|
+
"""
|
|
186
|
+
try:
|
|
187
|
+
import tomllib # py3.11+
|
|
188
|
+
except ImportError:
|
|
189
|
+
try:
|
|
190
|
+
import tomli as tomllib # type: ignore
|
|
191
|
+
except ImportError:
|
|
192
|
+
return None
|
|
193
|
+
cur = os.path.dirname(os.path.abspath(source_path))
|
|
194
|
+
while True:
|
|
195
|
+
candidate = os.path.join(cur, "atman.toml")
|
|
196
|
+
if os.path.isfile(candidate):
|
|
197
|
+
try:
|
|
198
|
+
with open(candidate, "rb") as fp:
|
|
199
|
+
data = tomllib.load(fp)
|
|
200
|
+
except Exception:
|
|
201
|
+
return None
|
|
202
|
+
v = (data.get("project", {})
|
|
203
|
+
.get("compile", {})
|
|
204
|
+
.get("loop_max_iterations"))
|
|
205
|
+
if isinstance(v, int) and v > 0:
|
|
206
|
+
return v
|
|
207
|
+
return None
|
|
208
|
+
parent = os.path.dirname(cur)
|
|
209
|
+
if parent == cur:
|
|
210
|
+
return None
|
|
211
|
+
cur = parent
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _compile_to_python(path: str, *, runtime_dim: int,
|
|
215
|
+
runtime_seed: int,
|
|
216
|
+
loop_T: int | None = None) -> str | None:
|
|
217
|
+
"""Validate + parse + codegen one .su file. Returns generated Python
|
|
218
|
+
source, or None on failure (diagnostics already printed).
|
|
219
|
+
|
|
220
|
+
`loop_T` resolution: if the caller passes an explicit value, use it.
|
|
221
|
+
Else, walk up from the source file looking for `atman.toml` with a
|
|
222
|
+
`[project.compile] loop_max_iterations` field. Else default to 50.
|
|
223
|
+
"""
|
|
224
|
+
if not os.path.exists(path):
|
|
225
|
+
print(f"{path}: error: file not found", file=sys.stderr)
|
|
226
|
+
return None
|
|
227
|
+
bag = validate_file(path)
|
|
228
|
+
if bag.errors:
|
|
229
|
+
for d in bag:
|
|
230
|
+
print(d.format(), file=sys.stderr)
|
|
231
|
+
return None
|
|
232
|
+
with open(path, encoding="utf-8") as fp:
|
|
233
|
+
src = fp.read()
|
|
234
|
+
lexer = Lexer(src, file=path)
|
|
235
|
+
tokens = lexer.tokenize()
|
|
236
|
+
parser = Parser(tokens, file=path, diagnostics=lexer.diagnostics)
|
|
237
|
+
module = parser.parse_module()
|
|
238
|
+
if loop_T is None:
|
|
239
|
+
loop_T = _read_atman_loop_T(path) or 50
|
|
240
|
+
return translate_pytorch(
|
|
241
|
+
module, runtime_dim=runtime_dim, runtime_seed=runtime_seed,
|
|
242
|
+
loop_max_iterations=loop_T,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _run_execute(path: str, *, runtime_dim: int, runtime_seed: int,
|
|
247
|
+
loop_T: int | None = None) -> int:
|
|
248
|
+
"""Compile a .su file with the PyTorch codegen and exec the generated
|
|
249
|
+
module. A `main()` function in the module, if present, is called and
|
|
250
|
+
its return value is printed; otherwise the module's top-level prints
|
|
251
|
+
carry the output. Requires `torch` to be importable at runtime."""
|
|
252
|
+
import types
|
|
253
|
+
py_src = _compile_to_python(
|
|
254
|
+
path, runtime_dim=runtime_dim, runtime_seed=runtime_seed,
|
|
255
|
+
loop_T=loop_T,
|
|
256
|
+
)
|
|
257
|
+
if py_src is None:
|
|
258
|
+
return 1
|
|
259
|
+
mod = types.ModuleType("_sutra_run")
|
|
260
|
+
mod.__file__ = f"<generated from {path}>"
|
|
261
|
+
exec(compile(py_src, mod.__file__, "exec"), mod.__dict__)
|
|
262
|
+
if hasattr(mod, "main") and callable(mod.main):
|
|
263
|
+
result = mod.main()
|
|
264
|
+
if result is not None:
|
|
265
|
+
print(result)
|
|
266
|
+
return 0
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _run_viz(path: str, *, runtime_dim: int, runtime_seed: int,
|
|
270
|
+
loop_T: int | None = None,
|
|
271
|
+
output_html: str | None = None) -> int:
|
|
272
|
+
"""Compile, execute with tracing, and output a 3D visualization HTML.
|
|
273
|
+
|
|
274
|
+
Strategy: inject a tracing shim into the generated Python source that
|
|
275
|
+
wraps every _VSA method. This way tracing is active from the first
|
|
276
|
+
embed() call during module-level init.
|
|
277
|
+
"""
|
|
278
|
+
import types
|
|
279
|
+
from .trace import SutraTracer
|
|
280
|
+
|
|
281
|
+
py_src = _compile_to_python(
|
|
282
|
+
path, runtime_dim=runtime_dim, runtime_seed=runtime_seed,
|
|
283
|
+
loop_T=loop_T,
|
|
284
|
+
)
|
|
285
|
+
if py_src is None:
|
|
286
|
+
return 1
|
|
287
|
+
|
|
288
|
+
program_name = os.path.basename(path)
|
|
289
|
+
tracer = SutraTracer(program_name)
|
|
290
|
+
|
|
291
|
+
# Inject tracing shim: after the _VSA = _TorchVSA(...) line,
|
|
292
|
+
# wrap every method with a tracing version.
|
|
293
|
+
shim = '''
|
|
294
|
+
# ── Tracing shim (injected by --run-viz) ──
|
|
295
|
+
_orig_embed = _VSA.embed
|
|
296
|
+
_orig_bind = _VSA.bind
|
|
297
|
+
_orig_unbind = _VSA.unbind
|
|
298
|
+
_orig_bundle = _VSA.bundle
|
|
299
|
+
|
|
300
|
+
def _traced_embed(name):
|
|
301
|
+
v = _orig_embed(name)
|
|
302
|
+
_tracer.record_vector(name, v, "basis")
|
|
303
|
+
return v
|
|
304
|
+
|
|
305
|
+
def _traced_bind(a, b):
|
|
306
|
+
result = _orig_bind(a, b)
|
|
307
|
+
_tracer.record_op("bind", [a, b], result)
|
|
308
|
+
return result
|
|
309
|
+
|
|
310
|
+
def _traced_unbind(role, bound):
|
|
311
|
+
result = _orig_unbind(role, bound)
|
|
312
|
+
_tracer.record_op("unbind", [role, bound], result)
|
|
313
|
+
return result
|
|
314
|
+
|
|
315
|
+
def _traced_bundle(*vectors):
|
|
316
|
+
result = _orig_bundle(*vectors)
|
|
317
|
+
_tracer.record_op("bundle", list(vectors), result)
|
|
318
|
+
return result
|
|
319
|
+
|
|
320
|
+
_VSA.embed = _traced_embed
|
|
321
|
+
_VSA.bind = _traced_bind
|
|
322
|
+
_VSA.unbind = _traced_unbind
|
|
323
|
+
_VSA.bundle = _traced_bundle
|
|
324
|
+
# ── End tracing shim ──
|
|
325
|
+
'''
|
|
326
|
+
# Find the _VSA = _TorchVSA(...) line and inject after it
|
|
327
|
+
lines = py_src.split('\n')
|
|
328
|
+
inject_idx = None
|
|
329
|
+
for i, line in enumerate(lines):
|
|
330
|
+
if line.strip().startswith('_VSA = _TorchVSA('):
|
|
331
|
+
inject_idx = i + 1
|
|
332
|
+
break
|
|
333
|
+
|
|
334
|
+
if inject_idx is None:
|
|
335
|
+
print("warning: could not find _VSA init line for tracing", file=sys.stderr)
|
|
336
|
+
inject_idx = len(lines)
|
|
337
|
+
|
|
338
|
+
lines.insert(inject_idx, shim)
|
|
339
|
+
traced_src = '\n'.join(lines)
|
|
340
|
+
|
|
341
|
+
# Execute with tracer in the namespace
|
|
342
|
+
ns = {"_tracer": tracer}
|
|
343
|
+
exec(compile(traced_src, f"<traced {path}>", "exec"), ns)
|
|
344
|
+
|
|
345
|
+
# Run main if it exists
|
|
346
|
+
if "main" in ns and callable(ns["main"]):
|
|
347
|
+
result = ns["main"]()
|
|
348
|
+
if result is not None:
|
|
349
|
+
print(result)
|
|
350
|
+
|
|
351
|
+
# Generate output HTML
|
|
352
|
+
if output_html is None:
|
|
353
|
+
output_html = os.path.splitext(path)[0] + "_viz.html"
|
|
354
|
+
|
|
355
|
+
html = tracer.to_html()
|
|
356
|
+
with open(output_html, "w", encoding="utf-8") as f:
|
|
357
|
+
f.write(html)
|
|
358
|
+
print(f"\n3D visualization written to: {output_html}", file=sys.stderr)
|
|
359
|
+
|
|
360
|
+
# Also write trace JSON for the VS Code extension
|
|
361
|
+
trace_json = os.path.splitext(path)[0] + "_trace.json"
|
|
362
|
+
with open(trace_json, "w", encoding="utf-8") as f:
|
|
363
|
+
f.write(tracer.to_json())
|
|
364
|
+
|
|
365
|
+
return 0
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _run_emit(path: str, *, runtime_dim: int, runtime_seed: int,
|
|
369
|
+
loop_T: int | None = None) -> int:
|
|
370
|
+
out = _compile_to_python(
|
|
371
|
+
path, runtime_dim=runtime_dim, runtime_seed=runtime_seed,
|
|
372
|
+
loop_T=loop_T,
|
|
373
|
+
)
|
|
374
|
+
if out is None:
|
|
375
|
+
return 1
|
|
376
|
+
sys.stdout.write(out)
|
|
377
|
+
return 0
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def main(argv: List[str] | None = None) -> int:
|
|
381
|
+
parser = argparse.ArgumentParser(
|
|
382
|
+
prog="sutrac",
|
|
383
|
+
description="Validate Sutra (.su) source files.",
|
|
384
|
+
)
|
|
385
|
+
parser.add_argument(
|
|
386
|
+
"paths",
|
|
387
|
+
nargs="+",
|
|
388
|
+
help="Files or directories to validate. Directories are walked recursively.",
|
|
389
|
+
)
|
|
390
|
+
parser.add_argument(
|
|
391
|
+
"--json",
|
|
392
|
+
action="store_true",
|
|
393
|
+
help="Emit machine-readable diagnostics as JSON. For editors and language servers.",
|
|
394
|
+
)
|
|
395
|
+
parser.add_argument(
|
|
396
|
+
"--summary",
|
|
397
|
+
action="store_true",
|
|
398
|
+
help="Print a per-file summary table instead of individual diagnostics.",
|
|
399
|
+
)
|
|
400
|
+
parser.add_argument(
|
|
401
|
+
"--consistency",
|
|
402
|
+
action="store_true",
|
|
403
|
+
help="Cross-file check: report class names that appear in multiple casings across the file set.",
|
|
404
|
+
)
|
|
405
|
+
parser.add_argument(
|
|
406
|
+
"--emit",
|
|
407
|
+
action="store_true",
|
|
408
|
+
help=(
|
|
409
|
+
"Compile the first input file to self-contained torch Python and "
|
|
410
|
+
"print it to stdout. Picks CUDA at module init if available; "
|
|
411
|
+
"falls back to CPU otherwise. This is the one main codegen target — "
|
|
412
|
+
"PyTorch is the runtime and the tensor-op library Sutra compiles "
|
|
413
|
+
"against."
|
|
414
|
+
),
|
|
415
|
+
)
|
|
416
|
+
parser.add_argument(
|
|
417
|
+
"--run",
|
|
418
|
+
action="store_true",
|
|
419
|
+
help=(
|
|
420
|
+
"Compile and execute the first input file (PyTorch backend) in "
|
|
421
|
+
"one step. Captures and prints whatever the generated module "
|
|
422
|
+
"prints. Requires torch to be importable."
|
|
423
|
+
),
|
|
424
|
+
)
|
|
425
|
+
parser.add_argument(
|
|
426
|
+
"--run-viz",
|
|
427
|
+
action="store_true",
|
|
428
|
+
help=(
|
|
429
|
+
"Compile and execute with tracing, then generate a standalone "
|
|
430
|
+
"Three.js 3D visualization HTML alongside the program output."
|
|
431
|
+
),
|
|
432
|
+
)
|
|
433
|
+
parser.add_argument(
|
|
434
|
+
"--review",
|
|
435
|
+
action="store_true",
|
|
436
|
+
help=(
|
|
437
|
+
"Step-by-step review mode: show source, parsed AST, "
|
|
438
|
+
"inlined AST, every simplification rewrite that fires "
|
|
439
|
+
"(before/after), final simplified AST, and emitted Python. "
|
|
440
|
+
"For debugging and teaching."
|
|
441
|
+
),
|
|
442
|
+
)
|
|
443
|
+
parser.add_argument(
|
|
444
|
+
"--runtime-dim", type=int, default=50,
|
|
445
|
+
help="Hypervector dimension for the emitted runtime (default 50).",
|
|
446
|
+
)
|
|
447
|
+
parser.add_argument(
|
|
448
|
+
"--runtime-seed", type=int, default=42,
|
|
449
|
+
help="Random seed for the emitted runtime (default 42).",
|
|
450
|
+
)
|
|
451
|
+
parser.add_argument(
|
|
452
|
+
"--loop-T", type=int, default=None,
|
|
453
|
+
help=(
|
|
454
|
+
"Maximum compile-time loop unroll depth (T) for "
|
|
455
|
+
"tail-recursive loop functions and the soft-halt RNN cell. "
|
|
456
|
+
"If unset, the compiler reads the value from the nearest "
|
|
457
|
+
"[project.compile] loop_max_iterations field in atman.toml, "
|
|
458
|
+
"and falls back to 50 if no manifest declares it. The "
|
|
459
|
+
"soft-halt cell freezes state once halt-cum saturates, so "
|
|
460
|
+
"larger T costs only a longer emitted graph, not extra "
|
|
461
|
+
"runtime work."
|
|
462
|
+
),
|
|
463
|
+
)
|
|
464
|
+
parser.add_argument(
|
|
465
|
+
"--version",
|
|
466
|
+
action="version",
|
|
467
|
+
version=f"sutrac {__version__}",
|
|
468
|
+
)
|
|
469
|
+
args = parser.parse_args(argv)
|
|
470
|
+
if args.review:
|
|
471
|
+
if len(args.paths) != 1:
|
|
472
|
+
print(
|
|
473
|
+
"--review takes exactly one .su source file",
|
|
474
|
+
file=sys.stderr,
|
|
475
|
+
)
|
|
476
|
+
return 2
|
|
477
|
+
from .review import review_file
|
|
478
|
+
return review_file(args.paths[0])
|
|
479
|
+
if args.emit or args.run or args.run_viz:
|
|
480
|
+
if len(args.paths) != 1:
|
|
481
|
+
print(
|
|
482
|
+
"--emit/--run/--run-viz takes exactly one .su source file",
|
|
483
|
+
file=sys.stderr,
|
|
484
|
+
)
|
|
485
|
+
return 2
|
|
486
|
+
if args.run_viz:
|
|
487
|
+
return _run_viz(
|
|
488
|
+
args.paths[0],
|
|
489
|
+
runtime_dim=args.runtime_dim,
|
|
490
|
+
runtime_seed=args.runtime_seed,
|
|
491
|
+
loop_T=args.loop_T,
|
|
492
|
+
)
|
|
493
|
+
if args.run:
|
|
494
|
+
return _run_execute(
|
|
495
|
+
args.paths[0],
|
|
496
|
+
runtime_dim=args.runtime_dim,
|
|
497
|
+
runtime_seed=args.runtime_seed,
|
|
498
|
+
loop_T=args.loop_T,
|
|
499
|
+
)
|
|
500
|
+
return _run_emit(
|
|
501
|
+
args.paths[0],
|
|
502
|
+
runtime_dim=args.runtime_dim,
|
|
503
|
+
runtime_seed=args.runtime_seed,
|
|
504
|
+
loop_T=args.loop_T,
|
|
505
|
+
)
|
|
506
|
+
if args.json:
|
|
507
|
+
return _run_json(args.paths)
|
|
508
|
+
if args.consistency:
|
|
509
|
+
return _run_consistency(args.paths)
|
|
510
|
+
return _run_text(args.paths, summary=args.summary)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
if __name__ == "__main__":
|
|
514
|
+
sys.exit(main())
|