co-lambda 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
co_lambda/_pyast.py ADDED
@@ -0,0 +1,418 @@
1
+ """A Scott-encoded Python AST, generated by reflection on the ``ast`` module, with a decoder.
2
+
3
+ The compiler's target is a Python AST represented in the pure lambda-calculus by Scott encoding: a
4
+ value of an n-constructor data type is ``lambda h0 ... h_{n-1}. h_tag field0 ... field_k`` (it
5
+ applies the handler for its constructor to its fields). Rather than hand-write a constructor and a
6
+ decoder per node type, both are derived by reflection on ``ast``: ``SUPPORTED`` lists the node
7
+ classes, each class's ``_fields`` gives its arity and field order, and a field is encoded with a
8
+ kind tag (a child node, a list, an int, a string as a list of character codes, a bool, or none) so
9
+ the decoder can rebuild a real ``ast`` node generically with ``cls(*fields)``.
10
+
11
+ Encoding (a real ``ast`` node to a Scott value) and decoding (a Scott value, run in the interpreter,
12
+ back to a real ``ast`` node) are meta-level: the lambda-calculus does not inspect itself, but the
13
+ interpreter that runs it does, exactly as ``render`` reads the behaviour graph. The decoder reflects
14
+ a Scott value to the host by applying it to handlers whose heads are distinct free de Bruijn
15
+ variables (markers); the weak head normal form is then ``marker child0 ... childk``, a spine whose
16
+ head variable names the constructor and whose arguments are the child nodes.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import ast
22
+ import contextlib
23
+ import hashlib
24
+ import struct
25
+
26
+ from co_lambda._ast import App, Lam, Native, Node, Var, make_app, make_lam, make_var
27
+ from co_lambda._dsl import Builder, app, build, lam
28
+ from co_lambda._codec import church
29
+ from co_lambda._prelude import SCOTT_NIL
30
+ from co_lambda._sugar import cons
31
+
32
+ # The Python AST node classes the encoding supports, in tag order. Extend as needed; encoding an
33
+ # unsupported class raises, so a gap is loud rather than silent.
34
+ SUPPORTED: "tuple[type[ast.AST], ...]" = (
35
+ ast.Expression, ast.Module, ast.Expr,
36
+ ast.Lambda, ast.arguments, ast.arg,
37
+ ast.FunctionDef, ast.Return, ast.Assign, ast.While, ast.If, ast.Pass,
38
+ ast.Call, ast.Name, ast.Load, ast.Store, ast.Constant,
39
+ ast.BinOp, ast.Add, ast.Sub, ast.Mult,
40
+ ast.Compare, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Eq,
41
+ ast.Nonlocal, ast.Is,
42
+ ast.Subscript, ast.Tuple,
43
+ ast.IfExp, ast.Attribute,
44
+ ast.ClassDef, ast.AnnAssign,
45
+ )
46
+ _TAG: "dict[type[ast.AST], int]" = {cls: tag for tag, cls in enumerate(SUPPORTED)}
47
+ _ARITY = tuple(len(cls._fields) for cls in SUPPORTED)
48
+
49
+ # Field kind tags.
50
+ _K_NODE, _K_LIST, _K_INT, _K_STR, _K_BOOL, _K_NONE, _K_IDENT, _K_GENSYM = range(8)
51
+
52
+ # A gensym table mapping a payload node's identity to a content-addressable identifier. The
53
+ # codegen names each entity by an interned payload node; because the recursion is TABLED, the same
54
+ # payload is the SAME interned node, so it decodes to the SAME name (consistent across def and use),
55
+ # while distinct ones get distinct names. The name is derived from a deterministic Merkle hash of
56
+ # the node's de Bruijn structure (not Python's randomized ``hash()``), so the same structure always
57
+ # yields the same name across processes and the output is stable under local source modifications.
58
+ _gensym_ids: "dict[int, str]" = {}
59
+
60
+
61
+ def _reset_gensym() -> None:
62
+ _gensym_ids.clear()
63
+ _merkle_cache.clear()
64
+
65
+
66
+ _merkle_cache: "dict[int, int]" = {}
67
+
68
+
69
+ def _merkle_hash(node: Node) -> int:
70
+ """A deterministic Merkle hash of the node's de Bruijn structure, independent of PYTHONHASHSEED."""
71
+ cached = _merkle_cache.get(id(node))
72
+ if cached is not None:
73
+ return cached
74
+ match node:
75
+ case Var(index=index):
76
+ data = struct.pack(">BQ", 0, index)
77
+ case Lam(body=body):
78
+ data = struct.pack(">BQ", 1, _merkle_hash(body))
79
+ case App(function=function, argument=argument):
80
+ data = struct.pack(">BQQ", 2, _merkle_hash(function), _merkle_hash(argument))
81
+ case Native():
82
+ data = struct.pack(">BQ", 3, id(node) & 0xFFFFFFFFFFFFFFFF)
83
+ case _:
84
+ raise ValueError(f"cannot hash {node!r}")
85
+ result = int.from_bytes(hashlib.sha256(data).digest()[:8], "big")
86
+ _merkle_cache[id(node)] = result
87
+ return result
88
+
89
+
90
+ def _gensym_name(payload: "Node") -> str:
91
+ existing = _gensym_ids.get(id(payload))
92
+ if existing is not None:
93
+ return existing
94
+ name = f"vg_{_merkle_hash(payload):016x}"
95
+ _gensym_ids[id(payload)] = name
96
+ return name
97
+
98
+
99
+ # An optional per-decode memo keyed by interned-node identity. ``decode`` is a tree recursion, so a
100
+ # shared interned sub-graph (the interpreter hash-conses) is otherwise re-decoded once per occurrence.
101
+ # Lambda-lifted call-by-need emits the SAME interned factory node once per occurrence of a shared
102
+ # sub-term (COMPILE shares ~19x); decoding under ``memoized_decode`` forces and decodes each DISTINCT
103
+ # node once (O(distinct) instead of O(occurrences)), the same node-identity memoization ``to_anf_source``
104
+ # relies on via ``_extract``. Off by default so unrelated decodes stay simple.
105
+ _decode_memo: "dict[int, ast.AST] | None" = None
106
+
107
+
108
+ @contextlib.contextmanager
109
+ def memoized_decode():
110
+ """Decode each distinct interned Scott node once, keyed by node identity (does not nest)."""
111
+ global _decode_memo
112
+ assert _decode_memo is None, "memoized_decode does not nest"
113
+ _decode_memo = {}
114
+ try:
115
+ yield
116
+ finally:
117
+ _decode_memo = None
118
+
119
+ # Disjoint free-variable bands used as meta markers (far above any real index; Scott values here
120
+ # are closed, so the only free variables in a probed term are these markers).
121
+ _CTOR_BASE = 1_000_000
122
+ _FIELD_BASE = 2_000_000
123
+ _LIST_BASE = 3_000_000
124
+ _CHURCH_SUCC = 4_000_000
125
+ _CHURCH_ZERO = 4_000_001
126
+
127
+
128
+ # --- encoding: a real ast node to a Scott Builder ---------------------------------------------
129
+
130
+ def _ctor(tag: int, fields: "list[Builder]") -> Builder:
131
+ def collect(handlers: "list[Builder]") -> Builder:
132
+ if len(handlers) == len(SUPPORTED):
133
+ applied = handlers[tag]
134
+ for field in fields:
135
+ applied = app(applied, field)
136
+ return applied
137
+ return lam(lambda handler: collect(handlers + [handler]))
138
+
139
+ return collect([])
140
+
141
+
142
+ def _scott_list(elements: "list[Builder]") -> Builder:
143
+ result = SCOTT_NIL
144
+ for element in reversed(elements):
145
+ result = cons(element, result)
146
+ return result
147
+
148
+
149
+ def _kind(kind: int, payload: Builder) -> Builder:
150
+ # <church kind, payload> as a 2-tuple lambda s. s (church kind) payload.
151
+ return lam(lambda selector: app(app(selector, church(kind)), payload))
152
+
153
+
154
+ def _encode_field(value: object) -> Builder:
155
+ if isinstance(value, ast.AST):
156
+ return _kind(_K_NODE, encode(value))
157
+ if isinstance(value, list):
158
+ return _kind(_K_LIST, _scott_list([_encode_field(item) for item in value]))
159
+ if isinstance(value, bool):
160
+ return _kind(_K_BOOL, church(1 if value else 0))
161
+ if isinstance(value, int):
162
+ return _kind(_K_INT, church(value))
163
+ if isinstance(value, str):
164
+ return _kind(_K_STR, _scott_list([church(ord(character)) for character in value]))
165
+ if value is None:
166
+ return _kind(_K_NONE, church(0))
167
+ raise ValueError(f"cannot encode field value {value!r} of type {type(value).__name__}")
168
+
169
+
170
+ def encode(node: ast.AST) -> Builder:
171
+ """Encode a real Python ``ast`` node into a Scott-encoded Builder."""
172
+ cls = type(node)
173
+ if cls not in _TAG:
174
+ raise ValueError(f"unsupported ast node {cls.__name__}; add it to SUPPORTED")
175
+ fields = [_encode_field(getattr(node, name, None)) for name in cls._fields]
176
+ return _ctor(_TAG[cls], fields)
177
+
178
+
179
+ # --- decoding: a Scott value (run in the interpreter) back to a real ast node ------------------
180
+
181
+ def _handler(tag: int, arity: int, base: int) -> Node:
182
+ body: Node = make_var(base + tag + arity) # the free marker, shifted past the arity binders
183
+ for position in range(arity):
184
+ body = make_app(body, make_var(arity - 1 - position))
185
+ for _ in range(arity):
186
+ body = make_lam(body)
187
+ return body
188
+
189
+
190
+ def _spine(node: Node) -> "tuple[int, list[Node]]":
191
+ arguments: "list[Node]" = []
192
+ current = node
193
+ while True:
194
+ whnf = current.weak_head_normal_form
195
+ match whnf:
196
+ case Var(index=index):
197
+ arguments.reverse()
198
+ return index, arguments
199
+ case App(function=function, argument=argument):
200
+ arguments.append(argument)
201
+ current = function
202
+ case _:
203
+ raise ValueError(f"expected a variable-headed spine, got {whnf!r}")
204
+
205
+
206
+ def _extract(node: Node, arities: "tuple[int, ...]", base: int) -> "tuple[int, list[Node]]":
207
+ applied = node
208
+ for tag, arity in enumerate(arities):
209
+ applied = make_app(applied, _handler(tag, arity, base))
210
+ head, fields = _spine(applied)
211
+ return head - base, fields
212
+
213
+
214
+ def _church_to_int(node: Node) -> int:
215
+ current = make_app(make_app(node, make_var(_CHURCH_SUCC)), make_var(_CHURCH_ZERO))
216
+ count = 0
217
+ while True:
218
+ whnf = current.weak_head_normal_form
219
+ match whnf:
220
+ case Var(index=index):
221
+ if index != _CHURCH_ZERO:
222
+ raise ValueError(f"church spine ended at variable {index}")
223
+ return count
224
+ case App(argument=argument):
225
+ count += 1
226
+ current = argument
227
+ case _:
228
+ raise ValueError(f"church spine hit {whnf!r}")
229
+
230
+
231
+ def _decode_scott_list(node: Node) -> "list[Node]":
232
+ items: "list[Node]" = []
233
+ current = node
234
+ while True:
235
+ head, fields = _extract(current, (2, 0), _LIST_BASE) # cons of arity 2, nil of arity 0
236
+ if head == 1: # nil
237
+ return items
238
+ items.append(fields[0]) # cons head
239
+ current = fields[1] # cons tail
240
+
241
+
242
+ def _decode_field(node: Node) -> object:
243
+ head, fields = _extract(node, (2,), _FIELD_BASE) # the <kind, payload> pair
244
+ kind = _church_to_int(fields[0])
245
+ payload = fields[1]
246
+ match kind:
247
+ case 0: # NODE
248
+ return decode(payload)
249
+ case 1: # LIST
250
+ return [_decode_field(item) for item in _decode_scott_list(payload)]
251
+ case 2: # INT
252
+ return _church_to_int(payload)
253
+ case 3: # STR
254
+ return "".join(chr(_church_to_int(code)) for code in _decode_scott_list(payload))
255
+ case 4: # BOOL
256
+ return bool(_church_to_int(payload))
257
+ case 5: # NONE
258
+ return None
259
+ case 6: # IDENT: a list of Nats (an AST path) rendered to one underscore-joined identifier
260
+ return _path_to_identifier(payload)
261
+ case 7: # GENSYM: a fresh identifier per distinct (interned) payload node, by node identity
262
+ return _gensym_name(payload)
263
+ case _:
264
+ raise ValueError(f"unknown field kind {kind}")
265
+
266
+
267
+ # The single identifier decoder shared by every runtime: a variable's identifier is the list of Nats
268
+ # naming its AST path, rendered ``v`` then ``_<segment>`` per segment (underscore-joined decimal
269
+ # integers, an alphabetic prefix). Distinct paths give distinct names, so uniqueness is by construction;
270
+ # the lambda compiler emits only the path (a list of Nats), never the rendered string.
271
+ def _path_to_identifier(payload: Node) -> str:
272
+ segments = [_church_to_int(segment) for segment in _decode_scott_list(payload)]
273
+ return "_".join(["v", *(str(segment) for segment in segments)])
274
+
275
+
276
+ def decode(node: Node) -> ast.AST:
277
+ """Decode a Scott-encoded ast value (run in the interpreter) back to a real ``ast`` node.
278
+
279
+ Under ``memoized_decode`` each distinct interned node is decoded once (the result is shared across
280
+ its occurrences), collapsing a shared sub-graph instead of re-walking every occurrence.
281
+ """
282
+ if _decode_memo is not None:
283
+ cached = _decode_memo.get(id(node))
284
+ if cached is not None:
285
+ return cached
286
+ tag, fields = _extract(node, _ARITY, _CTOR_BASE)
287
+ cls = SUPPORTED[tag]
288
+ decoded = cls(*[_decode_field(field) for field in fields])
289
+ if _decode_memo is not None:
290
+ _decode_memo[id(node)] = decoded
291
+ return decoded
292
+
293
+
294
+ def to_python_source(node: Node) -> str:
295
+ """Decode a Scott-encoded Python AST and unparse it to source."""
296
+ return ast.unparse(ast.fix_missing_locations(decode(node)))
297
+
298
+
299
+ def _field_payload(field_node: Node) -> Node:
300
+ """The payload Scott node of a ``<kind, payload>`` field (without decoding the payload)."""
301
+ _, parts = _extract(field_node, (2,), _FIELD_BASE)
302
+ return parts[1]
303
+
304
+
305
+ def to_anf_source(node: Node, binding_name: str) -> str:
306
+ """Serialize a Scott-encoded Python AST to A-normal-form source, sharing sub-expressions by identity.
307
+
308
+ A generic, graph-preserving serialization (not program-specific): each distinct ``ast.Call`` node in
309
+ the Scott value becomes ONE assignment to a fresh temporary, so a sub-graph shared across the term
310
+ (the interpreter hash-conses, so a repeated sub-term is the SAME node) is emitted once rather than
311
+ unfolded, and a deeply nested reconstruction stays under CPython's parser nesting cap. Each distinct
312
+ node is forced once (``_extract`` memoised by node identity), which keeps a compiler-scale graph from
313
+ blowing up the interner. Non-``Call`` nodes (``Name``/``Constant``/``Lambda``/...) are inlined whole.
314
+
315
+ A compiler-emitted helper definition rides along as an ``ast.FunctionDef`` node embedded in the graph
316
+ (in the function position of the ``ast.Call`` that invokes it). Such a node is HOISTED: emitted once
317
+ (deduplicated by node identity, exactly like the ``_k`` call-sharing) into a preamble of top-level
318
+ ``def``s ahead of the assignments, and replaced in place by a ``Name`` referencing the def. This is the
319
+ general facility that lets the compiler emit its own top-level helpers (e.g. the analysis fixpoint)
320
+ instead of importing them from the runtime; the def is a program-independent constant, so it interns to
321
+ one node and is hoisted exactly once.
322
+
323
+ The result is a module: the hoisted ``def``s, then the temp assignments, ending ``<binding_name> =
324
+ <root>`` (a non-``Call`` root just yields that final bare-expression assignment).
325
+ """
326
+ memo: "dict[int, str]" = {}
327
+ hoisted: "list[ast.stmt]" = []
328
+ statements: "list[ast.stmt]" = []
329
+ counter = 0
330
+
331
+ def hoist_def(scott_node: Node) -> ast.expr:
332
+ """Hoist an ``ast.FunctionDef`` node to the top-level preamble (once, by node identity) and
333
+ reference it by its own name. Decoded whole: a helper def is small and fixed, so its body needs
334
+ no further A-normal-form flattening."""
335
+ key = id(scott_node)
336
+ cached = memo.get(key)
337
+ if cached is not None:
338
+ return ast.Name(id=cached, ctx=ast.Load())
339
+ function_def = decode(scott_node)
340
+ assert isinstance(function_def, ast.FunctionDef), (
341
+ f"hoist_def expected an ast.FunctionDef, got {type(function_def).__name__}"
342
+ )
343
+ hoisted.append(function_def)
344
+ memo[key] = function_def.name
345
+ return ast.Name(id=function_def.name, ctx=ast.Load())
346
+
347
+ def emit_callee(callee_node: Node) -> ast.expr:
348
+ """The function position of a Call: a hoisted ``FunctionDef`` becomes a top-level def reference;
349
+ anything else (a ``Name``, a ``Lambda``) is inlined whole, as before -- NOT routed through
350
+ ``emit``, so a curried application ``f(a)(b)`` keeps its nested-call shape rather than ANF-splitting."""
351
+ tag, _ = _extract(callee_node, _ARITY, _CTOR_BASE)
352
+ if SUPPORTED[tag] is ast.FunctionDef:
353
+ return hoist_def(callee_node)
354
+ inlined = decode(callee_node)
355
+ assert isinstance(inlined, ast.expr), (
356
+ f"emit_callee expected an ast.expr, got {type(inlined).__name__}"
357
+ )
358
+ return inlined
359
+
360
+ def emit(scott_node: Node) -> ast.expr:
361
+ nonlocal counter
362
+ key = id(scott_node)
363
+ if key in memo:
364
+ return ast.Name(id=memo[key], ctx=ast.Load())
365
+ tag, fields = _extract(scott_node, _ARITY, _CTOR_BASE)
366
+ cls = SUPPORTED[tag]
367
+ if cls is ast.FunctionDef:
368
+ return hoist_def(scott_node) # a helper def reached in a value position: hoist it
369
+ if cls is not ast.Call:
370
+ inlined = decode(scott_node) # a leaf or opaque expression (Name/Constant/Lambda/...): inline it
371
+ assert isinstance(inlined, ast.expr), (
372
+ f"emit expected an ast.expr, got {type(inlined).__name__}"
373
+ )
374
+ return inlined
375
+ function = emit_callee(_field_payload(fields[0]))
376
+ arguments = [emit(_field_payload(argument)) for argument in _decode_scott_list(_field_payload(fields[1]))]
377
+ call = ast.Call(func=function, args=arguments, keywords=[])
378
+ name = f"_k{counter}"
379
+ counter += 1
380
+ statements.append(ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=call))
381
+ memo[key] = name
382
+ return ast.Name(id=name, ctx=ast.Load())
383
+
384
+ root = emit(node)
385
+ statements.append(ast.Assign(targets=[ast.Name(id=binding_name, ctx=ast.Store())], value=root))
386
+ return ast.unparse(ast.fix_missing_locations(ast.Module(body=hoisted + statements, type_ignores=[])))
387
+
388
+
389
+ def roundtrip(source: str, *, mode: str = "eval") -> str:
390
+ """Parse ``source``, encode it to a Scott value, decode it, and unparse: a faithfulness check."""
391
+ return to_python_source(build(encode(ast.parse(source, mode=mode))))
392
+
393
+
394
+ # --- BinNat readouts (the decode side of the _binnat encoding) -------------------------------------
395
+
396
+ # A free-variable band used as a meta marker when probing a Scott boolean (a bit). Disjoint from the
397
+ # bands above, so a probed bit's only free variables are these markers.
398
+ _BIT_BASE = 8_500_000
399
+
400
+
401
+ def _bit_value(node: "Node") -> int:
402
+ # A Scott boolean applied to two nullary handlers exposes handler 0 for TRUE, handler 1 for FALSE.
403
+ tag, _ = _extract(node, (0, 0), _BIT_BASE)
404
+ return 1 if tag == 0 else 0
405
+
406
+
407
+ def binnat_to_int(node: "Node") -> int:
408
+ """Decode a BinNat (an LSB-first Scott list of bits) to a non-negative int."""
409
+ value = 0
410
+ for position, bit in enumerate(_decode_scott_list(node)):
411
+ value += _bit_value(bit) << position
412
+ return value
413
+
414
+
415
+ def binnat_list_to_identifier(node: "Node", prefix: str = "v") -> str:
416
+ """Decode a Scott list of BinNats to an underscore-joined identifier, e.g. ``v_12_3_567``."""
417
+ segments = [binnat_to_int(segment) for segment in _decode_scott_list(node)]
418
+ return "_".join([prefix, *(str(segment) for segment in segments)])