egglog 6.1.0__cp311-none-win_amd64.whl → 7.0.0__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/pretty.py ADDED
@@ -0,0 +1,418 @@
1
+ """
2
+ Pretty printing for declerations.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import Counter, defaultdict
8
+ from dataclasses import dataclass, field
9
+ from typing import TYPE_CHECKING, TypeAlias
10
+
11
+ import black
12
+ from typing_extensions import assert_never
13
+
14
+ from .declarations import *
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Mapping
18
+
19
+ __all__ = [
20
+ "pretty_decl",
21
+ "pretty_callable_ref",
22
+ "BINARY_METHODS",
23
+ "UNARY_METHODS",
24
+ ]
25
+ MAX_LINE_LENGTH = 110
26
+ LINE_DIFFERENCE = 10
27
+ BLACK_MODE = black.Mode(line_length=180)
28
+
29
+ # Use this special character in place of the args, so that if the args are inlined
30
+ # in the viz, they will replace it
31
+ ARG_STR = "·"
32
+
33
+ # Special methods which we might want to use as functions
34
+ # Mapping to the operator they represent for pretty printing them
35
+ # https://docs.python.org/3/reference/datamodel.html
36
+ BINARY_METHODS = {
37
+ "__lt__": "<",
38
+ "__le__": "<=",
39
+ "__eq__": "==",
40
+ "__ne__": "!=",
41
+ "__gt__": ">",
42
+ "__ge__": ">=",
43
+ # Numeric
44
+ "__add__": "+",
45
+ "__sub__": "-",
46
+ "__mul__": "*",
47
+ "__matmul__": "@",
48
+ "__truediv__": "/",
49
+ "__floordiv__": "//",
50
+ "__mod__": "%",
51
+ # TODO: Support divmod, with tuple return value
52
+ # "__divmod__": "divmod",
53
+ # TODO: Three arg power
54
+ "__pow__": "**",
55
+ "__lshift__": "<<",
56
+ "__rshift__": ">>",
57
+ "__and__": "&",
58
+ "__xor__": "^",
59
+ "__or__": "|",
60
+ }
61
+
62
+
63
+ UNARY_METHODS = {
64
+ "__pos__": "+",
65
+ "__neg__": "-",
66
+ "__invert__": "~",
67
+ }
68
+
69
+ AllDecls: TypeAlias = RulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
70
+
71
+
72
+ def pretty_decl(
73
+ decls: Declarations, decl: AllDecls, *, wrapping_fn: str | None = None, ruleset_name: str | None = None
74
+ ) -> str:
75
+ """
76
+ Pretty print a decleration.
77
+
78
+ This will use re-format the result and put the expression on the last line, preceeded by the statements.
79
+ """
80
+ traverse = TraverseContext()
81
+ traverse(decl, toplevel=True)
82
+ pretty = traverse.pretty(decls)
83
+ expr = pretty(decl, ruleset_name=ruleset_name)
84
+ if wrapping_fn:
85
+ expr = f"{wrapping_fn}({expr})"
86
+ program = "\n".join([*pretty.statements, expr])
87
+ try:
88
+ # TODO: Try replacing with ruff for speed
89
+ # https://github.com/amyreese/ruff-api
90
+ return black.format_str(program, mode=BLACK_MODE).strip()
91
+ except black.parsing.InvalidInput:
92
+ return program
93
+
94
+
95
+ def pretty_callable_ref(
96
+ decls: Declarations,
97
+ ref: CallableRef,
98
+ first_arg: ExprDecl | None = None,
99
+ bound_tp_params: tuple[JustTypeRef, ...] | None = None,
100
+ ) -> str:
101
+ """
102
+ Pretty print a callable reference, using a dummy value for
103
+ the args if the function is not in the form `f(x, ...)`.
104
+
105
+ To be used in the visualization.
106
+ """
107
+ # Pass in three dummy args, which are the max used for any operation that
108
+ # is not a generic function call
109
+ args: list[ExprDecl] = [LitDecl(ARG_STR)] * 3
110
+ if first_arg:
111
+ args.insert(0, first_arg)
112
+ res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner(
113
+ ref, args, bound_tp_params=bound_tp_params, parens=False
114
+ )
115
+ # Either returns a function or a function with args. If args are provided, they would just be called,
116
+ # on the function, so return them, because they are dummies
117
+ return res[0] if isinstance(res, tuple) else res
118
+
119
+
120
+ @dataclass
121
+ class TraverseContext:
122
+ """
123
+ State for traversing expressions (or declerations that contain expressions), so we can know how many parents each
124
+ expression has.
125
+ """
126
+
127
+ # All expressions we have seen (incremented the parent counts of all children)
128
+ _seen: set[AllDecls] = field(default_factory=set)
129
+ # The number of parents for each expressions
130
+ parents: Counter[AllDecls] = field(default_factory=Counter)
131
+
132
+ def pretty(self, decls: Declarations) -> PrettyContext:
133
+ """
134
+ Create a pretty context from the state of this traverse context.
135
+ """
136
+ return PrettyContext(decls, self.parents)
137
+
138
+ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901
139
+ if not toplevel:
140
+ self.parents[decl] += 1
141
+ if decl in self._seen:
142
+ return
143
+ match decl:
144
+ case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions):
145
+ self(lhs)
146
+ self(rhs)
147
+ for cond in conditions:
148
+ self(cond)
149
+ case RuleDecl(head, body, _):
150
+ for action in head:
151
+ self(action)
152
+ for fact in body:
153
+ self(fact)
154
+ case SetDecl(_, lhs, rhs) | UnionDecl(_, lhs, rhs):
155
+ self(lhs)
156
+ self(rhs)
157
+ case LetDecl(_, d) | ExprActionDecl(d) | ExprFactDecl(d):
158
+ self(d.expr)
159
+ case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d):
160
+ self(d)
161
+ case PanicDecl(_) | VarDecl(_) | LitDecl(_) | PyObjectDecl(_):
162
+ pass
163
+ case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls):
164
+ for de in decls:
165
+ self(de)
166
+ case CallDecl(_, exprs, _):
167
+ for e in exprs:
168
+ self(e.expr)
169
+ case RunDecl(_, until):
170
+ if until:
171
+ for f in until:
172
+ self(f)
173
+ case _:
174
+ assert_never(decl)
175
+
176
+ self._seen.add(decl)
177
+
178
+
179
+ @dataclass
180
+ class PrettyContext:
181
+ """
182
+
183
+ We need to build up a list of all the expressions we are pretty printing, so that we can see who has parents and who is mutated
184
+ and create temp variables for them.
185
+
186
+ """
187
+
188
+ decls: Declarations
189
+ parents: Mapping[AllDecls, int]
190
+
191
+ # All the expressions we have saved as names
192
+ names: dict[AllDecls, str] = field(default_factory=dict)
193
+ # A list of statements assigning variables or calling destructive ops
194
+ statements: list[str] = field(default_factory=list)
195
+ # Mapping of type to the number of times we have generated a name for that type, used to generate unique names
196
+ _gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
197
+
198
+ def __call__(
199
+ self, decl: AllDecls, *, unwrap_lit: bool = False, parens: bool = False, ruleset_name: str | None = None
200
+ ) -> str:
201
+ if decl in self.names:
202
+ return self.names[decl]
203
+ expr, tp_name = self.uncached(decl, unwrap_lit=unwrap_lit, parens=parens, ruleset_name=ruleset_name)
204
+ # We use a heuristic to decide whether to name this sub-expression as a variable
205
+ # The rough goal is to reduce the number of newlines, given our line length of ~180
206
+ # We determine it's worth making a new line for this expression if the total characters
207
+ # it would take up is > than some constant (~ line length).
208
+ line_diff: int = len(expr) - LINE_DIFFERENCE
209
+ n_parents = self.parents[decl]
210
+ if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
211
+ self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False)
212
+ return expr_name
213
+ return expr
214
+
215
+ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: PLR0911
216
+ match decl:
217
+ case LitDecl(value):
218
+ match value:
219
+ case None:
220
+ return "Unit()", "Unit"
221
+ case bool(b):
222
+ return str(b) if unwrap_lit else f"Bool({b})", "Bool"
223
+ case int(i):
224
+ return str(i) if unwrap_lit else f"i64({i})", "i64"
225
+ case float(f):
226
+ return str(f) if unwrap_lit else f"f64({f})", "f64"
227
+ case str(s):
228
+ return repr(s) if unwrap_lit else f"String({s!r})", "String"
229
+ assert_never(value)
230
+ case VarDecl(name):
231
+ return name, name
232
+ case CallDecl(_, _, _):
233
+ return self._call(decl, parens)
234
+ case PyObjectDecl(value):
235
+ return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
236
+ case ActionCommandDecl(action):
237
+ return self(action), "action"
238
+ case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions):
239
+ args = ", ".join(map(self, (rhs, *conditions)))
240
+ fn = "rewrite" if isinstance(decl, RewriteDecl) else "birewrite"
241
+ return f"{fn}({self(lhs)}).to({args})", "rewrite"
242
+ case RuleDecl(head, body, name):
243
+ l = ", ".join(map(self, body))
244
+ if name:
245
+ l += f", name={name}"
246
+ r = ", ".join(map(self, head))
247
+ return f"rule({l}).then({r})", "rule"
248
+ case SetDecl(_, lhs, rhs):
249
+ return f"set_({self(lhs)}).to({self(rhs)})", "action"
250
+ case UnionDecl(_, lhs, rhs):
251
+ return f"union({self(lhs)}).with_({self(rhs)})", "action"
252
+ case LetDecl(name, expr):
253
+ return f"let({name!r}, {self(expr.expr)})", "action"
254
+ case ExprActionDecl(expr):
255
+ return self(expr.expr), "action"
256
+ case ExprFactDecl(expr):
257
+ return self(expr.expr), "fact"
258
+ case ChangeDecl(_, expr, change):
259
+ return f"{change}({self(expr)})", "action"
260
+ case PanicDecl(s):
261
+ return f"panic({s!r})", "action"
262
+ case EqDecl(_, exprs):
263
+ first, *rest = exprs
264
+ return f"eq({self(first)}).to({', '.join(map(self, rest))})", "fact"
265
+ case RulesetDecl(rules):
266
+ if ruleset_name:
267
+ return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
268
+ args = ", ".join(map(self, rules))
269
+ return f"ruleset({args})", "ruleset"
270
+ case SaturateDecl(schedule):
271
+ return f"{self(schedule, parens=True)}.saturate()", "schedule"
272
+ case RepeatDecl(schedule, times):
273
+ return f"{self(schedule, parens=True)} * {times}", "schedule"
274
+ case SequenceDecl(schedules):
275
+ if len(schedules) == 2:
276
+ return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule"
277
+ args = ", ".join(map(self, schedules))
278
+ return f"seq({args})", "schedule"
279
+ case RunDecl(ruleset_name, until):
280
+ ruleset = self.decls._rulesets[ruleset_name]
281
+ ruleset_str = self(ruleset, ruleset_name=ruleset_name)
282
+ if not until:
283
+ return ruleset_str, "schedule"
284
+ args = ", ".join(map(self, until))
285
+ return f"run({ruleset_str}, {args})", "schedule"
286
+ assert_never(decl)
287
+
288
+ def _call(
289
+ self,
290
+ decl: CallDecl,
291
+ parens: bool,
292
+ ) -> tuple[str, str]:
293
+ """
294
+ Pretty print the call. Also returns if it was saved as a name.
295
+
296
+ :param parens: If true, wrap the call in parens if it is a binary method call.
297
+ """
298
+ args = [a.expr for a in decl.args]
299
+ ref = decl.callable
300
+ # Special case !=
301
+ if decl.callable == FunctionRef("!="):
302
+ l, r = self(args[0]), self(args[1])
303
+ return f"ne({l}).to({r})", "Unit"
304
+ function_decl = self.decls.get_callable_decl(ref).to_function_decl()
305
+ # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
306
+ n_defaults = 0
307
+ for arg, default in zip(
308
+ reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type
309
+ ):
310
+ if arg != default:
311
+ break
312
+ n_defaults += 1
313
+ if n_defaults:
314
+ args = args[:-n_defaults]
315
+
316
+ tp_name = function_decl.semantic_return_type.name
317
+ if function_decl.mutates:
318
+ first_arg = args[0]
319
+ expr_str = self(first_arg)
320
+ # copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly)
321
+ has_multiple_parents = self.parents[first_arg] > 1
322
+ self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
323
+ # Set the first arg to be the name of the mutated arg and return the name
324
+ args[0] = VarDecl(expr_name)
325
+ else:
326
+ expr_name = None
327
+ res = self._call_inner(ref, args, decl.bound_tp_params, parens)
328
+ expr = (
329
+ f"{res[0]}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
330
+ if isinstance(res, tuple)
331
+ else res
332
+ )
333
+ # If we have a name, then we mutated
334
+ if expr_name:
335
+ self.statements.append(expr)
336
+ return expr_name, tp_name
337
+ return expr, tp_name
338
+
339
+ def _call_inner( # noqa: PLR0911
340
+ self, ref: CallableRef, args: list[ExprDecl], bound_tp_params: tuple[JustTypeRef, ...] | None, parens: bool
341
+ ) -> tuple[str, list[ExprDecl]] | str:
342
+ """
343
+ Pretty print the call, returning either the full function call or a tuple of the function and the args.
344
+ """
345
+ match ref:
346
+ case FunctionRef(name):
347
+ return name, args
348
+ case ClassMethodRef(class_name, method_name):
349
+ fn_str = str(JustTypeRef(class_name, bound_tp_params or ()))
350
+ if method_name != "__init__":
351
+ fn_str += f".{method_name}"
352
+ return fn_str, args
353
+ case MethodRef(_class_name, method_name):
354
+ slf, *args = args
355
+ slf = self(slf, parens=True)
356
+ match method_name:
357
+ case _ if method_name in UNARY_METHODS:
358
+ expr = f"{UNARY_METHODS[method_name]}{slf}"
359
+ return f"({expr})" if parens else expr
360
+ case _ if method_name in BINARY_METHODS:
361
+ expr = f"{slf} {BINARY_METHODS[method_name]} {self(args[0], parens=True, unwrap_lit=True)}"
362
+ return f"({expr})" if parens else expr
363
+ case "__getitem__":
364
+ return f"{slf}[{self(args[0], unwrap_lit=True)}]"
365
+ case "__call__":
366
+ return slf, args
367
+ case "__delitem__":
368
+ return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
369
+ case "__setitem__":
370
+ return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
371
+ case _:
372
+ return f"{slf}.{method_name}", args
373
+ case ConstantRef(name):
374
+ return name
375
+ case ClassVariableRef(class_name, variable_name):
376
+ return f"{class_name}.{variable_name}"
377
+ case PropertyRef(_class_name, property_name):
378
+ return f"{self(args[0], parens=True)}.{property_name}"
379
+ assert_never(ref)
380
+
381
+ def _generate_name(self, typ: str) -> str:
382
+ self._gen_name_types[typ] += 1
383
+ return f"_{typ}_{self._gen_name_types[typ]}"
384
+
385
+ def _name_expr(self, tp_name: str, expr_str: str, copy_identifier: bool) -> str:
386
+ # tp_name =
387
+ # If the thing we are naming is already a variable, we don't need to name it
388
+ if expr_str.isidentifier():
389
+ if copy_identifier:
390
+ name = self._generate_name(tp_name)
391
+ self.statements.append(f"{name} = copy({expr_str})")
392
+ else:
393
+ name = expr_str
394
+ else:
395
+ name = self._generate_name(tp_name)
396
+ self.statements.append(f"{name} = {expr_str}")
397
+ return name
398
+
399
+
400
+ def _plot_line_length(expr: object): # pragma: no cover
401
+ """
402
+ Plots the number of line lengths based on different max lengths
403
+ """
404
+ global MAX_LINE_LENGTH, LINE_DIFFERENCE
405
+ import altair as alt
406
+ import pandas as pd
407
+
408
+ sizes = []
409
+ for line_length in range(40, 180, 10):
410
+ MAX_LINE_LENGTH = line_length
411
+ for diff in range(0, 40, 5):
412
+ LINE_DIFFERENCE = diff
413
+ new_l = len(str(expr).split())
414
+ sizes.append((line_length, diff, new_l))
415
+
416
+ df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901
417
+
418
+ return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")