curryparty 0.3.1__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: curryparty
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Python playground to learn lambda calculus
5
5
  Author: Antonin P
6
6
  Author-email: Antonin P <antonin.peronnet@telecom-paris.fr>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "curryparty"
3
- version = "0.3.1"
3
+ version = "0.3.3"
4
4
  description = "Python playground to learn lambda calculus"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -1,3 +1,8 @@
1
+ """The `curryparty` library, a playground to learn lambda-calculus.
2
+
3
+ This library is intended to be used in an interactive
4
+ """
5
+
1
6
  from typing import Iterable, List, Optional, Union
2
7
 
3
8
  try:
@@ -8,9 +13,9 @@ except ImportError:
8
13
  )
9
14
  import uuid
10
15
 
11
- from svg import SVG, Rect
16
+ from svg import SVG, Length, Rect, ViewBoxSpec
12
17
 
13
- from .core import SCHEMA, beta_reduce, compose, find_redexes, find_variables, subtree
18
+ from .core import AbstractTerm
14
19
  from .display import (
15
20
  compute_height,
16
21
  compute_svg_frame_final,
@@ -33,24 +38,18 @@ def log2(n):
33
38
 
34
39
 
35
40
  class Term:
36
- def __init__(self, nodes: pl.DataFrame):
37
- assert nodes.schema == SCHEMA, (
38
- f"{nodes.schema} is different from expected {SCHEMA}"
39
- )
40
- self.nodes = nodes
41
- self.lamb = None
41
+ def __init__(self, data: AbstractTerm):
42
+ self.data = data
42
43
 
43
44
  def __call__(self, other: "Term") -> "Term":
44
- return Term(compose(self.nodes, other.nodes))
45
+ return Term((self.data)(other.data))
45
46
 
46
47
  def beta(self) -> Optional["Term"]:
47
- candidates = find_redexes(self.nodes).first().collect()
48
- if len(candidates) == 0:
48
+ candidates = self.data.find_redexes()
49
+ redex = next(candidates, None)
50
+ if redex is None:
49
51
  return None
50
- _redex, lamb, b = candidates.row(0)
51
- self.lamb = lamb
52
- self.b = b
53
- reduced = beta_reduce(self.nodes, lamb, b)
52
+ reduced = self.data.beta_reduce(redex)
54
53
  return Term(reduced)
55
54
 
56
55
  def reduce(self):
@@ -61,41 +60,44 @@ class Term:
61
60
 
62
61
  def reduction_chain(self) -> Iterable["Term"]:
63
62
  term = self
64
- while True:
63
+ while term is not None:
65
64
  yield term
66
65
  term = term.beta()
67
- if term is None:
68
- break
69
66
 
70
67
  def show_beta(self, duration=7):
71
68
  """
72
69
  Generates an HTML representation that toggles visibility between
73
70
  a static state and a SMIL animation on hover using pure CSS.
74
71
  """
75
- candidates = find_redexes(self.nodes).first().collect()
76
- if len(candidates) == 0:
72
+
73
+ candidates = self.data.find_redexes()
74
+ redex = next(candidates, None)
75
+ if redex is None:
77
76
  return self._repr_html_()
78
77
 
79
- _redex, lamb, b = candidates.row(0)
80
- new_nodes = beta_reduce(self.nodes, lamb, b)
81
- vars = find_variables(self.nodes, lamb).collect()["id"]
82
- b_subtree = subtree(self.nodes, b).collect()
83
- height = min(compute_height(self.nodes), compute_height(new_nodes)) * 2
84
- if count_variables(self.nodes) == 0:
78
+ lamb = self.data.node(redex).get_left()
79
+ assert lamb is not None
80
+ b = self.data.node(redex).get_arg()
81
+ assert b is not None
82
+ new_nodes = self.data.beta_reduce(redex)
83
+ vars = list(self.data.find_variables(lamb))
84
+ b_subtree = list(self.data.get_subtree(b))
85
+ height = min(compute_height(self.data), compute_height(new_nodes)) * 2
86
+ if count_variables(self.data) == 0:
85
87
  return "no width"
86
- raw_width = max(count_variables(self.nodes), count_variables(new_nodes))
88
+ raw_width = max(count_variables(self.data), count_variables(new_nodes))
87
89
  width = 1 << (1 + log2(raw_width))
88
90
  frame_data: list[ShapeAnimFrame] = []
89
91
  N_STEPS = 6
90
92
 
91
93
  for t in range(N_STEPS):
92
94
  if t == 0:
93
- items = compute_svg_frame_init(self.nodes, t)
95
+ items = compute_svg_frame_init(self.data, t)
94
96
  elif t == 1 or t == 2:
95
- items = compute_svg_frame_phase_a(self.nodes, lamb, b_subtree, vars, t)
97
+ items = compute_svg_frame_phase_a(self.data, redex, b_subtree, vars, t)
96
98
  elif t == 3 or t == 4:
97
99
  items = compute_svg_frame_phase_b(
98
- self.nodes, lamb, b_subtree, new_nodes, t
100
+ self.data, redex, b_subtree, new_nodes, t
99
101
  )
100
102
  else:
101
103
  items = compute_svg_frame_final(new_nodes, t)
@@ -114,10 +116,10 @@ class Term:
114
116
  anim_elements.append(
115
117
  Rect(
116
118
  id=box_id,
117
- x=f"{-width}",
118
- y="0",
119
- width="100%",
120
- height="100%",
119
+ x=-width,
120
+ y=0,
121
+ width=Length(100, "%"),
122
+ height=Length(100, "%"),
121
123
  fill="transparent",
122
124
  )
123
125
  )
@@ -126,7 +128,7 @@ class Term:
126
128
  H = height * 40
127
129
  anim_svg = SVG(
128
130
  xmlns="http://www.w3.org/2000/svg",
129
- viewBox=f"{-width} 0 {2 * width} {height}",
131
+ viewBox=ViewBoxSpec(-width, 0, 2 * width, height),
130
132
  style=f"max-height:{H}px",
131
133
  elements=anim_elements,
132
134
  ).as_str()
@@ -140,11 +142,11 @@ class Term:
140
142
  "</div>"
141
143
  )
