tablambda 0.6.0.post30.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tablambda/_pyast.py ADDED
@@ -0,0 +1,416 @@
1
+ """A Scott-encoded Python AST, generated by reflection on the ``ast`` module, with a decoder.
2
+
3
+ The compiler's target is a Python AST represented in the pure lambda-calculus by Scott encoding: a
4
+ value of an n-constructor data type is ``lambda h0 ... h_{n-1}. h_tag field0 ... field_k`` (it
5
+ applies the handler for its constructor to its fields). Rather than hand-write a constructor and a
6
+ decoder per node type, both are derived by reflection on ``ast``: ``SUPPORTED`` lists the node
7
+ classes, each class's ``_fields`` gives its arity and field order, and a field is encoded with a
8
+ kind tag (a child node, a list, an int, a string as a list of character codes, a bool, or none) so
9
+ the decoder can rebuild a real ``ast`` node generically with ``cls(*fields)``.
10
+
11
+ Encoding (a real ``ast`` node to a Scott value) and decoding (a Scott value, run in the interpreter,
12
+ back to a real ``ast`` node) are meta-level: the lambda-calculus does not inspect itself, but the
13
+ interpreter that runs it does, exactly as ``render`` reads the behaviour graph. The decoder reflects
14
+ a Scott value to the host by applying it to handlers whose heads are distinct free de Bruijn
15
+ variables (markers); the weak head normal form is then ``marker child0 ... childk``, a spine whose
16
+ head variable names the constructor and whose arguments are the child nodes.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import ast
22
+ import contextlib
23
+ import hashlib
24
+ import struct
25
+
26
+ from tablambda._ast import App, Lam, Node, Var, make_app, make_lam, make_var
27
+ from tablambda._dsl import Builder, app, build, lam
28
+ from tablambda._codec import church
29
+ from tablambda._prelude import SCOTT_NIL
30
+ from tablambda._sugar import cons
31
+
32
+ # The Python AST node classes the encoding supports, in tag order. Extend as needed; encoding an
33
+ # unsupported class raises, so a gap is loud rather than silent.
34
+ SUPPORTED: "tuple[type[ast.AST], ...]" = (
35
+ ast.Expression, ast.Module, ast.Expr,
36
+ ast.Lambda, ast.arguments, ast.arg,
37
+ ast.FunctionDef, ast.Return, ast.Assign, ast.While, ast.If, ast.Pass,
38
+ ast.Call, ast.Name, ast.Load, ast.Store, ast.Constant,
39
+ ast.BinOp, ast.Add, ast.Sub, ast.Mult,
40
+ ast.Compare, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Eq,
41
+ ast.Nonlocal, ast.Is,
42
+ ast.Subscript, ast.Tuple,
43
+ ast.IfExp, ast.Attribute,
44
+ ast.ClassDef, ast.AnnAssign,
45
+ )
46
+ _TAG: "dict[type[ast.AST], int]" = {cls: tag for tag, cls in enumerate(SUPPORTED)}
47
+ _ARITY = tuple(len(cls._fields) for cls in SUPPORTED)
48
+
49
+ # Field kind tags.
50
+ _K_NODE, _K_LIST, _K_INT, _K_STR, _K_BOOL, _K_NONE, _K_IDENT, _K_GENSYM = range(8)
51
+
52
+ # A gensym table mapping a payload node's identity to a content-addressable identifier. The
53
+ # codegen names each entity by an interned payload node; because the recursion is TABLED, the same
54
+ # payload is the SAME interned node, so it decodes to the SAME name (consistent across def and use),
55
+ # while distinct ones get distinct names. The name is derived from a deterministic Merkle hash of
56
+ # the node's de Bruijn structure (not Python's randomized ``hash()``), so the same structure always
57
+ # yields the same name across processes and the output is stable under local source modifications.
58
+ _gensym_ids: "dict[int, str]" = {}
59
+
60
+
61
+ def _reset_gensym() -> None:
62
+ _gensym_ids.clear()
63
+ _merkle_cache.clear()
64
+
65
+
66
+ _merkle_cache: "dict[int, int]" = {}
67
+
68
+
69
+ def _merkle_hash(node: Node) -> int:
70
+ """A deterministic Merkle hash of the node's de Bruijn structure, independent of PYTHONHASHSEED."""
71
+ cached = _merkle_cache.get(id(node))
72
+ if cached is not None:
73
+ return cached
74
+ match node:
75
+ case Var(index=index):
76
+ data = struct.pack(">BQ", 0, index)
77
+ case Lam(body=body):
78
+ data = struct.pack(">BQ", 1, _merkle_hash(body))
79
+ case App(function=function, argument=argument):
80
+ data = struct.pack(">BQQ", 2, _merkle_hash(function), _merkle_hash(argument))
81
+ case _:
82
+ raise ValueError(f"cannot hash {node!r}")
83
+ result = int.from_bytes(hashlib.sha256(data).digest()[:8], "big")
84
+ _merkle_cache[id(node)] = result
85
+ return result
86
+
87
+
88
+ def _gensym_name(payload: "Node") -> str:
89
+ existing = _gensym_ids.get(id(payload))
90
+ if existing is not None:
91
+ return existing
92
+ name = f"vg_{_merkle_hash(payload):016x}"
93
+ _gensym_ids[id(payload)] = name
94
+ return name
95
+
96
+
97
+ # An optional per-decode memo keyed by interned-node identity. ``decode`` is a tree recursion, so a
98
+ # shared interned sub-graph (the interpreter hash-conses) is otherwise re-decoded once per occurrence.
99
+ # Lambda-lifted call-by-need emits the SAME interned factory node once per occurrence of a shared
100
+ # sub-term (COMPILE shares ~19x); decoding under ``memoized_decode`` forces and decodes each DISTINCT
101
+ # node once (O(distinct) instead of O(occurrences)), the same node-identity memoization ``to_anf_source``
102
+ # relies on via ``_extract``. Off by default so unrelated decodes stay simple.
103
+ _decode_memo: "dict[int, ast.AST] | None" = None
104
+
105
+
106
+ @contextlib.contextmanager
107
+ def memoized_decode():
108
+ """Decode each distinct interned Scott node once, keyed by node identity (does not nest)."""
109
+ global _decode_memo
110
+ assert _decode_memo is None, "memoized_decode does not nest"
111
+ _decode_memo = {}
112
+ try:
113
+ yield
114
+ finally:
115
+ _decode_memo = None
116
+
117
+ # Disjoint free-variable bands used as meta markers (far above any real index; Scott values here
118
+ # are closed, so the only free variables in a probed term are these markers).
119
+ _CTOR_BASE = 1_000_000
120
+ _FIELD_BASE = 2_000_000
121
+ _LIST_BASE = 3_000_000
122
+ _CHURCH_SUCC = 4_000_000
123
+ _CHURCH_ZERO = 4_000_001
124
+
125
+
126
+ # --- encoding: a real ast node to a Scott Builder ---------------------------------------------
127
+
128
+ def _ctor(tag: int, fields: "list[Builder]") -> Builder:
129
+ def collect(handlers: "list[Builder]") -> Builder:
130
+ if len(handlers) == len(SUPPORTED):
131
+ applied = handlers[tag]
132
+ for field in fields:
133
+ applied = app(applied, field)
134
+ return applied
135
+ return lam(lambda handler: collect(handlers + [handler]))
136
+
137
+ return collect([])
138
+
139
+
140
+ def _scott_list(elements: "list[Builder]") -> Builder:
141
+ result = SCOTT_NIL
142
+ for element in reversed(elements):
143
+ result = cons(element, result)
144
+ return result
145
+
146
+
147
+ def _kind(kind: int, payload: Builder) -> Builder:
148
+ # <church kind, payload> as a 2-tuple lambda s. s (church kind) payload.
149
+ return lam(lambda selector: app(app(selector, church(kind)), payload))
150
+
151
+
152
+ def _encode_field(value: object) -> Builder:
153
+ if isinstance(value, ast.AST):
154
+ return _kind(_K_NODE, encode(value))
155
+ if isinstance(value, list):
156
+ return _kind(_K_LIST, _scott_list([_encode_field(item) for item in value]))
157
+ if isinstance(value, bool):
158
+ return _kind(_K_BOOL, church(1 if value else 0))
159
+ if isinstance(value, int):
160
+ return _kind(_K_INT, church(value))
161
+ if isinstance(value, str):
162
+ return _kind(_K_STR, _scott_list([church(ord(character)) for character in value]))
163
+ if value is None:
164
+ return _kind(_K_NONE, church(0))
165
+ raise ValueError(f"cannot encode field value {value!r} of type {type(value).__name__}")
166
+
167
+
168
+ def encode(node: ast.AST) -> Builder:
169
+ """Encode a real Python ``ast`` node into a Scott-encoded Builder."""
170
+ cls = type(node)
171
+ if cls not in _TAG:
172
+ raise ValueError(f"unsupported ast node {cls.__name__}; add it to SUPPORTED")
173
+ fields = [_encode_field(getattr(node, name, None)) for name in cls._fields]
174
+ return _ctor(_TAG[cls], fields)
175
+
176
+
177
+ # --- decoding: a Scott value (run in the interpreter) back to a real ast node ------------------
178
+
179
+ def _handler(tag: int, arity: int, base: int) -> Node:
180
+ body: Node = make_var(base + tag + arity) # the free marker, shifted past the arity binders
181
+ for position in range(arity):
182
+ body = make_app(body, make_var(arity - 1 - position))
183
+ for _ in range(arity):
184
+ body = make_lam(body)
185
+ return body
186
+
187
+
188
+ def _spine(node: Node) -> "tuple[int, list[Node]]":
189
+ arguments: "list[Node]" = []
190
+ current = node
191
+ while True:
192
+ whnf = current.weak_head_normal_form
193
+ match whnf:
194
+ case Var(index=index):
195
+ arguments.reverse()
196
+ return index, arguments
197
+ case App(function=function, argument=argument):
198
+ arguments.append(argument)
199
+ current = function
200
+ case _:
201
+ raise ValueError(f"expected a variable-headed spine, got {whnf!r}")
202
+
203
+
204
+ def _extract(node: Node, arities: "tuple[int, ...]", base: int) -> "tuple[int, list[Node]]":
205
+ applied = node
206
+ for tag, arity in enumerate(arities):
207
+ applied = make_app(applied, _handler(tag, arity, base))
208
+ head, fields = _spine(applied)
209
+ return head - base, fields
210
+
211
+
212
+ def _church_to_int(node: Node) -> int:
213
+ current = make_app(make_app(node, make_var(_CHURCH_SUCC)), make_var(_CHURCH_ZERO))
214
+ count = 0
215
+ while True:
216
+ whnf = current.weak_head_normal_form
217
+ match whnf:
218
+ case Var(index=index):
219
+ if index != _CHURCH_ZERO:
220
+ raise ValueError(f"church spine ended at variable {index}")
221
+ return count
222
+ case App(argument=argument):
223
+ count += 1
224
+ current = argument
225
+ case _:
226
+ raise ValueError(f"church spine hit {whnf!r}")
227
+
228
+
229
+ def _decode_scott_list(node: Node) -> "list[Node]":
230
+ items: "list[Node]" = []
231
+ current = node
232
+ while True:
233
+ head, fields = _extract(current, (2, 0), _LIST_BASE) # cons of arity 2, nil of arity 0
234
+ if head == 1: # nil
235
+ return items
236
+ items.append(fields[0]) # cons head
237
+ current = fields[1] # cons tail
238
+
239
+
240
+ def _decode_field(node: Node) -> object:
241
+ head, fields = _extract(node, (2,), _FIELD_BASE) # the <kind, payload> pair
242
+ kind = _church_to_int(fields[0])
243
+ payload = fields[1]
244
+ match kind:
245
+ case 0: # NODE
246
+ return decode(payload)
247
+ case 1: # LIST
248
+ return [_decode_field(item) for item in _decode_scott_list(payload)]
249
+ case 2: # INT
250
+ return _church_to_int(payload)
251
+ case 3: # STR
252
+ return "".join(chr(_church_to_int(code)) for code in _decode_scott_list(payload))
253
+ case 4: # BOOL
254
+ return bool(_church_to_int(payload))
255
+ case 5: # NONE
256
+ return None
257
+ case 6: # IDENT: a list of Nats (an AST path) rendered to one underscore-joined identifier
258
+ return _path_to_identifier(payload)
259
+ case 7: # GENSYM: a fresh identifier per distinct (interned) payload node, by node identity
260
+ return _gensym_name(payload)
261
+ case _:
262
+ raise ValueError(f"unknown field kind {kind}")
263
+
264
+
265
+ # The single identifier decoder shared by every runtime: a variable's identifier is the list of Nats
266
+ # naming its AST path, rendered ``v`` then ``_<segment>`` per segment (underscore-joined decimal
267
+ # integers, an alphabetic prefix). Distinct paths give distinct names, so uniqueness is by construction;
268
+ # the lambda compiler emits only the path (a list of Nats), never the rendered string.
269
+ def _path_to_identifier(payload: Node) -> str:
270
+ segments = [_church_to_int(segment) for segment in _decode_scott_list(payload)]
271
+ return "_".join(["v", *(str(segment) for segment in segments)])
272
+
273
+
274
+ def decode(node: Node) -> ast.AST:
275
+ """Decode a Scott-encoded ast value (run in the interpreter) back to a real ``ast`` node.
276
+
277
+ Under ``memoized_decode`` each distinct interned node is decoded once (the result is shared across
278
+ its occurrences), collapsing a shared sub-graph instead of re-walking every occurrence.
279
+ """
280
+ if _decode_memo is not None:
281
+ cached = _decode_memo.get(id(node))
282
+ if cached is not None:
283
+ return cached
284
+ tag, fields = _extract(node, _ARITY, _CTOR_BASE)
285
+ cls = SUPPORTED[tag]
286
+ decoded = cls(*[_decode_field(field) for field in fields])
287
+ if _decode_memo is not None:
288
+ _decode_memo[id(node)] = decoded
289
+ return decoded
290
+
291
+
292
+ def to_python_source(node: Node) -> str:
293
+ """Decode a Scott-encoded Python AST and unparse it to source."""
294
+ return ast.unparse(ast.fix_missing_locations(decode(node)))
295
+
296
+
297
+ def _field_payload(field_node: Node) -> Node:
298
+ """The payload Scott node of a ``<kind, payload>`` field (without decoding the payload)."""
299
+ _, parts = _extract(field_node, (2,), _FIELD_BASE)
300
+ return parts[1]
301
+
302
+
303
+ def to_anf_source(node: Node, binding_name: str) -> str:
304
+ """Serialize a Scott-encoded Python AST to A-normal-form source, sharing sub-expressions by identity.
305
+
306
+ A generic, graph-preserving serialization (not program-specific): each distinct ``ast.Call`` node in
307
+ the Scott value becomes ONE assignment to a fresh temporary, so a sub-graph shared across the term
308
+ (the interpreter hash-conses, so a repeated sub-term is the SAME node) is emitted once rather than
309
+ unfolded, and a deeply nested reconstruction stays under CPython's parser nesting cap. Each distinct
310
+ node is forced once (``_extract`` memoised by node identity), which keeps a compiler-scale graph from
311
+ blowing up the interner. Non-``Call`` nodes (``Name``/``Constant``/``Lambda``/...) are inlined whole.
312
+
313
+ A compiler-emitted helper definition rides along as an ``ast.FunctionDef`` node embedded in the graph
314
+ (in the function position of the ``ast.Call`` that invokes it). Such a node is HOISTED: emitted once
315
+ (deduplicated by node identity, exactly like the ``_k`` call-sharing) into a preamble of top-level
316
+ ``def``s ahead of the assignments, and replaced in place by a ``Name`` referencing the def. This is the
317
+ general facility that lets the compiler emit its own top-level helpers (e.g. the analysis fixpoint)
318
+ instead of importing them from the runtime; the def is a program-independent constant, so it interns to
319
+ one node and is hoisted exactly once.
320
+
321
+ The result is a module: the hoisted ``def``s, then the temp assignments, ending ``<binding_name> =
322
+ <root>`` (a non-``Call`` root just yields that final bare-expression assignment).
323
+ """
324
+ memo: "dict[int, str]" = {}
325
+ hoisted: "list[ast.stmt]" = []
326
+ statements: "list[ast.stmt]" = []
327
+ counter = 0
328
+
329
+ def hoist_def(scott_node: Node) -> ast.expr:
330
+ """Hoist an ``ast.FunctionDef`` node to the top-level preamble (once, by node identity) and
331
+ reference it by its own name. Decoded whole: a helper def is small and fixed, so its body needs
332
+ no further A-normal-form flattening."""
333
+ key = id(scott_node)
334
+ cached = memo.get(key)
335
+ if cached is not None:
336
+ return ast.Name(id=cached, ctx=ast.Load())
337
+ function_def = decode(scott_node)
338
+ assert isinstance(function_def, ast.FunctionDef), (
339
+ f"hoist_def expected an ast.FunctionDef, got {type(function_def).__name__}"
340
+ )
341
+ hoisted.append(function_def)
342
+ memo[key] = function_def.name
343
+ return ast.Name(id=function_def.name, ctx=ast.Load())
344
+
345
+ def emit_callee(callee_node: Node) -> ast.expr:
346
+ """The function position of a Call: a hoisted ``FunctionDef`` becomes a top-level def reference;
347
+ anything else (a ``Name``, a ``Lambda``) is inlined whole, as before -- NOT routed through
348
+ ``emit``, so a curried application ``f(a)(b)`` keeps its nested-call shape rather than ANF-splitting."""
349
+ tag, _ = _extract(callee_node, _ARITY, _CTOR_BASE)
350
+ if SUPPORTED[tag] is ast.FunctionDef:
351
+ return hoist_def(callee_node)
352
+ inlined = decode(callee_node)
353
+ assert isinstance(inlined, ast.expr), (
354
+ f"emit_callee expected an ast.expr, got {type(inlined).__name__}"
355
+ )
356
+ return inlined
357
+
358
+ def emit(scott_node: Node) -> ast.expr:
359
+ nonlocal counter
360
+ key = id(scott_node)
361
+ if key in memo:
362
+ return ast.Name(id=memo[key], ctx=ast.Load())
363
+ tag, fields = _extract(scott_node, _ARITY, _CTOR_BASE)
364
+ cls = SUPPORTED[tag]
365
+ if cls is ast.FunctionDef:
366
+ return hoist_def(scott_node) # a helper def reached in a value position: hoist it
367
+ if cls is not ast.Call:
368
+ inlined = decode(scott_node) # a leaf or opaque expression (Name/Constant/Lambda/...): inline it
369
+ assert isinstance(inlined, ast.expr), (
370
+ f"emit expected an ast.expr, got {type(inlined).__name__}"
371
+ )
372
+ return inlined
373
+ function = emit_callee(_field_payload(fields[0]))
374
+ arguments = [emit(_field_payload(argument)) for argument in _decode_scott_list(_field_payload(fields[1]))]
375
+ call = ast.Call(func=function, args=arguments, keywords=[])
376
+ name = f"_k{counter}"
377
+ counter += 1
378
+ statements.append(ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=call))
379
+ memo[key] = name
380
+ return ast.Name(id=name, ctx=ast.Load())
381
+
382
+ root = emit(node)
383
+ statements.append(ast.Assign(targets=[ast.Name(id=binding_name, ctx=ast.Store())], value=root))
384
+ return ast.unparse(ast.fix_missing_locations(ast.Module(body=hoisted + statements, type_ignores=[])))
385
+
386
+
387
+ def roundtrip(source: str, *, mode: str = "eval") -> str:
388
+ """Parse ``source``, encode it to a Scott value, decode it, and unparse: a faithfulness check."""
389
+ return to_python_source(build(encode(ast.parse(source, mode=mode))))
390
+
391
+
392
+ # --- BinNat readouts (the decode side of the _binnat encoding) -------------------------------------
393
+
394
+ # A free-variable band used as a meta marker when probing a Scott boolean (a bit). Disjoint from the
395
+ # bands above, so a probed bit's only free variables are these markers.
396
+ _BIT_BASE = 8_500_000
397
+
398
+
399
+ def _bit_value(node: "Node") -> int:
400
+ # A Scott boolean applied to two nullary handlers exposes handler 0 for TRUE, handler 1 for FALSE.
401
+ tag, _ = _extract(node, (0, 0), _BIT_BASE)
402
+ return 1 if tag == 0 else 0
403
+
404
+
405
+ def binnat_to_int(node: "Node") -> int:
406
+ """Decode a BinNat (an LSB-first Scott list of bits) to a non-negative int."""
407
+ value = 0
408
+ for position, bit in enumerate(_decode_scott_list(node)):
409
+ value += _bit_value(bit) << position
410
+ return value
411
+
412
+
413
+ def binnat_list_to_identifier(node: "Node", prefix: str = "v") -> str:
414
+ """Decode a Scott list of BinNats to an underscore-joined identifier, e.g. ``v_12_3_567``."""
415
+ segments = [binnat_to_int(segment) for segment in _decode_scott_list(node)]
416
+ return "_".join([prefix, *(str(segment) for segment in segments)])