egglog 11.2.0__cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Files changed (46) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-314-x86_64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +734 -0
  4. egglog/builtins.py +1133 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +286 -0
  7. egglog/declarations.py +912 -0
  8. egglog/deconstruct.py +173 -0
  9. egglog/egraph.py +1875 -0
  10. egglog/egraph_state.py +680 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +67 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/__init__.py +3 -0
  26. egglog/exp/array_api.py +2019 -0
  27. egglog/exp/array_api_jit.py +51 -0
  28. egglog/exp/array_api_loopnest.py +74 -0
  29. egglog/exp/array_api_numba.py +69 -0
  30. egglog/exp/array_api_program_gen.py +510 -0
  31. egglog/exp/program_gen.py +425 -0
  32. egglog/exp/siu_examples.py +32 -0
  33. egglog/ipython_magic.py +41 -0
  34. egglog/pretty.py +509 -0
  35. egglog/py.typed +0 -0
  36. egglog/runtime.py +712 -0
  37. egglog/thunk.py +97 -0
  38. egglog/type_constraint_solver.py +113 -0
  39. egglog/version_compat.py +87 -0
  40. egglog/visualizer.css +1 -0
  41. egglog/visualizer.js +35777 -0
  42. egglog/visualizer_widget.py +39 -0
  43. egglog-11.2.0.dist-info/METADATA +74 -0
  44. egglog-11.2.0.dist-info/RECORD +46 -0
  45. egglog-11.2.0.dist-info/WHEEL +4 -0
  46. egglog-11.2.0.dist-info/licenses/LICENSE +21 -0