142
144
 
143
- def _repr_html_(self, x0=-10):
144
- frame = sorted(compute_svg_frame_init(self.nodes), key=lambda x: x.zindex)
145
+ def _repr_html_(self):
146
+ frame = sorted(compute_svg_frame_init(self.data), key=lambda x: x.zindex)
145
147
 
146
- width = (1 << (1 + log2(count_variables(self.nodes)))) + 4
147
- height = compute_height(self.nodes) + 1
148
+ width = (1 << (1 + log2(count_variables(self.data)))) + 4
149
+ height = compute_height(self.data) + 1
148
150
 
149
151
  elements = [ShapeAnim.from_single_frame(x) for x in frame]
150
152
 
@@ -152,13 +154,12 @@ class Term:
152
154
  H = height * 40
153
155
  W = width * 40
154
156
 
155
- svg_element = SVG(
157
+ return SVG(
156
158
  xmlns="http://www.w3.org/2000/svg",
157
- viewBox=f"{-1} 0 {width} {height}", # type: ignore
159
+ viewBox=ViewBoxSpec(-1, 0, width, height),
158
160
  elements=elements,
159
161
  style=f"max-height:{H}px; margin-left: clamp(0px, calc(100% - {W}px), 100px)",
160
162
  ).as_str()
161
- return f"<div>{svg_element}</div>"
162
163
 
163
164
 
164
165
  def offset_var(x: Union[int, str], offset: int) -> Union[int, str]:
@@ -190,11 +191,12 @@ class L:
190
191
  self.args.append((offset + i, offset + x))
191
192
  self.n += t.n
192
193
  elif isinstance(t, Term):
193
- for i, x in t.nodes.select("id", "ref").drop_nulls().iter_rows():
194
+ # fixme: encapsulate
195
+ for i, x in t.data.nodes.select("id", "ref").drop_nulls().iter_rows():
194
196
  self.refs.append((offset + i, offset + x))
195
- for i, x in t.nodes.select("id", "arg").drop_nulls().iter_rows():
197
+ for i, x in t.data.nodes.select("id", "arg").drop_nulls().iter_rows():
196
198
  self.args.append((offset + i, offset + x))
197
- self.n += len(t.nodes)
199
+ self.n += len(t.data.nodes)
198
200
  else:
199
201
  assert isinstance(t, str)
200
202
  self.refs.append((self.n, t))
@@ -242,8 +244,8 @@ class L:
242
244
  .to_frame()
243
245
  .join(ref, on="id", how="left")
244
246
  .join(arg, on="id", how="left")
245
- ).with_columns(bid=pl.struct(major="id", minor="id"))
246
- return Term(data)
247
+ ).with_columns(prev=None)
248
+ return Term(AbstractTerm(data))
247
249
 
248
250
 
249
251
  def V(name: str) -> L:
