co-lambda 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,188 @@
1
+ """The defunctionalized runtime: the minimal execution substrate for compiled code.
2
+
3
+ Generated code references exactly two free names from this module: ``Thunk`` and ``interned``.
4
+ Everything else is internal implementation (``_BOTTOM``, ``fixpoint_cached_property``) or a
5
+ host import (``dataclass``). The runtime holds NO domain logic; all compilation decisions live
6
+ in the pure-lambda compiler ``_defun_codegen``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import hashlib
12
+ import struct
13
+ import sys
14
+ import threading
15
+ from collections.abc import Callable, Iterator
16
+ from contextlib import contextmanager
17
+ from dataclasses import dataclass, fields as dataclass_fields
18
+ from enum import Enum, auto
19
+ from typing import Any, Literal, Protocol, TypeGuard, Union, TypeVar, overload
20
+
21
+ from typing_extensions import dataclass_transform
22
+
23
+ from fixpoints._core import fixpoint_cached_property, fixpoint_slotted
24
+
25
+ _T = TypeVar("_T")
26
+
27
+
28
+ class _DefunBottom(Enum):
29
+ BOTTOM = auto()
30
+
31
+
32
+ _BOTTOM = _DefunBottom.BOTTOM
33
+
34
+
35
+ def _intern(cls: type[_T], field_names: tuple[str, ...]) -> type[_T]:
36
+ """Hash-cons ``cls``'s instances by ``(cls, field-values-by-identity)``.
37
+
38
+ Two instances of the same class with identical field values (by ``is``) become the same object.
39
+ Fields are themselves interned closures or ``Thunk`` instances, so identity comparison is O(1)
40
+ structural equality, matching ``_ast._intern_node``. The hash-cons table is exposed as
41
+ ``__intern_pool__`` for introspection (e.g. counting tabled objects in a benchmark); it is the SAME
42
+ table the interner already keeps, so surfacing it adds no behaviour.
43
+
44
+ The key is computed directly from the positional constructor arguments (which correspond 1:1 to
45
+ ``field_names`` for both ``@dataclass`` classes and ``Thunk``), so a cache hit avoids allocating
46
+ a throwaway instance entirely.
47
+ """
48
+ pool: dict[tuple, object] = {}
49
+ original_init = cls.__init__
50
+
51
+ def __new__(klass, *args):
52
+ key = (klass,) + tuple(id(a) for a in args)
53
+ existing = pool.get(key)
54
+ if existing is not None:
55
+ return existing
56
+ instance = object.__new__(klass)
57
+ original_init(instance, *args)
58
+ pool[key] = instance
59
+ return instance
60
+
61
+ cls_any: Any = cls
62
+ cls_any.__new__ = __new__
63
+ cls_any.__init__ = lambda self, *args, **kwargs: None
64
+ cls_any.__intern_pool__ = pool
65
+ return cls
66
+
67
+
68
+ @overload
69
+ def interned(cls: type[_T], *, slots: bool = ...) -> type[_T]: ...
70
+
71
+
72
+ @overload
73
+ def interned(
74
+ cls: None = ..., *, slots: bool = ...
75
+ ) -> Callable[[type[_T]], type[_T]]: ...
76
+
77
+
78
+ @dataclass_transform(eq_default=False)
79
+ def interned(cls=None, *, slots=True):
80
+ """Class decorator: make ``cls`` a frozen-by-identity dataclass and hash-cons its instances.
81
+
82
+ Applies ``dataclass(eq=False, slots=slots)`` internally (so generated code needs only
83
+ ``@interned``, not a separate ``@dataclass``), then interns. ``slots=True`` (the default) makes the
84
+ closures the compiler emits slotted, which is faster and lighter; ``eq=False`` keeps identity-based
85
+ equality. Usable bare (``@interned``) or parameterised (``@interned(slots=False)``).
86
+ """
87
+ if cls is None:
88
+ return lambda klass: interned(klass, slots=slots)
89
+ cls = dataclass(eq=False, slots=slots)(cls)
90
+ field_names = tuple(f.name for f in dataclass_fields(cls))
91
+ return _intern(cls, field_names)
92
+
93
+
94
+ def _deterministic_hash(*parts: int) -> int:
95
+ """A deterministic hash from a sequence of integers, independent of ``PYTHONHASHSEED``."""
96
+ data = struct.pack(f">{len(parts)}q", *parts)
97
+ return int.from_bytes(hashlib.sha256(data).digest()[:8], "big")
98
+
99
+
100
+ @fixpoint_slotted
101
+ class Thunk:
102
+ """A suspended application (redex). Interned so structurally equal redexes share identity,
103
+ enabling tabling: ``weak_head_normal_form`` is computed once per distinct ``Thunk``.
104
+
105
+ Slotted for speed and low memory; ``@fixpoint_slotted`` automatically adds a dedicated cache
106
+ slot for each ``fixpoint_cached_property``, avoiding a ``__dict__`` or intermediate dict.
107
+ Identity-based equality (``object.__eq__``).
108
+ """
109
+
110
+ __slots__ = ("callee", "argument")
111
+
112
+ def __init__(self, callee: Lambda, argument: Lambda) -> None:
113
+ self.callee = callee
114
+ self.argument = argument
115
+
116
+ def __call__(self, a: Lambda) -> Thunk:
117
+ return Thunk(self, a)
118
+
119
+ @fixpoint_cached_property(bottom=lambda: _BOTTOM)
120
+ def weak_head_normal_form(self) -> Lambda | Literal[_DefunBottom.BOTTOM]:
121
+ callee = self.callee
122
+ if _is_thunk(callee):
123
+ callee = callee.weak_head_normal_form
124
+ if callee is _BOTTOM:
125
+ return _BOTTOM
126
+ result = callee(self.argument)
127
+ return result.weak_head_normal_form if _is_thunk(result) else result
128
+
129
+
130
+ Thunk = _intern(Thunk, ("callee", "argument"))
131
+
132
+
133
+ def _is_thunk(x: object) -> TypeGuard[Thunk]:
134
+ return isinstance(x, Thunk)
135
+
136
+
137
+ class Lambda(Protocol):
138
+ """A lambda value: any callable that takes a Lambda and returns a Lambda or Thunk."""
139
+
140
+ def __call__(self, a: Lambda) -> Union["Lambda", "Thunk"]: ...
141
+
142
+
143
+ # --- stack helpers ---------------------------------------------------------------------------------
144
+
145
+ _COMPILE_RECURSION_LIMIT = 16_000
146
+ _RECURSION_LIMIT = 200_000
147
+ _STACK_SIZE = 1024 * 1024 * 1024 # 1 GiB
148
+
149
+
150
+ @contextmanager
151
+ def recursion_headroom() -> Iterator[None]:
152
+ previous = sys.getrecursionlimit()
153
+ sys.setrecursionlimit(max(previous, _COMPILE_RECURSION_LIMIT))
154
+ try:
155
+ yield
156
+ finally:
157
+ sys.setrecursionlimit(previous)
158
+
159
+
160
+ def _python_tag() -> str:
161
+ """A Python-version tag for generated-module filenames, e.g. ``py313``. Defunctionalized modules
162
+ are rendered with ``ast.unparse``, whose formatting can differ between Python versions, so a module
163
+ generated under one interpreter must not be reused under another; the tag keeps artifacts distinct.
164
+ """
165
+ return f"py{sys.version_info.major}{sys.version_info.minor}"
166
+
167
+
168
+ def run_in_large_stack(thunk):
169
+ """Run ``thunk`` in a thread with a 1 GiB C stack and a high recursion limit."""
170
+ result: list = []
171
+
172
+ def run() -> None:
173
+ previous_limit = sys.getrecursionlimit()
174
+ sys.setrecursionlimit(max(previous_limit, _RECURSION_LIMIT))
175
+ try:
176
+ result.append(thunk())
177
+ finally:
178
+ sys.setrecursionlimit(previous_limit)
179
+
180
+ previous_stack_size = threading.stack_size(_STACK_SIZE)
181
+ try:
182
+ worker = threading.Thread(target=run)
183
+ worker.start()
184
+ worker.join()
185
+ finally:
186
+ threading.stack_size(previous_stack_size)
187
+ (single_result,) = result
188
+ return single_result
@@ -0,0 +1,470 @@
1
+ """The defunctionalization boundary: quote, compile, decode, canonicalize, load.
2
+
3
+ Thin Python layer analogous to ``_specialize`` but for the defunctionalization target. The lambda
4
+ compiler ``DEFUN`` produces the Scott-encoded ``ast.Module``; this module quotes the input, runs
5
+ the compiler in the interpreter, decodes the Scott AST to a real ``ast.Module``, deduplicates
6
+ class definitions by node identity (``memoized_decode``), renames every class by the Merkle hash of
7
+ its COMPILED body (``_canonicalize_classes``), and unparses to source. ``load`` execs the source
8
+ with the runtime globals (``Thunk``, ``interned``, ``dataclass``) and returns the ``compiled``
9
+ value.
10
+
11
+ Content addressing happens on the compiled dataclass, not the source lambda term. Two source
12
+ closures of the same shape that capture variables at different de Bruijn depths compile to the same
13
+ dataclass (same arity, byte-identical ``__call__`` body over positional capture fields), so the
14
+ boundary collapses them to one class. This is coarser than the source's term equality and makes the
15
+ generated code smaller and more reusable.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import contextlib
22
+ import copy
23
+ import hashlib
24
+ from typing import TypeVar
25
+
26
+ from co_lambda._ast import Node
27
+ from co_lambda._codec import quote_binnat
28
+ from co_lambda._defun_codegen import DEFUN
29
+ from co_lambda._defun_runtime import Thunk, _BOTTOM, _is_thunk, interned, run_in_large_stack
30
+
31
+ _AstNode = TypeVar("_AstNode", bound=ast.AST)
32
+ from co_lambda._dsl import app, build
33
+ from co_lambda._pyast import SUPPORTED, _ARITY, _reset_gensym, decode, memoized_decode
34
+
35
+
36
+ class _RenameClasses(ast.NodeTransformer):
37
+ """Rewrite ``ast.Name`` references to class names according to ``mapping``."""
38
+
39
+ def __init__(self, mapping: "dict[str, str]") -> None:
40
+ self._mapping = mapping
41
+
42
+ def visit_Name(self, node: ast.Name) -> ast.Name:
43
+ renamed = self._mapping.get(node.id)
44
+ if renamed is not None:
45
+ return ast.copy_location(ast.Name(id=renamed, ctx=node.ctx), node)
46
+ return node
47
+
48
+
49
+ def _rename_copy(node: _AstNode, mapping: "dict[str, str]") -> _AstNode:
50
+ """A deep copy of ``node`` with class-name references rewritten per ``mapping``."""
51
+ renamed = _RenameClasses(mapping).visit(copy.deepcopy(node))
52
+ assert isinstance(renamed, type(node)), (
53
+ f"_RenameClasses must preserve node type {type(node).__name__}, got {type(renamed).__name__}"
54
+ )
55
+ return renamed
56
+
57
+
58
+ class _RenameFields(ast.NodeTransformer):
59
+ """Rewrite a class's capture-field names (AnnAssign targets and ``self.<field>`` accesses)."""
60
+
61
+ def __init__(self, mapping: "dict[str, str]") -> None:
62
+ self._mapping = mapping
63
+
64
+ def visit_Name(self, node: ast.Name) -> ast.Name:
65
+ renamed = self._mapping.get(node.id)
66
+ if renamed is not None:
67
+ return ast.copy_location(ast.Name(id=renamed, ctx=node.ctx), node)
68
+ return node
69
+
70
+ def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
71
+ self.generic_visit(node)
72
+ renamed = self._mapping.get(node.attr)
73
+ if renamed is not None:
74
+ node.attr = renamed
75
+ return node
76
+
77
+
78
+ def _canonicalize_fields(classdef: ast.ClassDef) -> ast.ClassDef:
79
+ """Rename a class's capture fields to positional ``cap_<i>`` names (in definition order)."""
80
+ field_names = [
81
+ statement.target.id
82
+ for statement in classdef.body
83
+ if isinstance(statement, ast.AnnAssign) and isinstance(statement.target, ast.Name)
84
+ ]
85
+ mapping = {name: f"cap_{position}" for position, name in enumerate(field_names)}
86
+ renamed = _RenameFields(mapping).visit(copy.deepcopy(classdef))
87
+ assert isinstance(renamed, ast.ClassDef)
88
+ return renamed
89
+
90
+
91
+ def _canonicalize_classes(module: ast.Module) -> ast.Module:
92
+ """Rename every closure class by a content hash of its COMPILED dataclass and drop duplicates.
93
+
94
+ Capture fields are first renamed positionally (``cap_0``, ``cap_1``, ...). The content hash of a
95
+ class is then the Merkle hash of its (field-canonicalized) body with references to other classes
96
+ replaced by THEIR content hashes (computed bottom-up over the acyclic class-reference DAG) and the
97
+ class's own name replaced by a fixed placeholder. Two classes with identical compiled bodies hash
98
+ equal and collapse to one. Definitions are emitted sorted by name, so the output is stable under
99
+ local source edits and identical between the in-process and self-hosted compilers.
100
+ """
101
+ classdefs: "dict[str, ast.ClassDef]" = {}
102
+ others: "list[ast.stmt]" = []
103
+ for statement in module.body:
104
+ if isinstance(statement, ast.ClassDef):
105
+ field_canonical = _canonicalize_fields(statement)
106
+ kept = classdefs.get(field_canonical.name)
107
+ if kept is not None:
108
+ assert ast.dump(kept) == ast.dump(field_canonical), (
109
+ f"provisional class {field_canonical.name!r} has two non-identical definitions"
110
+ )
111
+ continue
112
+ classdefs[field_canonical.name] = field_canonical
113
+ else:
114
+ others.append(statement)
115
+ provisional = set(classdefs)
116
+
117
+ def referenced(classdef: ast.ClassDef) -> "set[str]":
118
+ return {n.id for n in ast.walk(classdef) if isinstance(n, ast.Name) and n.id in provisional}
119
+
120
+ canonical: "dict[str, str]" = {}
121
+ in_progress: "set[str]" = set()
122
+
123
+ def canonical_name(name: str) -> str:
124
+ cached = canonical.get(name)
125
+ if cached is not None:
126
+ return cached
127
+ assert name not in in_progress, f"class reference cycle through {name!r}"
128
+ in_progress.add(name)
129
+ classdef = classdefs[name]
130
+ mapping = {reference: canonical_name(reference) for reference in referenced(classdef)}
131
+ mapping[name] = "_SELF_"
132
+ key_node = _rename_copy(classdef, mapping)
133
+ assert isinstance(key_node, ast.ClassDef)
134
+ key_node.name = "_SELF_"
135
+ digest = hashlib.sha256(ast.dump(key_node).encode()).digest()[:8]
136
+ result = "vg_" + digest.hex()
137
+ in_progress.discard(name)
138
+ canonical[name] = result
139
+ return result
140
+
141
+ for name in classdefs:
142
+ canonical_name(name)
143
+
144
+ global_mapping = {name: canonical[name] for name in provisional}
145
+ deduped: "dict[str, ast.ClassDef]" = {}
146
+ for name, classdef in classdefs.items():
147
+ renamed = _rename_copy(classdef, global_mapping)
148
+ assert isinstance(renamed, ast.ClassDef)
149
+ renamed.name = canonical[name]
150
+ deduped[renamed.name] = renamed
151
+
152
+ sorted_defs: "list[ast.stmt]" = [deduped[key] for key in sorted(deduped)]
153
+ new_others = [_rename_copy(statement, global_mapping) for statement in others]
154
+ module.body = sorted_defs + new_others
155
+ return module
156
+
157
+
158
+ # --- Direct defun decoder: decode Scott-encoded AST from defunctionalized values -----------------
159
+ # Mirrors ``_pyast.decode`` but operates on defun values (Thunk + closures) instead of interpreter
160
+ # Nodes, eliminating the expensive ``reify`` NbE round-trip in the self-hosted compilation path.
161
+
162
+
163
+ class _TagMarker:
164
+ """Callable marker for Scott constructor extraction. When the Scott value selects this handler,
165
+ it calls ``__call__`` once per field, accumulating the field values."""
166
+
167
+ __slots__ = ("tag", "fields")
168
+
169
+ def __init__(self, tag: int) -> None:
170
+ self.tag = tag
171
+ self.fields: list[object] = []
172
+
173
+ def __call__(self, argument: object) -> "_TagMarker":
174
+ self.fields.append(argument)
175
+ return self
176
+
177
+
178
+ class _ChurchApp:
179
+ """Marker node in a Church numeral spine: successor applied to predecessor."""
180
+
181
+ __slots__ = ("argument",)
182
+
183
+ def __init__(self, argument: object) -> None:
184
+ self.argument = argument
185
+
186
+
187
+ class _ChurchSucc:
188
+ """Callable marker for Church numeral successor."""
189
+
190
+ __slots__ = ()
191
+
192
+ def __call__(self, argument: object) -> _ChurchApp:
193
+ return _ChurchApp(argument)
194
+
195
+
196
+ _CHURCH_SUCC_DEFUN = _ChurchSucc()
197
+ _CHURCH_ZERO_DEFUN = object()
198
+
199
+ _church_int_cache: "dict[int, int]" = {}
200
+ _defun_gensym_ids: "dict[int, str]" = {}
201
+ _defun_gensym_counter: int = 0
202
+ _defun_decode_memo: "dict[int, ast.AST] | None" = None
203
+
204
+
205
+ def _reset_defun_gensym() -> None:
206
+ _church_int_cache.clear()
207
+ _defun_gensym_ids.clear()
208
+ global _defun_gensym_counter
209
+ _defun_gensym_counter = 0
210
+
211
+
212
+ @contextlib.contextmanager
213
+ def _memoized_decode_defun():
214
+ global _defun_decode_memo
215
+ assert _defun_decode_memo is None, "memoized decode_defun does not nest"
216
+ _defun_decode_memo = {}
217
+ try:
218
+ yield
219
+ finally:
220
+ _defun_decode_memo = None
221
+
222
+
223
+ def _force_defun(value: object) -> object:
224
+ if _is_thunk(value):
225
+ whnf = value.weak_head_normal_form
226
+ assert whnf is not _BOTTOM, "hit bottom while forcing defun value"
227
+ return whnf
228
+ return value
229
+
230
+
231
+ def _extract_defun(value: object, arities: "tuple[int, ...]") -> "tuple[int, list[object]]":
232
+ current = _force_defun(value)
233
+ for tag in range(len(arities)):
234
+ assert callable(current), f"expected callable during extraction, got {type(current).__name__}"
235
+ result = current(_TagMarker(tag))
236
+ current = _force_defun(result)
237
+ assert isinstance(current, _TagMarker), (
238
+ f"expected _TagMarker after extraction, got {type(current).__name__}"
239
+ )
240
+ return current.tag, current.fields
241
+
242
+
243
+ def _church_to_int_defun(value: object) -> int:
244
+ key = id(value)
245
+ cached = _church_int_cache.get(key)
246
+ if cached is not None:
247
+ return cached
248
+ current = _force_defun(value)
249
+ assert callable(current), f"church spine head must be callable, got {type(current).__name__}"
250
+ current = current(_CHURCH_SUCC_DEFUN)
251
+ current = _force_defun(current)
252
+ assert callable(current), f"church spine successor result must be callable, got {type(current).__name__}"
253
+ current = current(_CHURCH_ZERO_DEFUN)
254
+ current = _force_defun(current)
255
+ count = 0
256
+ while isinstance(current, _ChurchApp):
257
+ count += 1
258
+ current = _force_defun(current.argument)
259
+ assert current is _CHURCH_ZERO_DEFUN, "church spine did not end at zero marker"
260
+ _church_int_cache[key] = count
261
+ return count
262
+
263
+
264
+ def _decode_scott_list_defun(value: object) -> "list[object]":
265
+ items: "list[object]" = []
266
+ current = value
267
+ while True:
268
+ tag, fields = _extract_defun(current, (2, 0))
269
+ if tag == 1:
270
+ return items
271
+ assert tag == 0, f"expected cons (0) or nil (1), got {tag}"
272
+ items.append(fields[0])
273
+ current = fields[1]
274
+
275
+
276
+ def _gensym_name_defun(payload: object) -> str:
277
+ global _defun_gensym_counter
278
+ key = id(payload)
279
+ existing = _defun_gensym_ids.get(key)
280
+ if existing is not None:
281
+ return existing
282
+ name = f"vg_{_defun_gensym_counter:016x}"
283
+ _defun_gensym_counter += 1
284
+ _defun_gensym_ids[key] = name
285
+ return name
286
+
287
+
288
+ def _decode_field_defun(value: object) -> object:
289
+ _, fields = _extract_defun(value, (2,))
290
+ kind_value, payload = fields
291
+ kind = _church_to_int_defun(kind_value)
292
+ match kind:
293
+ case 0:
294
+ return decode_defun(payload)
295
+ case 1:
296
+ return [_decode_field_defun(item) for item in _decode_scott_list_defun(payload)]
297
+ case 2:
298
+ return _church_to_int_defun(payload)
299
+ case 3:
300
+ return "".join(chr(_church_to_int_defun(code)) for code in _decode_scott_list_defun(payload))
301
+ case 5:
302
+ return None
303
+ case 7:
304
+ return _gensym_name_defun(payload)
305
+ case _:
306
+ raise ValueError(f"defun decode: unsupported field kind {kind}")
307
+
308
+
309
+ def decode_defun(value: object) -> ast.AST:
310
+ """Decode a Scott-encoded AST directly from defunctionalized values (Thunks + closures).
311
+
312
+ Skips the ``reify`` NbE round-trip that converts defun values to interpreter Nodes. Under
313
+ ``_memoized_decode_defun``, each distinct interned value is decoded once (keyed by identity).
314
+ """
315
+ if _defun_decode_memo is not None:
316
+ cached = _defun_decode_memo.get(id(value))
317
+ if cached is not None:
318
+ return cached
319
+ tag, fields = _extract_defun(value, _ARITY)
320
+ cls = SUPPORTED[tag]
321
+ decoded = cls(*[_decode_field_defun(field) for field in fields])
322
+ if _defun_decode_memo is not None:
323
+ _defun_decode_memo[id(value)] = decoded
324
+ return decoded
325
+
326
+
327
+ def defunctionalize(node: Node) -> str:
328
+ """Compile a lambda term to defunctionalized Python source (a module of closure classes).
329
+
330
+ Runs in a large-stack thread: the interpreter's substitution recursion can be as deep as the term,
331
+ which overflows the C stack on Python 3.12+ (which caps C recursion regardless of
332
+ ``setrecursionlimit``); ``run_in_large_stack`` gives it a 1 GiB stack and a high recursion limit.
333
+ """
334
+ def work() -> str:
335
+ module = build(app(DEFUN, quote_binnat(node)))
336
+ _reset_gensym()
337
+ with memoized_decode():
338
+ decoded = decode(module)
339
+ assert isinstance(decoded, ast.Module)
340
+ canonical_module = _canonicalize_classes(decoded)
341
+ return ast.unparse(ast.fix_missing_locations(canonical_module))
342
+
343
+ return run_in_large_stack(work)
344
+
345
+
346
+ def _defun_globals() -> dict:
347
+ return {
348
+ "Thunk": Thunk,
349
+ "interned": interned,
350
+ }
351
+
352
+
353
+ def load_namespace(source: str) -> dict:
354
+ """Execute defunctionalized source and return the whole module namespace.
355
+
356
+ The namespace holds every generated closure class (each carrying its ``__intern_pool__``) and the
357
+ ``compiled`` value, so a caller can both run the program and inspect its tabled objects.
358
+ """
359
+ namespace = _defun_globals()
360
+ exec(compile(source, "<defun>", "exec"), namespace) # noqa: S102
361
+ return namespace
362
+
363
+
364
+ def load(source: str) -> object:
365
+ """Execute defunctionalized source and return the ``compiled`` value."""
366
+ return load_namespace(source)["compiled"]
367
+
368
+
369
+ def defunctionalize_and_load(node: Node) -> object:
370
+ """Compile a lambda term to defunctionalized code and load the resulting value."""
371
+ return load(defunctionalize(node))
372
+
373
+
374
+ # A self-contained import header so a generated defunctionalized module runs on its own: it binds the
375
+ # exactly two runtime free names the generated code references (``Thunk``, ``interned``). ``interned``
376
+ # applies ``dataclass(eq=False)`` itself, so generated classes carry only ``@interned``.
377
+ _DEFUN_MODULE_HEADER = (
378
+ "# Generated, self-contained module: the import header is added at serialization time (see\n"
379
+ "# co_lambda._defunctionalize.runnable_defun_module); the body is emitted by the DEFUN lambda\n"
380
+ "# term and content-addressed by compiled dataclass shape.\n"
381
+ "from co_lambda._defun_runtime import Lambda, Thunk, interned\n"
382
+ )
383
+
384
+
385
+ def runnable_defun_module(source: str) -> str:
386
+ """Prepend the runtime import header so a defunctionalized module is importable on its own."""
387
+ return _DEFUN_MODULE_HEADER + "\n" + source
388
+
389
+
390
+ def defun_compiler_source() -> str:
391
+ """The defunctionalization compiler ``DEFUN`` self-compiled to a runnable dataclass module.
392
+
393
+ This is the dataclass-form ``compiled compiler``: ``DEFUN`` defunctionalized by itself. Importing
394
+ the result binds ``compiled`` to the defunctionalized ``DEFUN`` value; applying it (through a
395
+ ``Thunk``) to a quoted program yields that program's compiled Scott ``ast.Module`` as a
396
+ defunctionalized value.
397
+ """
398
+ from co_lambda._defun_codegen import DEFUN
399
+
400
+ return runnable_defun_module(defunctionalize(build(DEFUN)))
401
+
402
+
403
+ def compile_with_defun(engine: object, node: Node) -> str:
404
+ """Compile ``node`` by RUNNING a defunctionalized ``DEFUN`` engine (the dataclass compiled compiler).
405
+
406
+ ``engine`` is the ``compiled`` value of a ``defun_compiler_source`` module. The node is quoted and
407
+ itself defunctionalized to feed the engine a defunctionalized Scott source value; the engine's
408
+ output (a defunctionalized Scott ``ast.Module``) is reified, decoded, canonicalized, and unparsed,
409
+ yielding exactly what the in-process ``defunctionalize`` produces, by self-hosting.
410
+ """
411
+ quoted_argument = defunctionalize_and_load(build(quote_binnat(node)))
412
+
413
+ def work() -> str:
414
+ result = Thunk(engine, quoted_argument).weak_head_normal_form
415
+ if result is _BOTTOM:
416
+ raise ValueError("the defunctionalized compiler did not produce a module")
417
+ _reset_defun_gensym()
418
+ with _memoized_decode_defun():
419
+ decoded = decode_defun(result)
420
+ assert isinstance(decoded, ast.Module)
421
+ canonical_module = _canonicalize_classes(decoded)
422
+ return ast.unparse(ast.fix_missing_locations(canonical_module))
423
+
424
+ return run_in_large_stack(work)
425
+
426
+
427
+ def reify(value: object, depth: int = 0) -> Node:
428
+ """Read a defunctionalized value back to an interpreter ``Node``.
429
+
430
+ Forces ``Thunk.weak_head_normal_form`` to reach a closure (a defunctionalized dataclass with
431
+ ``__call__``) or a neutral term (``Node``), then probes closures under a fresh neutral binder
432
+ to read their body.
433
+ """
434
+ from co_lambda._ast import Node as AstNode, make_app, make_lam, make_var
435
+
436
+ if _is_thunk(value):
437
+ whnf = value.weak_head_normal_form
438
+ if whnf is _BOTTOM:
439
+ raise ValueError("reify: hit bottom (unproductive cycle)")
440
+ return reify(whnf, depth)
441
+
442
+ if isinstance(value, AstNode):
443
+ return _reify_node(value, depth)
444
+
445
+ if callable(value):
446
+ probe = make_var(depth)
447
+ result = value(probe)
448
+ return make_lam(reify(result, depth + 1))
449
+
450
+ raise ValueError(f"reify: cannot read back {value!r}")
451
+
452
+
453
+ def _reify_node(node: Node, depth: int) -> Node:
454
+ """Read back an interpreter Node that appears as a neutral term in defunctionalized output.
455
+
456
+ Probe variables are created as ``make_var(level)`` where ``level`` is the depth at probe time.
457
+ When quoting back, a variable at level ``l`` under ``depth`` binders has de Bruijn index
458
+ ``depth - l - 1``. Sub-terms of neutral applications may be defunctionalized values (when a
459
+ closure was probed with a neutral variable); these are handed back to ``reify``.
460
+ """
461
+ from co_lambda._ast import App, Var, make_app, make_lam, make_var
462
+
463
+ whnf = node.weak_head_normal_form
464
+ match whnf:
465
+ case Var(index=level):
466
+ return make_var(depth - level - 1)
467
+ case App(function=function, argument=argument):
468
+ return make_app(reify(function, depth), reify(argument, depth))
469
+ case _:
470
+ raise ValueError(f"reify: unexpected weak head normal form in neutral term: {whnf!r}")