egglog/pretty.py ADDED
@@ -0,0 +1,509 @@
1
+ """
2
+ Pretty printing for declerations.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import Counter, defaultdict
8
+ from dataclasses import dataclass, field
9
+ from typing import TYPE_CHECKING, TypeAlias
10
+
11
+ import black
12
+ from typing_extensions import assert_never
13
+
14
+ from .declarations import *
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Mapping
18
+
19
+
20
+ __all__ = [
21
+ "BINARY_METHODS",
22
+ "UNARY_METHODS",
23
+ "pretty_callable_ref",
24
+ "pretty_decl",
25
+ ]
26
+ MAX_LINE_LENGTH = 110
27
+ LINE_DIFFERENCE = 10
28
+ BLACK_MODE = black.Mode(line_length=180)
29
+
30
+ # Use this special character in place of the args, so that if the args are inlined
31
+ # in the viz, they will replace it
32
+ ARG_STR = "·"
33
+
34
+ # Special methods which we might want to use as functions
35
+ # Mapping to the operator they represent for pretty printing them
36
+ # https://docs.python.org/3/reference/datamodel.html
37
+ BINARY_METHODS = {
38
+ "__lt__": "<",
39
+ "__le__": "<=",
40
+ "__eq__": "==",
41
+ "__ne__": "!=",
42
+ "__gt__": ">",
43
+ "__ge__": ">=",
44
+ # Numeric
45
+ "__add__": "+",
46
+ "__sub__": "-",
47
+ "__mul__": "*",
48
+ "__matmul__": "@",
49
+ "__truediv__": "/",
50
+ "__floordiv__": "//",
51
+ "__mod__": "%",
52
+ # TODO: Support divmod, with tuple return value
53
+ # "__divmod__": "divmod",
54
+ # TODO: Three arg power
55
+ "__pow__": "**",
56
+ "__lshift__": "<<",
57
+ "__rshift__": ">>",
58
+ "__and__": "&",
59
+ "__xor__": "^",
60
+ "__or__": "|",
61
+ }
62
+
63
+
64
+ UNARY_METHODS = {
65
+ "__pos__": "+",
66
+ "__neg__": "-",
67
+ "__invert__": "~",
68
+ }
69
+
70
+ AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
71
+
72
+
73
+ def pretty_decl(
74
+ decls: Declarations, decl: AllDecls, *, wrapping_fn: str | None = None, ruleset_name: str | None = None
75
+ ) -> str:
76
+ """
77
+ Pretty print a decleration.
78
+
79
+ This will use re-format the result and put the expression on the last line, preceeded by the statements.
80
+ """
81
+ traverse = TraverseContext(decls)
82
+ traverse(decl, toplevel=True)
83
+ pretty = traverse.pretty()
84
+ expr = pretty(decl, ruleset_name=ruleset_name)
85
+ if wrapping_fn:
86
+ expr = f"{wrapping_fn}({expr})"
87
+ program = "\n".join([*pretty.statements, expr])
88
+ try:
89
+ # TODO: Try replacing with ruff for speed
90
+ # https://github.com/amyreese/ruff-api
91
+ return black.format_str(program, mode=BLACK_MODE).strip()
92
+ except black.parsing.InvalidInput:
93
+ return program
94
+
95
+
96
+ def pretty_callable_ref(
97
+ decls: Declarations,
98
+ ref: CallableRef,
99
+ first_arg: ExprDecl | None = None,
100
+ bound_tp_params: tuple[JustTypeRef, ...] | None = None,
101
+ include_all_args: bool = False,
102
+ ) -> str:
103
+ """
104
+ Pretty print a callable reference, using a dummy value for
105
+ the args if the function is not in the form `f(x, ...)`.
106
+
107
+ To be used in the visualization.
108
+ """
109
+ # Pass in three dummy args, which are the max used for any operation that
110
+ # is not a generic function call
111
+ args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * 3
112
+ if first_arg:
113
+ args.insert(0, first_arg)
114
+ context = PrettyContext(decls, defaultdict(lambda: 0))
115
+ res = context._call_inner(ref, args, bound_tp_params=bound_tp_params, parens=False)
116
+ # Either returns a function or a function with args. If args are provided, they would just be called,
117
+ # on the function, so return them, because they are dummies
118
+ if isinstance(res, tuple):
119
+ # If we want to include all args as ARG_STR, then we need to figure out how many to use
120
+ # used for set_cost so that `cost(E(...))` will show up as a call
121
+ if include_all_args:
122
+ signature = decls.get_callable_decl(ref).signature
123
+ assert isinstance(signature, FunctionSignature)
124
+ correct_args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * len(signature.arg_types)
125
+ return f"{res[0]}({', '.join(context(a, parens=False, unwrap_lit=True) for a in correct_args)})"
126
+ return res[0]
127
+ return res
128
+
129
+
130
+ # TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
131
+ # so that things like Math.__add__ will be represented properly
132
+
133
+
134
+ @dataclass
135
+ class TraverseContext:
136
+ """
137
+ State for traversing expressions (or declerations that contain expressions), so we can know how many parents each
138
+ expression has.
139
+ """
140
+
141
+ decls: Declarations
142
+
143
+ # All expressions we have seen (incremented the parent counts of all children)
144
+ _seen: set[AllDecls] = field(default_factory=set)
145
+ # The number of parents for each expressions
146
+ parents: Counter[AllDecls] = field(default_factory=Counter)
147
+
148
+ def pretty(self) -> PrettyContext:
149
+ """
150
+ Create a pretty context from the state of this traverse context.
151
+ """
152
+ return PrettyContext(self.decls, self.parents)
153
+
154
+ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901, PLR0912
155
+ if not toplevel:
156
+ self.parents[decl] += 1
157
+ if decl in self._seen:
158
+ return
159
+ match decl:
160
+ case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions):
161
+ self(lhs)
162
+ self(rhs)
163
+ for cond in conditions:
164
+ self(cond)
165
+ case RuleDecl(head, body, _):
166
+ for action in head:
167
+ self(action)
168
+ for fact in body:
169
+ self(fact)
170
+ case SetDecl(_, lhs, rhs) | UnionDecl(_, lhs, rhs) | EqDecl(_, lhs, rhs):
171
+ self(lhs)
172
+ self(rhs)
173
+ case LetDecl(_, d) | ExprActionDecl(d) | ExprFactDecl(d):
174
+ self(d.expr)
175
+ case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d):
176
+ self(d)
177
+ case PanicDecl(_) | UnboundVarDecl(_) | LetRefDecl(_) | LitDecl(_) | PyObjectDecl(_):
178
+ pass
179
+ case SequenceDecl(decls) | RulesetDecl(decls):
180
+ for de in decls:
181
+ if isinstance(de, DefaultRewriteDecl):
182
+ continue
183
+ self(de)
184
+ case CallDecl(ref, exprs, _):
185
+ match ref:
186
+ case FunctionRef(UnnamedFunctionRef(_, res)):
187
+ self(res.expr)
188
+ case _:
189
+ for e in exprs:
190
+ self(e.expr)
191
+ case RunDecl(_, until):
192
+ if until:
193
+ for f in until:
194
+ self(f)
195
+ case PartialCallDecl(c):
196
+ self(c)
197
+ case CombinedRulesetDecl(_):
198
+ pass
199
+ case DefaultRewriteDecl():
200
+ pass
201
+ case SetCostDecl(_, e, c):
202
+ self(e)
203
+ self(c)
204
+ case _:
205
+ assert_never(decl)
206
+
207
+ self._seen.add(decl)
208
+
209
+
210
+ @dataclass
211
+ class PrettyContext:
212
+ """
213
+
214
+ We need to build up a list of all the expressions we are pretty printing, so that we can see who has parents and who is mutated
215
+ and create temp variables for them.
216
+
217
+ """
218
+
219
+ decls: Declarations
220
+ parents: Mapping[AllDecls, int]
221
+
222
+ # All the expressions we have saved as names
223
+ names: dict[AllDecls, str] = field(default_factory=dict)
224
+ # A list of statements assigning variables or calling destructive ops
225
+ statements: list[str] = field(default_factory=list)
226
+ # Mapping of type to the number of times we have generated a name for that type, used to generate unique names
227
+ _gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
228
+
229
+ def __call__(
230
+ self, decl: AllDecls, *, unwrap_lit: bool = False, parens: bool = False, ruleset_name: str | None = None
231
+ ) -> str:
232
+ if decl in self.names:
233
+ return self.names[decl]
234
+ expr, tp_name = self.uncached(decl, unwrap_lit=unwrap_lit, parens=parens, ruleset_name=ruleset_name)
235
+ # We use a heuristic to decide whether to name this sub-expression as a variable
236
+ # The rough goal is to reduce the number of newlines, given our line length of ~180
237
+ # We determine it's worth making a new line for this expression if the total characters
238
+ # it would take up is > than some constant (~ line length).
239
+ line_diff: int = len(expr) - LINE_DIFFERENCE
240
+ n_parents = self.parents[decl]
241
+ if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
242
+ self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False)
243
+ return expr_name
244
+ return expr
245
+
246
+ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: C901, PLR0911, PLR0912
247
+ """
248
+ Returns a tuple of a string value of the decleration and the "type" to use when create a memoized cached version
249
+ for de-duplication.
250
+ """
251
+ match decl:
252
+ case LitDecl(value):
253
+ match value:
254
+ case None:
255
+ return "Unit()", "Unit"
256
+ case bool(b):
257
+ return str(b) if unwrap_lit else f"Bool({b})", "Bool"
258
+ case int(i):
259
+ return str(i) if unwrap_lit else f"i64({i})", "i64"
260
+ case float(f):
261
+ return str(f) if unwrap_lit else f"f64({f})", "f64"
262
+ case str(s):
263
+ return repr(s) if unwrap_lit else f"String({s!r})", "String"
264
+ assert_never(value)
265
+ case UnboundVarDecl(name) | LetRefDecl(name):
266
+ return name, name
267
+ case CallDecl(_, _, _):
268
+ return self._call(decl, parens)
269
+ case PartialCallDecl(CallDecl(ref, typed_args, _)):
270
+ return self._pretty_partial(ref, [a.expr for a in typed_args], parens), "fn"
271
+ case PyObjectDecl(value):
272
+ return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
273
+ case ActionCommandDecl(action):
274
+ return self(action), "action"
275
+ case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions):
276
+ args = ", ".join(map(self, (rhs, *conditions)))
277
+ fn = "rewrite" if isinstance(decl, RewriteDecl) else "birewrite"
278
+ return f"{fn}({self(lhs)}).to({args})", "rewrite"
279
+ case RuleDecl(head, body, name):
280
+ l = ", ".join(map(self, body))
281
+ if name:
282
+ l += f", name={name}"
283
+ r = ", ".join(map(self, head))
284
+ return f"rule({l}).then({r})", "rule"
285
+ case SetDecl(_, lhs, rhs):
286
+ return f"set_({self(lhs)}).to({self(rhs)})", "action"
287
+ case UnionDecl(_, lhs, rhs):
288
+ return f"union({self(lhs)}).with_({self(rhs)})", "action"
289
+ case LetDecl(name, expr):
290
+ return f"let({name!r}, {self(expr.expr)})", "action"
291
+ case ExprActionDecl(expr):
292
+ return self(expr.expr), "action"
293
+ case ExprFactDecl(expr):
294
+ return self(expr.expr), "fact"
295
+ case ChangeDecl(_, expr, change):
296
+ return f"{change}({self(expr)})", "action"
297
+ case PanicDecl(s):
298
+ return f"panic({s!r})", "action"
299
+ case SetCostDecl(_, expr, cost):
300
+ return f"set_cost({self(expr)}, {self(cost, unwrap_lit=True)})", "action"
301
+ case EqDecl(_, left, right):
302
+ return f"eq({self(left)}).to({self(right)})", "fact"
303
+ case RulesetDecl(rules):
304
+ if ruleset_name:
305
+ return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
306
+ args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
307
+ return f"ruleset({args})", "ruleset"
308
+ case CombinedRulesetDecl(rulesets):
309
+ if ruleset_name:
310
+ rulesets = (*rulesets, f"name={ruleset_name!r})")
311
+ return f"unstable_combine_rulesets({', '.join(rulesets)})", "combined_ruleset"
312
+ case SaturateDecl(schedule):
313
+ return f"{self(schedule, parens=True)}.saturate()", "schedule"
314
+ case RepeatDecl(schedule, times):
315
+ return f"{self(schedule, parens=True)} * {times}", "schedule"
316
+ case SequenceDecl(schedules):
317
+ if len(schedules) == 2:
318
+ return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule"
319
+ args = ", ".join(map(self, schedules))
320
+ return f"seq({args})", "schedule"
321
+ case RunDecl(ruleset_name, until):
322
+ ruleset = self.decls._rulesets[ruleset_name]
323
+ ruleset_str = self(ruleset, ruleset_name=ruleset_name)
324
+ if not until:
325
+ return ruleset_str, "schedule"
326
+ args = ", ".join(map(self, until))
327
+ return f"run({ruleset_str}, {args})", "schedule"
328
+ case DefaultRewriteDecl():
329
+ msg = "default rewrites should not be pretty printed"
330
+ raise TypeError(msg)
331
+ assert_never(decl)
332
+
333
+ def _call(
334
+ self,
335
+ decl: CallDecl,
336
+ parens: bool,
337
+ ) -> tuple[str, str]:
338
+ """
339
+ Pretty print the call. Also returns if it was saved as a name.
340
+
341
+ :param parens: If true, wrap the call in parens if it is a binary method call.
342
+ """
343
+ args = [a.expr for a in decl.args]
344
+ ref = decl.callable
345
+ # Special case !=
346
+ if decl.callable == FunctionRef("!="):
347
+ l, r = self(args[0]), self(args[1])
348
+ return f"ne({l}).to({r})", "Unit"
349
+ signature = self.decls.get_callable_decl(ref).signature
350
+
351
+ # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
352
+ n_defaults = 0
353
+ # Dont try counting defaults for function application
354
+ if isinstance(signature, FunctionSignature):
355
+ for arg, default in zip(
356
+ reversed(args), reversed(signature.arg_defaults), strict=not signature.var_arg_type
357
+ ):
358
+ if arg != default:
359
+ break
360
+ n_defaults += 1
361
+ if n_defaults:
362
+ args = args[:-n_defaults]
363
+
364
+ # If this is a function application, the type is the first type arg of the function object
365
+ if signature == "fn-app":
366
+ tp_name = decl.args[0].tp.args[0].name
367
+ else:
368
+ assert isinstance(signature, FunctionSignature)
369
+ tp_name = signature.semantic_return_type.name
370
+ if isinstance(signature, FunctionSignature) and signature.mutates:
371
+ first_arg = args[0]
372
+ expr_str = self(first_arg)
373
+ # copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly)
374
+ has_multiple_parents = self.parents[first_arg] > 1
375
+ self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
376
+ # Set the first arg to be the name of the mutated arg and return the name
377
+ args[0] = LetRefDecl(expr_name)
378
+ else:
379
+ expr_name = None
380
+ res = self._call_inner(ref, args, decl.bound_tp_params, parens)
381
+ expr = (
382
+ (f"{res[0]}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})")
383
+ if isinstance(res, tuple)
384
+ else res
385
+ )
386
+ # If we have a name, then we mutated
387
+ if expr_name:
388
+ self.statements.append(expr)
389
+ return expr_name, tp_name
390
+ return expr, tp_name
391
+
392
+ def _call_inner( # noqa: C901, PLR0911, PLR0912
393
+ self,
394
+ ref: CallableRef,
395
+ args: list[ExprDecl],
396
+ bound_tp_params: tuple[JustTypeRef, ...] | None,
397
+ parens: bool,
398
+ ) -> tuple[str, list[ExprDecl]] | str:
399
+ """
400
+ Pretty print the call, returning either the full function call or a tuple of the function and the args.
401
+ """
402
+ match ref:
403
+ case FunctionRef(name):
404
+ return name, args
405
+ case ClassMethodRef(class_name, method_name):
406
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
407
+ return f"{tp_ref}.{method_name}", args
408
+ case MethodRef(_class_name, method_name):
409
+ slf, *args = args
410
+ non_str_slf = slf
411
+ slf = self(slf, parens=True)
412
+ match method_name:
413
+ case _ if method_name in UNARY_METHODS:
414
+ expr = f"{UNARY_METHODS[method_name]}{slf}"
415
+ return f"({expr})" if parens else expr
416
+ case _ if method_name in BINARY_METHODS:
417
+ expr = f"{slf} {BINARY_METHODS[method_name]} {self(args[0], parens=True, unwrap_lit=True)}"
418
+ return f"({expr})" if parens else expr
419
+ case "__getitem__":
420
+ return f"{slf}[{self(args[0], unwrap_lit=True)}]"
421
+ case "__call__":
422
+ return slf, args
423
+ case "__delitem__":
424
+ return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
425
+ case "__setitem__":
426
+ return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
427
+ case "__round__":
428
+ return "round", [non_str_slf, *args]
429
+ case _:
430
+ return f"{slf}.{method_name}", args
431
+ case ConstantRef(name):
432
+ return name
433
+ case ClassVariableRef(class_name, variable_name):
434
+ return f"{class_name}.{variable_name}"
435
+ case PropertyRef(_class_name, property_name):
436
+ return f"{self(args[0], parens=True)}.{property_name}"
437
+ case InitRef(class_name):
438
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
439
+ return str(tp_ref), args
440
+ case UnnamedFunctionRef():
441
+ expr = self._pretty_function_body(ref, [])
442
+ return f"({expr})", args
443
+ assert_never(ref)
444
+
445
+ def _generate_name(self, typ: str) -> str:
446
+ self._gen_name_types[typ] += 1
447
+ return f"_{typ}_{self._gen_name_types[typ]}"
448
+
449
+ def _name_expr(self, tp_name: str, expr_str: str, copy_identifier: bool) -> str:
450
+ # tp_name =
451
+ # If the thing we are naming is already a variable, we don't need to name it
452
+ if expr_str.isidentifier():
453
+ if copy_identifier:
454
+ name = self._generate_name(tp_name)
455
+ self.statements.append(f"{name} = copy({expr_str})")
456
+ else:
457
+ name = expr_str
458
+ else:
459
+ name = self._generate_name(tp_name)
460
+ self.statements.append(f"{name} = {expr_str}")
461
+ return name
462
+
463
+ def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl], parens: bool) -> str:
464
+ """
465
+ Returns a partial function call as a string.
466
+ """
467
+ match ref:
468
+ case FunctionRef(name):
469
+ fn = name
470
+ case UnnamedFunctionRef():
471
+ res = self._pretty_function_body(ref, args)
472
+ return f"({res})" if parens else res
473
+ case (
474
+ ClassMethodRef(class_name, method_name)
475
+ | MethodRef(class_name, method_name)
476
+ | PropertyRef(class_name, method_name)
477
+ ):
478
+ fn = f"{class_name}.{method_name}"
479
+ case InitRef(class_name):
480
+ fn = class_name
481
+ case ConstantRef(_):
482
+ msg = "Constants should not be callable"
483
+ raise NotImplementedError(msg)
484
+ case ClassVariableRef(_, _):
485
+ msg = "Class variables should not be callable"
486
+ raise NotADirectoryError(msg)
487
+ case _:
488
+ assert_never(ref)
489
+ if not args:
490
+ return fn
491
+ arg_strs = (
492
+ fn,
493
+ *(self(a, parens=False, unwrap_lit=True) for a in args),
494
+ )
495
+ return f"partial({', '.join(arg_strs)})"
496
+
497
+ def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) -> str:
498
+ """
499
+ Pretty print the body of a function, partially applying some arguments.
500
+ """
501
+ var_args = fn.args
502
+ replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
503
+ var_args = var_args[len(args) :]
504
+ res = replace_typed_expr(fn.res, replacements)
505
+ arg_names = fn.args[len(args) :]
506
+ prefix = "lambda"
507
+ if arg_names:
508
+ prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
509
+ return f"{prefix}: {self(res.expr)}"
egglog/py.typed ADDED
File without changes