@@ -0,0 +1,326 @@
1
+ """Core engine of lambda-calculus.
2
+
3
+ This file defines 2 main types:
4
+ - `AbstractTerm`, a Lambda-calculus term
5
+ - `NodeId`, a node in the term.
6
+
7
+ Because this is the most complex module of the library, it is useful to define some terms.
8
+ I will refer to these terms in the above comments for concision.
9
+
10
+ # 1. Term
11
+
12
+ A `Term` is a complete Lambda-calculus expression.
13
+
14
+ # 2. Node
15
+
16
+ A `Node` is what constitutes a `Term`. There are 3 types of Nodes.
17
+
18
+ 2.1 There is `Lambda` Node, that has a single children
19
+ 2.2 There is a `Variable` Node. It has no children and refers to a `Lambda` Node.
20
+ We say the variable is `bound` to the corresponding Lambda.
21
+ 2.3 There is an `Application` Node. It has 2 children.
22
+ The first children is the function and the second children is the argument.
23
+
24
+ # 3. Subtrees
25
+
26
+ Each `Node` in the `Term` has a subtree. It is made up of itself, his children, his grandchildren ... and so on.
27
+ The **strict** subtree of a node is its subtree but without the node itself.
28
+
29
+ # 4. Redex
30
+
31
+ A "Redex Node" is a `Node` in the `Term` that we can reduce. It must have the following properties:
32
+
33
+ 4.1 A "Redex Node" is a `Node` of type "Application". The subtree of the "Redex node" is called simply the "Redex".
34
+ 4.2 The first children of the Redex node must be a lambda. The trunk node is the child of the lambda node.
35
+ The "Trunk" is defined as the strict subtree of the lambda node (i.e the subtree of the trunk).
36
+ 4.3 The variables that are bound to the "trunk node" are called the "Stumps".
37
+ 4.4 The second children of the Redex node can be anything. We call it the **Substitute node** of the Redex.
38
+ The "Subsitute" is simply the subtree of the "substitute node" of the Redex.
39
+
40
+ # 5. Beta-reduction
41
+
42
+ The operation of "beta-reduction" transforms a Redex inside the Term, leaving the rest of the term unchanged.
43
+ The beta-reduction of a Term on a Redex R consists in 2 steps:
44
+
45
+ 5.1 Beta-reduction replaces all stumps by a copy of the substitute
46
+ 5.1.1 If an application had a stump node as one of its children, the stump node is replaced by a copy of the substitute.
47
+ 5.1.2 All variables in the substitute are now bound to the corresponding duplicated lambda in their duplicated substitute.
48
+ But if a variable was bound to a lambda higher in the tree (not in the substitute), the bound remains
49
+ 5.2 Beta-reduction removes the redex node and the lambda node
50
+ 5.3 Beta-reduction removes the substitute (it has already been duplicated for each stump)
51
+ 5.4 All we have left is the trunk, and it is placed where the original redex was.
52
+
53
+ """
54
+
55
+ from __future__ import annotations
56
+
57
+ from dataclasses import dataclass
58
+ from typing import Generator, List, NewType, Optional
59
+
60
+ import polars as pl
61
+ from polars import Schema, UInt32
62
+
63
+ __all__ = ["AbstractTerm", "NodeId"]
64
+
65
+ # See `Node`
66
+ PREV_ID_SCHEMA = {"stump": UInt32, "local": UInt32}
67
+
68
+ SCHEMA = Schema(
69
+ {"id": UInt32, "ref": UInt32, "arg": UInt32, "prev": pl.Struct(PREV_ID_SCHEMA)},
70
+ )
71
+
72
+
73
+ def _shift(nodes: pl.DataFrame, offset: int):
74
+ return nodes.with_columns(
75
+ pl.col("id") + offset,
76
+ pl.col("ref") + offset,
77
+ pl.col("arg") + offset,
78
+ prev=None,
79
+ )
80
+
81
+
82
+ NodeId = NewType("NodeId", int)
83
+
84
+
85
+ @dataclass
86
+ class Node[NodeId]:
87
+ """
88
+ Represents a node in the lambda calculus abstract syntax tree.
89
+
90
+ Attributes:
91
+ ref:
92
+ The lambda this variable is bound to, or None if this node is not a variable.
93
+ children:
94
+ Tuple containing either:
95
+
96
+ - Single id: child of lambda node
97
+ - Two ids: function and argument of application node
98
+ - Empty tuple: no children (in this case, this is a variable)
99
+
100
+ previous:
101
+ If this NodeId comes from a beta-reduction, corresponds to the id of the node in the term before reduction.
102
+ If the node has been created by the reduction, `previous` is the id in the "Substitute" (see 5.1)
103
+
104
+ previous_stump:
105
+ If the node has been created by a beta-reduction, the `Stump` it originates from (see 4.3)
106
+ """
107
+
108
+ ref: Optional[NodeId]
109
+ children: List[NodeId]
110
+ previous_local: NodeId
111
+ previous_stump: NodeId
112
+
113
+ def get_arg(self) -> Optional[NodeId]:
114
+ if len(self.children) == 2:
115
+ return self.children[1]
116
+ return None
117
+
118
+ def get_left(self) -> Optional[NodeId]:
119
+ if len(self.children) == 0:
120
+ return None
121
+ return self.children[0]
122
+
123
+ def previous(self):
124
+ return (self.previous_local, self.previous_stump)
125
+
126
+
127
+ class AbstractTerm:
128
+ nodes: pl.DataFrame
129
+
130
+ def __init__(self, nodes: pl.DataFrame | pl.LazyFrame):
131
+ nodes = nodes.match_to_schema(SCHEMA)
132
+ if isinstance(nodes, pl.LazyFrame):
133
+ self.nodes = nodes.collect()
134
+ else:
135
+ self.nodes = nodes
136
+
137
+ def root(self) -> NodeId:
138
+ return NodeId(0)
139
+
140
+ def node(self, node_id: NodeId):
141
+ """
142
+ Get a node in the expression tree by id
143
+ """
144
+
145
+ id, ref, arg, prev = self.nodes.row(node_id)
146
+ children = []
147
+ if ref is None:
148
+ children.append(id + 1)
149
+ if arg is not None:
150
+ children.append(arg)
151
+
152
+ return Node(
153
+ ref=ref,
154
+ children=children,
155
+ previous_local=prev["local"] if prev else None,
156
+ previous_stump=prev["stump"] if prev else None,
157
+ )
158
+
159
+ def find_variables(self, lamb: NodeId) -> Generator[NodeId]:
160
+ """
161
+ Find all variables bound to a specific lambda.
162
+
163
+ Args:
164
+ lamb: the id of the lambda to consider
165
+ """
166
+ return self.nodes.filter(pl.col("ref") == lamb)["id"].__iter__()
167
+
168
+ def find_redexes(self) -> Generator[NodeId]:
169
+ """
170
+ Find all candidate redex nodes (see 4.)
171
+
172
+ The first redex must be the leftmost-outermost redex.
173
+ """
174
+ return self.nodes.filter(
175
+ pl.col("arg").is_not_null(),
176
+ pl.col("arg").shift(-1).is_null(),
177
+ pl.col("ref").shift(-1).is_null(),
178
+ )["id"].__iter__()
179
+
180
+ def _get_subtree(self, root: NodeId) -> range:
181
+ """
182
+ Get the subtree of the node in a postfix left -> right order (see 3.)
183
+
184
+ Args:
185
+ root: the ancestor to consider
186
+ """
187
+ refs = self.nodes["ref"]
188
+ args = self.nodes["arg"]
189
+ rightmost = root
190
+ while True:
191
+ ref = refs[rightmost]
192
+ if ref is not None:
193
+ return range(NodeId(root), NodeId(rightmost + 1))
194
+
195
+ arg = args[rightmost]
196
+ rightmost = arg if arg is not None else rightmost + 1
197
+
198
+ def get_subtree(self, root: NodeId) -> Generator[NodeId]:
199
+ return (NodeId(x) for x in self._get_subtree(root))
200
+
201
+ def __call__(self, other: AbstractTerm) -> AbstractTerm:
202
+ r"""
203
+ Compose this term with another one with an application.
204
+ ```
205
+ application
206
+ / \
207
+ / \
208
+ self other
209
+ ```
210
+ """
211
+ n = len(self.nodes)
212
+ return AbstractTerm(
213
+ pl.concat(
214
+ [
215
+ pl.DataFrame(
216
+ [{"id": 0, "arg": n + 1}],
217
+ schema=SCHEMA,
218
+ ),
219
+ _shift(self.nodes, 1),
220
+ _shift(other.nodes, n + 1),
221
+ ],
222
+ )
223
+ )
224
+
225
+ def beta_reduce(self, redex: NodeId) -> AbstractTerm:
226
+ """
227
+ Compute the beta-reduction of this term on this redex.
228
+ """
229
+ lamb = redex + 1
230
+ a = redex + 2
231
+ id_subst = self.node(redex).get_arg()
232
+ assert id_subst is not None
233
+
234
+ # see 4.3
235
+ stumps = self.nodes.lazy().filter(pl.col("ref") == lamb).select("id")
236
+
237
+ # see 4.4
238
+ subst_range = self._get_subtree(id_subst)
239
+ subst = self.nodes.lazy().filter(
240
+ pl.col("id").is_between(subst_range.start, subst_range.stop, closed="left")
241
+ )
242
+
243
+ # we start duplicating
244
+ new_trunks = subst.join(stumps, how="cross", suffix="_stump").select(
245
+ "id",
246
+ "id_stump",
247
+ prev=pl.struct(stump="id_stump", local="id"),
248
+ prev_arg=pl.struct(stump="id_stump", local="arg"),
249
+ prev_ref=pl.struct(stump="id_stump", local="ref"),
250
+ # TODO: explain the logic for variables that are not bound inside the redex
251
+ prev_ref_unbound=pl.struct(stump=None, local="ref", schema=PREV_ID_SCHEMA),
252
+ )
253
+
254
+ new_nodes = (
255
+ self.nodes.lazy()
256
+ # remove the information about the previous iteration
257
+ .select(pl.exclude("prev"))
258
+ # See 5.2
259
+ .filter(
260
+ ~(pl.col("id").eq(redex) | pl.col("id").eq(lamb)),
261
+ )
262
+ # See 5.3
263
+ .join(subst, left_on="id", right_on="id", how="anti")
264
+ # See 5.4
265
+ .with_columns(arg=pl.col("arg").replace(redex, a))
266
+ .join(
267
+ # 5.1.1 we check if the argument is a stump
268
+ stumps.select("id", arg_is_stump=True),
269
+ left_on="arg",
270
+ right_on="id",
271
+ how="left",
272
+ maintain_order="left",
273
+ )
274
+ # See 5.1
275
+ .join(
276
+ new_trunks,
277
+ left_on="id",
278
+ right_on="id_stump",
279
+ how="left",
280
+ maintain_order="left",
281
+ )
282
+ .select(
283
+ pl.col("prev").fill_null(pl.struct(stump=None, local="id")),
284
+ pl.col("prev_ref").fill_null(pl.struct(stump=None, local="ref")),
285
+ pl.col("prev_ref_unbound"),
286
+ # 5.1.1
287
+ pl.col("prev_arg").fill_null(
288
+ pl.when(pl.col("arg_is_stump"))
289
+ .then(pl.struct(stump="arg", local=id_subst))
290
+ .otherwise(pl.struct(stump=None, local="arg"))
291
+ ),
292
+ )
293
+ .with_row_index("id")
294
+ ).cache()
295
+
296
+ # renumbering step: we replace each "previous location" with the new id
297
+ out = (
298
+ new_nodes.join(
299
+ new_nodes.select(prev_ref="prev", ref="id"),
300
+ on="prev_ref",
301
+ how="left",
302
+ maintain_order="left",
303
+ nulls_equal=True,
304
+ )
305
+ .join(
306
+ new_nodes.select(prev_ref_unbound="prev", ref_unbound="id"),
307
+ on="prev_ref_unbound",
308
+ how="left",
309
+ maintain_order="left",
310
+ nulls_equal=True,
311
+ )
312
+ .join(
313
+ new_nodes.select(prev_arg="prev", arg="id"),
314
+ on="prev_arg",
315
+ how="left",
316
+ maintain_order="left",
317
+ nulls_equal=True,
318
+ )
319
+ .select(
320
+ "id",
321
+ ref=pl.coalesce("ref", "ref_unbound"), # 5.1.2
322
+ arg="arg",
323
+ prev="prev",
324
+ )
325
+ ).collect()
326
+ return AbstractTerm(out)
@@ -0,0 +1,305 @@
1
+ from typing import Iterable, List, Optional, Tuple, TypeAlias
2
+
3
+ import svg
4
+
5
+ from .core import AbstractTerm, NodeId
6
+ from .utils import Interval, ShapeAnimFrame
7
+
8
+
9
+ def compute_height(term: AbstractTerm):
10
+ _, y = compute_layout(term)
11
+ return max(interval[1] for interval in y.values() if interval) + 1
12
+
13
+
14
+ def count_variables(term: AbstractTerm):
15
+ return sum(1 for x in term.get_subtree(term.root()) if term.node(x).ref is not None)
16
+
17
+
18
+ Loc: TypeAlias = Tuple[NodeId, Optional[NodeId]]
19
+
20
+
21
+ def compute_layout(
22
+ term: AbstractTerm, lamb: Optional[NodeId] = None, replaced_var_width=1
23
+ ) -> tuple[dict[Loc, Interval], dict[Loc, Interval]]:
24
+ y: dict[Loc, Interval] = {(term.root(), None): Interval((0, 0))}
25
+ x: dict[Loc, Interval] = {}
26
+ nodes = list(term.get_subtree(term.root()))
27
+ for node_id in nodes:
28
+ node = term.node(node_id)
29
+ ref = node.ref
30
+ arg = node.get_arg()
31
+ if ref is not None:
32
+ continue
33
+ child = term.node(node_id).get_left()
34
+ assert child is not None
35
+ if arg is not None:
36
+ y[child, None] = y[node_id, None].shift(
37
+ 0 if term.node(child).get_arg() is None else 1
38
+ )
39
+ y[arg, None] = y[node_id, None].shift(0)
40
+ else:
41
+ y[child, None] = y[node_id, None].shift(1)
42
+
43
+ next_var_x = count_variables(term) - 1
44
+
45
+ for node_id in reversed(nodes):
46
+ node = term.node(node_id)
47
+ ref = node.ref
48
+ arg = node.get_arg()
49
+ if ref is not None:
50
+ width = replaced_var_width if ref == lamb else 1
51
+ x[node_id, None] = Interval((next_var_x - width + 1, next_var_x))
52
+ next_var_x -= width
53
+ x[ref, None] = x[node_id, None] | x.get((ref, None), Interval(None))
54
+
55
+ else:
56
+ child = term.node(node_id).get_left()
57
+ assert child is not None
58
+ x[node_id, None] = x[child, None] | x.get((node_id, None), Interval(None))
59
+ y[node_id, None] = y[child, None] | y[node_id, None]
60
+
61
+ return x, y
62
+
63
+
64
+ def draw(
65
+ x: dict[Loc, Interval],
66
+ y: dict[Loc, Interval],
67
+ i_node: Loc,
68
+ ref: Optional[Loc],
69
+ arg: Optional[Loc],
70
+ key: Loc,
71
+ idx: int,
72
+ replaced=False,
73
+ removed=False,
74
+ hide_arg=False,
75
+ ) -> Iterable[ShapeAnimFrame]:
76
+ x_node = x[i_node]
77
+ y_node = y[i_node]
78
+ if arg is not None or removed:
79
+ color = "transparent"
80
+ elif replaced or removed:
81
+ color = "green"
82
+ elif ref is not None:
83
+ color = "red"
84
+ else:
85
+ color = "blue"
86
+
87
+ r = svg.Rect(
88
+ height=0.8,
89
+ stroke_width=0.05,
90
+ stroke="gray",
91
+ )
92
+
93
+ yield ShapeAnimFrame(
94
+ element=r,
95
+ key=("r", key),
96
+ idx=idx,
97
+ attrs={
98
+ "x": 0.1 + x_node[0],
99
+ "y": 0.1 + y_node[0] + (1 if replaced else 0),
100
+ "width": 0.8 + x_node[1] - x_node[0],
101
+ "fill_opacity": 1 if arg is None else 0,
102
+ "fill": color,
103
+ },
104
+ zindex=0,
105
+ )
106
+ if arg is not None and not hide_arg:
107
+ r = svg.Rect(
108
+ height=0.8,
109
+ stroke_width=0.1,
110
+ stroke="orange",
111
+ )
112
+
113
+ yield ShapeAnimFrame(
114
+ element=r,
115
+ key=("a", key),
116
+ idx=idx,
117
+ attrs={
118
+ "x": 0.1 + x_node[0],
119
+ "y": 0.1 + y_node[0],
120
+ "width": 0.8 + x_node[1] - x_node[0],
121
+ "fill_opacity": 1 if arg is None else 0,
122
+ "fill": color,
123
+ },
124
+ zindex=1,
125
+ )
126
+
127
+ if ref is not None:
128
+ y_ref = y[ref]
129
+ e = svg.Line(
130
+ stroke_width=0.2,
131
+ stroke="gray",
132
+ )
133
+ yield ShapeAnimFrame(
134
+ element=e,
135
+ key=("l", key),
136
+ idx=idx,
137
+ attrs={
138
+ "x1": x_node[0] + 0.5,
139
+ "y1": y_ref[0] + 0.9,
140
+ "x2": x_node[0] + 0.5,
141
+ "y2": y_node[0] + 0.1 + (1 if replaced else 0),
142
+ "stroke": "green" if replaced else "gray",
143
+ },
144
+ zindex=2,
145
+ )
146
+
147
+ if arg is not None:
148
+ x_arg = x[arg]
149
+ e1 = svg.Line(
150
+ stroke="black",
151
+ stroke_width=0.05,
152
+ )
153
+ e2 = svg.Circle(fill="black", r=0.1)
154
+ if not removed:
155
+ yield ShapeAnimFrame(
156
+ element=e1,
157
+ key=("b", key),
158
+ idx=idx,
159
+ attrs={
160
+ "x1": 0.5 + x_node[1],
161
+ "y1": 0.5 + y_node[0],
162
+ "x2": 0.5 + x_arg[0],
163
+ "y2": 0.5 + y_node[0],
164
+ },
165
+ zindex=3,
166
+ )
167
+ yield ShapeAnimFrame(
168
+ element=e2,
169
+ key=("c", key),
170
+ idx=idx,
171
+ attrs={
172
+ "cx": 0.5 + x_node[1],
173
+ "cy": 0.5 + y_node[0],
174
+ },
175
+ zindex=3,
176
+ )
177
+
178
+
179
+ def compute_svg_frame_init(
180
+ term: AbstractTerm, idx: int = 0
181
+ ) -> Iterable[ShapeAnimFrame]:
182
+ x, y = compute_layout(term)
183
+ for node_id in term.get_subtree(term.root()):
184
+ node = term.node(node_id)
185
+ ref = node.ref
186
+ arg = node.get_arg()
187
+ yield from draw(
188
+ x,
189
+ y,
190
+ (node_id, None),
191
+ (ref, None) if ref is not None else None,
192
+ (arg, None) if arg is not None else None,
193
+ key=(node_id, None),
194
+ idx=idx,
195
+ )
196
+
197
+
198
+ def compute_svg_frame_phase_a(
199
+ term: AbstractTerm,
200
+ redex: NodeId,
201
+ b_subtree: List[NodeId],
202
+ vars: List[NodeId],
203
+ idx: int,
204
+ ) -> Iterable[ShapeAnimFrame]:
205
+ lamb = term.node(redex).get_left()
206
+ assert lamb is not None
207
+ b_width = sum(1 for x in b_subtree if term.node(x).ref is not None)
208
+ x, y = compute_layout(term, lamb=lamb, replaced_var_width=b_width)
209
+ for node_id in term.get_subtree(term.root()):
210
+ if node_id in b_subtree:
211
+ continue
212
+ node = term.node(node_id)
213
+ ref = node.ref
214
+ arg = node.get_arg()
215
+ replaced = ref == lamb
216
+ yield from draw(
217
+ x,
218
+ y,
219
+ (node_id, None),
220
+ (ref, None) if ref is not None else None,
221
+ (arg, None) if arg is not None else None,
222
+ key=(node_id, None),
223
+ idx=idx,
224
+ replaced=replaced,
225
+ removed=(node_id == lamb or node_id == redex),
226
+ )
227
+
228
+ for stump in vars:
229
+ for local_id in b_subtree:
230
+ local_node = term.node(local_id)
231
+ ref = local_node.ref
232
+ arg = local_node.get_arg()
233
+ key = (local_id, stump)
234
+ yield from draw(
235
+ x,
236
+ y,
237
+ (local_id, None),
238
+ (ref, None) if ref is not None else None,
239
+ (arg, None) if arg is not None else None,
240
+ key=key,
241
+ idx=idx,
242
+ )
243
+
244
+
245
+ def compute_svg_frame_phase_b(
246
+ term: AbstractTerm,
247
+ redex: NodeId,
248
+ b_subtree: List[NodeId],
249
+ reduced: AbstractTerm,
250
+ idx: int,
251
+ ) -> Iterable[ShapeAnimFrame]:
252
+ lamb = term.node(redex).get_left()
253
+ assert lamb is not None
254
+ b_width = sum(1 for x in b_subtree if term.node(x).ref is not None)
255
+ b = term.node(redex).get_arg()
256
+ assert b is not None
257
+ x, y = compute_layout(term, lamb=lamb, replaced_var_width=b_width)
258
+ b_x = x[b, None][0]
259
+ b_y = y[b, None][0]
260
+ for node_id in reduced.get_subtree(reduced.root()):
261
+ node = reduced.node(node_id)
262
+ local = node.previous_local
263
+ stump = node.previous_stump
264
+ if stump is not None:
265
+ delta_x = x[stump, None][0] - b_x
266
+ delta_y = y[stump, None][0] - b_y + 1
267
+ x[local, stump] = x[local, None].shift(delta_x)
268
+ y[local, stump] = y[local, None].shift(delta_y)
269
+
270
+ for node_id in reduced.get_subtree(reduced.root()):
271
+ node = reduced.node(node_id)
272
+ new_ref = node.ref
273
+ new_arg = node.get_arg()
274
+
275
+ key = node.previous()
276
+
277
+ yield from draw(
278
+ x,
279
+ y,
280
+ key,
281
+ ref=reduced.node(new_ref).previous() if new_ref else None,
282
+ arg=reduced.node(new_arg).previous() if new_arg else None,
283
+ key=key,
284
+ idx=idx,
285
+ )
286
+
287
+
288
+ def compute_svg_frame_final(
289
+ reduced: AbstractTerm, idx: int
290
+ ) -> Iterable[ShapeAnimFrame]:
291
+ x, y = compute_layout(reduced)
292
+ for node_id in reduced.get_subtree(reduced.root()):
293
+ node = reduced.node(node_id)
294
+ ref = node.ref
295
+ arg = node.get_arg()
296
+ key = node.previous()
297
+ yield from draw(
298
+ x,
299
+ y,
300
+ (node_id, None),
301
+ (ref, None) if ref is not None else None,
302
+ (arg, None) if arg is not None else None,
303
+ key,
304
+ idx=idx,
305
+ )
@@ -1,136 +0,0 @@
1
- import polars as pl
2
- from polars import Schema, UInt32
3
-
4
- SCHEMA = Schema(
5
- {
6
- "id": UInt32,
7
- "ref": UInt32,
8
- "arg": UInt32,
9
- "bid": pl.Struct({"major": UInt32, "minor": UInt32}),
10
- },
11
- )
12
-
13
-
14
- def _shift(nodes: pl.DataFrame, offset: int):
15
- return nodes.with_columns(
16
- pl.col("id") + offset,
17
- pl.col("ref") + offset,
18
- pl.col("arg") + offset,
19
- bid=None,
20
- )
21
-
22
-
23
- def compose(f: pl.DataFrame, x: pl.DataFrame) -> pl.DataFrame:
24
- n = len(f)
25
- return pl.concat(
26
- [
27
- pl.DataFrame(
28
- [{"id": 0, "arg": n + 1}],
29
- schema=SCHEMA,
30
- ),
31
- _shift(f, 1),
32
- _shift(x, n + 1),
33
- ],
34
- )
35
-
36
-
37
- def find_redexes(nodes: pl.DataFrame) -> pl.LazyFrame:
38
- return (
39
- nodes.lazy()
40
- .filter(
41
- pl.col("arg").is_not_null(),
42
- pl.col("arg").shift(-1).is_null(),
43
- pl.col("ref").shift(-1).is_null(),
44
- )
45
- .select("id", lamb=pl.col("id") + 1, arg="arg")
46
- )
47
-
48
-
49
- def find_variables(nodes: pl.DataFrame, lamb: int) -> pl.LazyFrame:
50
- return nodes.lazy().filter(pl.col("ref") == lamb).select("id", replaced=True)
51
-
52
-
53
- def subtree(nodes: pl.DataFrame, root: int) -> pl.LazyFrame:
54
- refs = nodes["ref"]
55
- args = nodes["arg"]
56
- rightmost = root
57
- while True:
58
- ref = refs[rightmost]
59
- if ref is not None:
60
- return nodes.lazy().filter(pl.col("id").is_between(root, rightmost))
61
-
62
- arg = args[rightmost]
63
- rightmost = arg if arg is not None else rightmost + 1
64
-
65
-
66
- def _generate_bi_identifier(
67
- major_name: str, minor_name: str, minor_replacement=pl.lit(None)
68
- ):
69
- return pl.struct(
70
- major=pl.col(major_name).fill_null(pl.col(minor_name)),
71
- minor=minor_replacement.fill_null(pl.col(minor_name)),
72
- )
73
-
74
-
75
- def beta_reduce(nodes: pl.DataFrame, lamb: int, b: int) -> pl.DataFrame:
76
- redex = lamb - 1
77
- a = lamb + 1
78
-
79
- vars = find_variables(nodes, lamb)
80
-
81
- b_subtree = subtree(nodes, b)
82
-
83
- b_subtree_duplicated = b_subtree.join(vars, how="cross", suffix="_major")
84
- rest_of_nodes = (
85
- nodes.lazy()
86
- .join(b_subtree, on="id", how="anti")
87
- .with_columns(arg=pl.col("arg").replace(redex, a))
88
- )
89
-
90
- new_nodes = (
91
- pl.concat(
92
- [b_subtree_duplicated, rest_of_nodes],
93
- how="diagonal",
94
- )
95
- .join(vars, left_on="id", right_on="id", how="anti")
96
- .join(vars, left_on="arg", right_on="id", how="left", suffix="_arg")
97
- .join(vars, left_on="ref", right_on="id", how="left", suffix="_ref")
98
- .filter(
99
- ~(pl.col("id").eq(redex) | pl.col("id").eq(lamb)),
100
- )
101
- .select(
102
- bid=_generate_bi_identifier("id_major", "id"),
103
- bid_ref=_generate_bi_identifier("id_major", "ref"),
104
- bid_ref_fallback=pl.struct(major="ref", minor="ref"),
105
- bid_arg=_generate_bi_identifier(
106
- # TODO: document and simplify to avoid "minor_replacement"
107
- "id_major",
108
- "arg",
109
- minor_replacement=pl.when("replaced_arg").then(b),
110
- ),
111
- )
112
- .sort("bid")
113
- .with_row_index("id")
114
- ).cache()
115
-
116
- return (
117
- new_nodes.join(
118
- new_nodes.select(bid_ref="bid", ref="id"),
119
- on="bid_ref",
120
- how="left",
121
- maintain_order="left",
122
- )
123
- .join(
124
- new_nodes.select(bid_ref_fallback="bid", ref_fallback="id"),
125
- on="bid_ref_fallback",
126
- how="left",
127
- maintain_order="left",
128
- )
129
- .join(
130
- new_nodes.select(bid_arg="bid", arg="id"),
131
- on="bid_arg",
132
- how="left",
133
- maintain_order="left",
134
- )
135
- .select("id", ref=pl.coalesce("ref", "ref_fallback"), arg="arg", bid="bid")
136
- ).collect()
@@ -1,271 +0,0 @@
1
- from typing import Any, Iterable, Optional, Union
2
-
3
- import polars as pl
4
- import svg
5
-
6
- from .utils import Interval, ShapeAnimFrame
7
-
8
-
9
- def compute_height(nodes: pl.DataFrame):
10
- _, y = compute_layout(nodes)
11
- return max(interval[1] for interval in y.values() if interval) + 1
12
-
13
-
14
- def count_variables(nodes: pl.DataFrame):
15
- return nodes["ref"].count()
16
-
17
-
18
- def compute_layout(
19
- nodes: pl.DataFrame, lamb=None, replaced_var_width=1
20
- ) -> tuple[dict[int, int], dict[int, int]]:
21
- y = {0: Interval((0, 0))}
22
- x = {}
23
- for node, ref, arg in nodes.select("id", "ref", "arg").iter_rows():
24
- if ref is not None:
25
- continue
26
- child = node + 1
27
- if arg is not None:
28
- y[child] = y[node].shift(0 if nodes["arg"][child] is None else 1)
29
- y[arg] = y[node].shift(0)
30
- else:
31
- y[child] = y[node].shift(1)
32
-
33
- next_var_x = count_variables(nodes) - 1
34
-
35
- for node, ref, arg in (
36
- nodes.sort("id", descending=True).select("id", "ref", "arg").iter_rows()
37
- ):
38
- if ref is not None:
39
- width = replaced_var_width if ref == lamb else 1
40
- x[node] = Interval((next_var_x - width + 1, next_var_x))
41
- next_var_x -= width
42
- x[ref] = x[node] | x.get(ref, Interval(None))
43
-
44
- else:
45
- child = node + 1
46
- x[node] = x[child] | x.get(node, Interval(None))
47
- y[node] = y[child] | y[node]
48
- return x, y
49
-
50
-
51
- def draw(
52
- x: dict[Union[int, tuple[int, int]], Interval],
53
- y: dict[Union[int, tuple[int, int]], Interval],
54
- i_node: Union[int, tuple[int, int]],
55
- ref: Optional[int],
56
- arg: Optional[int],
57
- key: Any,
58
- idx: int,
59
- replaced=False,
60
- removed=False,
61
- hide_arg=False,
62
- ) -> Iterable[ShapeAnimFrame]:
63
- x_node = x[i_node]
64
- y_node = y[i_node]
65
- if arg is not None or removed:
66
- color = "transparent"
67
- elif replaced or removed:
68
- color = "green"
69
- elif ref is not None:
70
- color = "red"
71
- else:
72
- color = "blue"
73
-
74
- r = svg.Rect(
75
- height=0.8,
76
- stroke_width=0.05,
77
- stroke="gray",
78
- )
79
-
80
- yield ShapeAnimFrame(
81
- element=r,
82
- key=("r", key),
83
- idx=idx,
84
- attrs={
85
- "x": 0.1 + x_node[0],
86
- "y": 0.1 + y_node[0] + (1 if replaced else 0),
87
- "width": 0.8 + x_node[1] - x_node[0],
88
- "fill_opacity": 1 if arg is None else 0,
89
- "fill": color,
90
- },
91
- zindex=0,
92
- )
93
- if arg is not None and not hide_arg:
94
- r = svg.Rect(
95
- height=0.8,
96
- stroke_width=0.1,
97
- stroke="orange",
98
- )
99
-
100
- yield ShapeAnimFrame(
101
- element=r,
102
- key=("a", key),
103
- idx=idx,
104
- attrs={
105
- "x": 0.1 + x_node[0],
106
- "y": 0.1 + y_node[0],
107
- "width": 0.8 + x_node[1] - x_node[0],
108
- "fill_opacity": 1 if arg is None else 0,
109
- "fill": color,
110
- },
111
- zindex=1,
112
- )
113
-
114
- if ref is not None:
115
- y_ref = y[ref]
116
- e = svg.Line(
117
- stroke_width=0.2,
118
- stroke="gray",
119
- )
120
- yield ShapeAnimFrame(
121
- element=e,
122
- key=("l", key),
123
- idx=idx,
124
- attrs={
125
- "x1": x_node[0] + 0.5,
126
- "y1": y_ref[0] + 0.9,
127
- "x2": x_node[0] + 0.5,
128
- "y2": y_node[0] + 0.1 + (1 if replaced else 0),
129
- "stroke": "green" if replaced else "gray",
130
- },
131
- zindex=2,
132
- )
133
-
134
- if arg is not None:
135
- x_arg = x[arg]
136
- e1 = svg.Line(
137
- stroke="black",
138
- stroke_width=0.05,
139
- )
140
- e2 = svg.Circle(fill="black", r=0.1)
141
- if not removed:
142
- yield ShapeAnimFrame(
143
- element=e1,
144
- key=("b", key),
145
- idx=idx,
146
- attrs={
147
- "x1": 0.5 + x_node[1],
148
- "y1": 0.5 + y_node[0],
149
- "x2": 0.5 + x_arg[0],
150
- "y2": 0.5 + y_node[0],
151
- },
152
- zindex=3,
153
- )
154
- yield ShapeAnimFrame(
155
- element=e2,
156
- key=("c", key),
157
- idx=idx,
158
- attrs={
159
- "cx": 0.5 + x_node[1],
160
- "cy": 0.5 + y_node[0],
161
- },
162
- zindex=3,
163
- )
164
-
165
-
166
- def compute_svg_frame_init(
167
- nodes: pl.DataFrame, idx: int = 0
168
- ) -> Iterable[ShapeAnimFrame]:
169
- x, y = compute_layout(nodes)
170
- for target_id, ref, arg in (
171
- nodes.select("id", "ref", "arg").sort("id", descending=True).iter_rows()
172
- ):
173
- yield from draw(x, y, target_id, ref, arg, key=target_id, idx=idx)
174
-
175
-
176
- def compute_svg_frame_phase_a(
177
- nodes: pl.DataFrame, lamb: int, b_subtree: pl.DataFrame, vars: pl.Series, idx: int
178
- ) -> Iterable[ShapeAnimFrame]:
179
- redex = lamb - 1 if lamb is not None else None
180
- b_width = b_subtree.count()["ref"].item()
181
- x, y = compute_layout(nodes, lamb=lamb, replaced_var_width=b_width)
182
- for target_id, ref, arg in (
183
- nodes.select("id", "ref", "arg").sort("id", descending=True).iter_rows()
184
- ):
185
- replaced = ref is not None and ref == lamb
186
- yield from draw(
187
- x,
188
- y,
189
- target_id,
190
- ref,
191
- arg,
192
- key=target_id,
193
- idx=idx,
194
- replaced=replaced,
195
- removed=(target_id == lamb or target_id == redex),
196
- )
197
-
198
- for v in vars:
199
- for minor, ref, arg in (
200
- b_subtree.select("id", "ref", "arg").sort("id", descending=True).iter_rows()
201
- ):
202
- yield from draw(x, y, minor, ref, arg, key=(v, minor), idx=idx)
203
-
204
-
205
- def compute_svg_frame_phase_b(
206
- nodes: pl.DataFrame,
207
- lamb: int,
208
- b_subtree: pl.DataFrame,
209
- new_nodes: pl.DataFrame,
210
- idx: int,
211
- ) -> Iterable[ShapeAnimFrame]:
212
- b_width = b_subtree.count()["ref"].item()
213
- b = b_subtree["id"][0]
214
- x, y = compute_layout(nodes, lamb=lamb, replaced_var_width=b_width)
215
- b_x = x[b][0]
216
- b_y = y[b][0]
217
- for bid, arg in new_nodes.select("bid", "arg").iter_rows():
218
- if bid["minor"] != bid["major"]:
219
- v = bid["major"]
220
- minor = bid["minor"]
221
- delta_x = x[v][0] - b_x
222
- delta_y = y[v][0] - b_y + 1
223
- x[(v, minor)] = x[minor].shift(delta_x)
224
- y[(v, minor)] = y[minor].shift(delta_y)
225
-
226
- for bid, new_ref, new_arg in new_nodes.select("bid", "ref", "arg").iter_rows():
227
- v = bid["major"]
228
- minor = bid["minor"]
229
- if new_ref is None:
230
- ref = None
231
- else:
232
- bid_ref = new_nodes["bid"][new_ref]
233
- ref = (
234
- (bid_ref["major"], bid_ref["minor"])
235
- if bid_ref["major"] != bid_ref["minor"]
236
- else bid_ref["minor"]
237
- )
238
- if new_arg is None:
239
- arg = None
240
- else:
241
- bid_arg = new_nodes["bid"][new_arg]
242
- arg = (
243
- (bid_arg["major"], bid_arg["minor"])
244
- if bid_arg["major"] != bid_arg["minor"]
245
- else bid_arg["minor"]
246
- )
247
- key = (v, minor) if minor != v else minor
248
- yield from draw(
249
- x,
250
- y,
251
- key,
252
- ref,
253
- arg,
254
- key=key,
255
- idx=idx,
256
- )
257
-
258
-
259
- def compute_svg_frame_final(
260
- reduced: pl.DataFrame, idx: int
261
- ) -> Iterable[ShapeAnimFrame]:
262
- x, y = compute_layout(reduced)
263
- for target_id, bid, ref, arg in (
264
- reduced.select("id", "bid", "ref", "arg")
265
- .sort("id", descending=True)
266
- .iter_rows()
267
- ):
268
- minor = bid["minor"]
269
- major = bid["major"]
270
- key = (major, minor) if minor != major else minor
271
- yield from draw(x, y, target_id, ref, arg, key, idx=idx)
File without changes