egglog 7.0.0__cp310-none-win_amd64.whl → 7.2.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph_state.py CHANGED
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
19
19
  if TYPE_CHECKING:
20
20
  from collections.abc import Iterable
21
21
 
22
- __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
22
+ __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "_rule_var_name"]
23
23
 
24
24
  # Create a global sort for python objects, so we can store them without an e-graph instance
25
25
  # Needed when serializing commands to egg commands when creating modules
@@ -87,17 +87,27 @@ class EGraphState:
87
87
  """
88
88
  Registers a ruleset if it's not already registered.
89
89
  """
90
- if name not in self.rulesets:
91
- if name:
92
- self.egraph.run_program(bindings.AddRuleset(name))
93
- rules = self.rulesets[name] = set()
94
- else:
95
- rules = self.rulesets[name]
96
- for rule in self.__egg_decls__._rulesets[name].rules:
97
- if rule in rules:
98
- continue
99
- self.egraph.run_program(self.command_to_egg(rule, name))
100
- rules.add(rule)
90
+ match self.__egg_decls__._rulesets[name]:
91
+ case RulesetDecl(rules):
92
+ if name not in self.rulesets:
93
+ if name:
94
+ self.egraph.run_program(bindings.AddRuleset(name))
95
+ added_rules = self.rulesets[name] = set()
96
+ else:
97
+ added_rules = self.rulesets[name]
98
+ for rule in rules:
99
+ if rule in added_rules:
100
+ continue
101
+ cmd = self.command_to_egg(rule, name)
102
+ self.egraph.run_program(cmd)
103
+ added_rules.add(rule)
104
+ case CombinedRulesetDecl(rulesets):
105
+ if name in self.rulesets:
106
+ return
107
+ self.rulesets[name] = set()
108
+ for ruleset in rulesets:
109
+ self.ruleset_to_egg(ruleset)
110
+ self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
101
111
 
102
112
  def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
103
113
  match cmd:
@@ -106,8 +116,8 @@ class EGraphState:
106
116
  case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
107
117
  self.type_ref_to_egg(tp)
108
118
  rewrite = bindings.Rewrite(
109
- self.expr_to_egg(lhs),
110
- self.expr_to_egg(rhs),
119
+ self._expr_to_egg(lhs),
120
+ self._expr_to_egg(rhs),
111
121
  [self.fact_to_egg(c) for c in conditions],
112
122
  )
113
123
  return (
@@ -121,6 +131,16 @@ class EGraphState:
121
131
  [self.fact_to_egg(f) for f in body],
122
132
  )
123
133
  return bindings.RuleCommand(name or "", ruleset, rule)
134
+ case DefaultRewriteDecl(ref, expr):
135
+ decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
136
+ sig = decl.signature
137
+ assert isinstance(sig, FunctionSignature)
138
+ args = tuple(
139
+ TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name)))
140
+ for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
141
+ )
142
+ rewrite_decl = RewriteDecl(sig.semantic_return_type.to_just(), CallDecl(ref, args), expr, (), False)
143
+ return self.command_to_egg(rewrite_decl, ruleset)
124
144
  case _:
125
145
  assert_never(cmd)
126
146
 
@@ -130,13 +150,13 @@ class EGraphState:
130
150
  return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
131
151
  case SetDecl(tp, call, rhs):
132
152
  self.type_ref_to_egg(tp)
133
- call_ = self.expr_to_egg(call)
134
- return bindings.Set(call_.name, call_.args, self.expr_to_egg(rhs))
153
+ call_ = self._expr_to_egg(call)
154
+ return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
135
155
  case ExprActionDecl(typed_expr):
136
156
  return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
137
157
  case ChangeDecl(tp, call, change):
138
158
  self.type_ref_to_egg(tp)
139
- call_ = self.expr_to_egg(call)
159
+ call_ = self._expr_to_egg(call)
140
160
  egg_change: bindings._Change
141
161
  match change:
142
162
  case "delete":
@@ -148,7 +168,7 @@ class EGraphState:
148
168
  return bindings.Change(egg_change, call_.name, call_.args)
149
169
  case UnionDecl(tp, lhs, rhs):
150
170
  self.type_ref_to_egg(tp)
151
- return bindings.Union(self.expr_to_egg(lhs), self.expr_to_egg(rhs))
171
+ return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
152
172
  case PanicDecl(name):
153
173
  return bindings.Panic(name)
154
174
  case _:
@@ -158,7 +178,7 @@ class EGraphState:
158
178
  match fact:
159
179
  case EqDecl(tp, exprs):
160
180
  self.type_ref_to_egg(tp)
161
- return bindings.Eq([self.expr_to_egg(e) for e in exprs])
181
+ return bindings.Eq([self._expr_to_egg(e) for e in exprs])
162
182
  case ExprFactDecl(typed_expr):
163
183
  return bindings.Fact(self.typed_expr_to_egg(typed_expr))
164
184
  case _:
@@ -184,14 +204,16 @@ class EGraphState:
184
204
  )
185
205
  case FunctionDecl():
186
206
  if not decl.builtin:
207
+ signature = decl.signature
208
+ assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
187
209
  egg_fn_decl = bindings.FunctionDecl(
188
210
  egg_name,
189
211
  bindings.Schema(
190
- [self.type_ref_to_egg(a.to_just()) for a in decl.arg_types],
191
- self.type_ref_to_egg(decl.semantic_return_type.to_just()),
212
+ [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
213
+ self.type_ref_to_egg(signature.semantic_return_type.to_just()),
192
214
  ),
193
- self.expr_to_egg(decl.default) if decl.default else None,
194
- self.expr_to_egg(decl.merge) if decl.merge else None,
215
+ self._expr_to_egg(decl.default) if decl.default else None,
216
+ self._expr_to_egg(decl.merge) if decl.merge else None,
195
217
  [self.action_to_egg(a) for a in decl.on_merge],
196
218
  decl.cost,
197
219
  decl.unextractable,
@@ -212,25 +234,30 @@ class EGraphState:
212
234
  decl = self.__egg_decls__._classes[ref.name]
213
235
  self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
214
236
  if not decl.builtin or ref.args:
215
- self.egraph.run_program(
216
- bindings.Sort(
217
- egg_name,
218
- (
219
- (
220
- self.type_ref_to_egg(JustTypeRef(ref.name)),
221
- [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args],
222
- )
223
- if ref.args
224
- else None
225
- ),
226
- )
227
- )
237
+ if ref.args:
238
+ if ref.name == "UnstableFn":
239
+ # UnstableFn is a special case, where the rest of args are collected into a call
240
+ type_args: list[bindings._Expr] = [
241
+ bindings.Call(
242
+ self.type_ref_to_egg(ref.args[1]),
243
+ [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
244
+ ),
245
+ bindings.Var(self.type_ref_to_egg(ref.args[0])),
246
+ ]
247
+ else:
248
+ type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
249
+ args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
250
+ else:
251
+ args = None
252
+ self.egraph.run_program(bindings.Sort(egg_name, args))
228
253
  # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
229
254
  # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
230
255
  # even if you never use that function.
231
256
  if decl.builtin:
232
257
  for method in decl.class_methods:
233
258
  self.callable_ref_to_egg(ClassMethodRef(ref.name, method))
259
+ if decl.init:
260
+ self.callable_ref_to_egg(InitRef(ref.name))
234
261
 
235
262
  return egg_name
236
263
 
@@ -247,15 +274,15 @@ class EGraphState:
247
274
 
248
275
  def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
249
276
  self.type_ref_to_egg(typed_expr_decl.tp)
250
- return self.expr_to_egg(typed_expr_decl.expr)
277
+ return self._expr_to_egg(typed_expr_decl.expr)
251
278
 
252
279
  @overload
253
- def expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
280
+ def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
254
281
 
255
282
  @overload
256
- def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
283
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
257
284
 
258
- def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
285
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
259
286
  """
260
287
  Convert an ExprDecl to an egg expression.
261
288
 
@@ -292,6 +319,9 @@ class EGraphState:
292
319
  res = bindings.Call(egg_fn, egg_args)
293
320
  case PyObjectDecl(value):
294
321
  res = GLOBAL_PY_OBJECT_SORT.store(value)
322
+ case PartialCallDecl(call_decl):
323
+ egg_fn_call = self._expr_to_egg(call_decl)
324
+ res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
295
325
  case _:
296
326
  assert_never(expr_decl.expr)
297
327
 
@@ -338,6 +368,8 @@ def _generate_callable_egg_name(ref: CallableRef) -> str:
338
368
  | PropertyRef(cls_name, name)
339
369
  ):
340
370
  return f"{cls_name}_{name}"
371
+ case InitRef(cls_name):
372
+ return f"{cls_name}___init__"
341
373
  case _:
342
374
  assert_never(ref)
343
375
 
@@ -371,26 +403,50 @@ class FromEggState:
371
403
  if term.name == "py-object":
372
404
  call = bindings.termdag_term_to_expr(self.termdag, term)
373
405
  expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
406
+ if term.name == "unstable-fn":
407
+ # Get function name
408
+ fn_term, *arg_terms = term.args
409
+ fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
410
+ assert isinstance(fn_value.expr, LitDecl)
411
+ fn_name = fn_value.expr.value
412
+ assert isinstance(fn_name, str)
413
+
414
+ # Resolve what types the partiallied applied args are
415
+ assert tp.name == "UnstableFn"
416
+ call_decl = self.from_call(tp.args[0], bindings.TermApp(fn_name, arg_terms))
417
+ expr_decl = PartialCallDecl(call_decl)
374
418
  else:
375
419
  expr_decl = self.from_call(tp, term)
376
420
  else:
377
421
  assert_never(term)
378
422
  return TypedExprDecl(tp, expr_decl)
379
423
 
380
- def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl:
424
+ def from_call(
425
+ self,
426
+ tp: JustTypeRef,
427
+ term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...]
428
+ ) -> CallDecl:
381
429
  """
382
430
  Convert a call to a CallDecl.
383
431
 
384
432
  There could be Python call refs which match the call, so we need to find the correct one.
433
+
434
+ The additional_arg_tps are known types for arguments that come after the term args, used to infer types
435
+ for partially applied functions, where we know the types of the later args, but not of the earlier ones where
436
+ we have values for.
385
437
  """
386
438
  # Find the first callable ref that matches the call
387
439
  for callable_ref in self.state.egg_fn_to_callable_refs[term.name]:
388
440
  # If this is a classmethod, we might need the type params that were bound for this type
389
441
  # This could be multiple types if the classmethod is ambiguous, like map create.
390
442
  possible_types: Iterable[JustTypeRef | None]
391
- fn_decl = self.decls.get_callable_decl(callable_ref).to_function_decl()
392
- if isinstance(callable_ref, ClassMethodRef):
393
- possible_types = self.state._get_possible_types(callable_ref.class_name)
443
+ signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
444
+ assert isinstance(signature, FunctionSignature)
445
+ if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
446
+ # Need OR in case we have class method whose class whas never added as a sort, which would happen
447
+ # if the class method didn't return that type and no other function did. In this case, we don't need
448
+ # to care about the type vars and we we don't need to bind any possible type.
449
+ possible_types = self.state._get_possible_types(callable_ref.class_name) or [None]
394
450
  cls_name = callable_ref.class_name
395
451
  else:
396
452
  possible_types = [None]
@@ -399,19 +455,33 @@ class FromEggState:
399
455
  tcs = TypeConstraintSolver(self.decls)
400
456
  if possible_type and possible_type.args:
401
457
  tcs.bind_class(possible_type)
402
-
403
458
  try:
404
459
  arg_types, bound_tp_params = tcs.infer_arg_types(
405
- fn_decl.arg_types, fn_decl.semantic_return_type, fn_decl.var_arg_type, tp, cls_name
460
+ signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
406
461
  )
407
462
  except TypeConstraintError:
408
463
  continue
409
- args: list[TypedExprDecl] = []
410
- for a, tp in zip(term.args, arg_types, strict=False):
411
- try:
412
- res = self.cache[a]
413
- except KeyError:
414
- res = self.cache[a] = self.from_expr(tp, self.termdag.nodes[a])
415
- args.append(res)
416
- return CallDecl(callable_ref, tuple(args), bound_tp_params)
464
+ args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
465
+
466
+ return CallDecl(
467
+ callable_ref,
468
+ args,
469
+ # Don't include bound type params if this is just a method, we only needed them for type resolution
470
+ # but dont need to store them
471
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
472
+ )
417
473
  raise ValueError(f"Could not find callable ref for call {term}")
474
+
475
+ def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
476
+ try:
477
+ return self.cache[term_id]
478
+ except KeyError:
479
+ res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
480
+ return res
481
+
482
+
483
+ def _rule_var_name(s: str) -> str:
484
+ """
485
+ Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
486
+ """
487
+ return f"__var__{s}"
@@ -0,0 +1,50 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Higher Order Functions
4
+ ======================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ from egglog import *
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Callable
15
+
16
+
17
+ class Math(Expr):
18
+ def __init__(self, i: i64Like) -> None: ...
19
+
20
+ def __add__(self, other: Math) -> Math: ...
21
+
22
+
23
+ class MathList(Expr):
24
+ def __init__(self) -> None: ...
25
+
26
+ def append(self, i: Math) -> MathList: ...
27
+
28
+ def map(self, f: Callable[[Math], Math]) -> MathList: ...
29
+
30
+
31
+ @ruleset
32
+ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
33
+ yield rewrite(Math(i) + Math(j)).to(Math(i + j))
34
+ yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
35
+ yield rewrite(MathList().map(f)).to(MathList())
36
+
37
+
38
+ @function(ruleset=math_ruleset)
39
+ def increment_by_one(x: Math) -> Math:
40
+ return x + Math(1)
41
+
42
+
43
+ egraph = EGraph()
44
+ x = egraph.let("x", MathList().append(Math(1)).append(Math(2)))
45
+ y = egraph.let("y", x.map(increment_by_one))
46
+ egraph.run(math_ruleset.saturate())
47
+
48
+ egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
49
+
50
+ egraph
egglog/exp/array_api.py CHANGED
@@ -18,7 +18,7 @@ from egglog.runtime import RuntimeExpr
18
18
  from .program_gen import *
19
19
 
20
20
  if TYPE_CHECKING:
21
- from collections.abc import Iterator
21
+ from collections.abc import Callable, Iterator
22
22
  from types import ModuleType
23
23
 
24
24
  # Pretend that exprs are numbers b/c sklearn does isinstance checks
@@ -257,7 +257,7 @@ class TupleInt(Expr):
257
257
 
258
258
  def __getitem__(self, i: Int) -> Int: ...
259
259
 
260
- def product(self) -> Int: ...
260
+ def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
261
261
 
262
262
 
263
263
  converter(
@@ -272,7 +272,7 @@ converter(
272
272
 
273
273
 
274
274
  @array_api_ruleset.register
275
- def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
275
+ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64, f: Callable[[Int, Int], Int]):
276
276
  return [
277
277
  rewrite(ti + TupleInt.EMPTY).to(ti),
278
278
  rewrite(TupleInt(i).length()).to(Int(1)),
@@ -281,10 +281,10 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
281
281
  rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
282
282
  # Rule for indexing > 0
283
283
  rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
284
- # Product
285
- rewrite(TupleInt(i).product()).to(i),
286
- rewrite((TupleInt(i) + ti).product()).to(i * ti.product()),
287
- rewrite(TupleInt.EMPTY.product()).to(Int(1)),
284
+ # fold
285
+ rewrite(TupleInt.EMPTY.fold(i, f)).to(i),
286
+ rewrite(TupleInt(i2).fold(i, f)).to(f(i, i2)),
287
+ rewrite((TupleInt(i2) + ti).fold(i, f)).to(ti.fold(f(i, i2), f)),
288
288
  ]
289
289
 
290
290
 
@@ -882,7 +882,10 @@ converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
882
882
 
883
883
  @function
884
884
  def asarray(
885
- a: NDArray, dtype: OptionalDType = OptionalDType.none, copy: OptionalBool = OptionalBool.none
885
+ a: NDArray,
886
+ dtype: OptionalDType = OptionalDType.none,
887
+ copy: OptionalBool = OptionalBool.none,
888
+ device: OptionalDevice = OptionalDevice.none,
886
889
  ) -> NDArray: ...
887
890
 
888
891
 
@@ -1346,7 +1349,7 @@ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
1346
1349
 
1347
1350
  @array_api_ruleset.register
1348
1351
  def _size(x: NDArray):
1349
- yield rewrite(x.size).to(x.shape.product())
1352
+ yield rewrite(x.size).to(x.shape.fold(Int(1), Int.__mul__))
1350
1353
 
1351
1354
 
1352
1355
  @overload
egglog/pretty.py CHANGED
@@ -66,7 +66,7 @@ UNARY_METHODS = {
66
66
  "__invert__": "~",
67
67
  }
68
68
 
69
- AllDecls: TypeAlias = RulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
69
+ AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
70
70
 
71
71
 
72
72
  def pretty_decl(
@@ -106,7 +106,7 @@ def pretty_callable_ref(
106
106
  """
107
107
  # Pass in three dummy args, which are the max used for any operation that
108
108
  # is not a generic function call
109
- args: list[ExprDecl] = [LitDecl(ARG_STR)] * 3
109
+ args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
110
110
  if first_arg:
111
111
  args.insert(0, first_arg)
112
112
  res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner(
@@ -117,6 +117,10 @@ def pretty_callable_ref(
117
117
  return res[0] if isinstance(res, tuple) else res
118
118
 
119
119
 
120
+ # TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
121
+ # so that things like Math.__add__ will be represented properly
122
+
123
+
120
124
  @dataclass
121
125
  class TraverseContext:
122
126
  """
@@ -162,6 +166,8 @@ class TraverseContext:
162
166
  pass
163
167
  case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls):
164
168
  for de in decls:
169
+ if isinstance(de, DefaultRewriteDecl):
170
+ continue
165
171
  self(de)
166
172
  case CallDecl(_, exprs, _):
167
173
  for e in exprs:
@@ -170,6 +176,12 @@ class TraverseContext:
170
176
  if until:
171
177
  for f in until:
172
178
  self(f)
179
+ case PartialCallDecl(c):
180
+ self(c)
181
+ case CombinedRulesetDecl(_):
182
+ pass
183
+ case DefaultRewriteDecl():
184
+ pass
173
185
  case _:
174
186
  assert_never(decl)
175
187
 
@@ -231,6 +243,9 @@ class PrettyContext:
231
243
  return name, name
232
244
  case CallDecl(_, _, _):
233
245
  return self._call(decl, parens)
246
+ case PartialCallDecl(CallDecl(ref, typed_args, _)):
247
+ arg_strs = (_pretty_callable(ref), *(self(a.expr, parens=False, unwrap_lit=True) for a in typed_args))
248
+ return f"UnstableFn({', '.join(arg_strs)})", "fn"
234
249
  case PyObjectDecl(value):
235
250
  return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
236
251
  case ActionCommandDecl(action):
@@ -265,8 +280,12 @@ class PrettyContext:
265
280
  case RulesetDecl(rules):
266
281
  if ruleset_name:
267
282
  return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
268
- args = ", ".join(map(self, rules))
283
+ args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
269
284
  return f"ruleset({args})", "ruleset"
285
+ case CombinedRulesetDecl(rulesets):
286
+ if ruleset_name:
287
+ rulesets = (*rulesets, f"name={ruleset_name!r})")
288
+ return f"unstable_combine_rulesets({', '.join(rulesets)})", "combined_ruleset"
270
289
  case SaturateDecl(schedule):
271
290
  return f"{self(schedule, parens=True)}.saturate()", "schedule"
272
291
  case RepeatDecl(schedule, times):
@@ -283,6 +302,9 @@ class PrettyContext:
283
302
  return ruleset_str, "schedule"
284
303
  args = ", ".join(map(self, until))
285
304
  return f"run({ruleset_str}, {args})", "schedule"
305
+ case DefaultRewriteDecl():
306
+ msg = "default rewrites should not be pretty printed"
307
+ raise TypeError(msg)
286
308
  assert_never(decl)
287
309
 
288
310
  def _call(
@@ -302,19 +324,28 @@ class PrettyContext:
302
324
  l, r = self(args[0]), self(args[1])
303
325
  return f"ne({l}).to({r})", "Unit"
304
326
  function_decl = self.decls.get_callable_decl(ref).to_function_decl()
327
+ signature = function_decl.signature
328
+
305
329
  # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
306
330
  n_defaults = 0
307
- for arg, default in zip(
308
- reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type
309
- ):
310
- if arg != default:
311
- break
312
- n_defaults += 1
331
+ # Dont try counting defaults for function application
332
+ if isinstance(signature, FunctionSignature):
333
+ for arg, default in zip(
334
+ reversed(args), reversed(signature.arg_defaults), strict=not signature.var_arg_type
335
+ ):
336
+ if arg != default:
337
+ break
338
+ n_defaults += 1
313
339
  if n_defaults:
314
340
  args = args[:-n_defaults]
315
341
 
316
- tp_name = function_decl.semantic_return_type.name
317
- if function_decl.mutates:
342
+ # If this is a function application, the type is the first type arg of the function object
343
+ if signature == "fn-app":
344
+ tp_name = decl.args[0].tp.args[0].name
345
+ else:
346
+ assert isinstance(signature, FunctionSignature)
347
+ tp_name = signature.semantic_return_type.name
348
+ if isinstance(signature, FunctionSignature) and signature.mutates:
318
349
  first_arg = args[0]
319
350
  expr_str = self(first_arg)
320
351
  # copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly)
@@ -346,10 +377,8 @@ class PrettyContext:
346
377
  case FunctionRef(name):
347
378
  return name, args
348
379
  case ClassMethodRef(class_name, method_name):
349
- fn_str = str(JustTypeRef(class_name, bound_tp_params or ()))
350
- if method_name != "__init__":
351
- fn_str += f".{method_name}"
352
- return fn_str, args
380
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
381
+ return f"{tp_ref}.{method_name}", args
353
382
  case MethodRef(_class_name, method_name):
354
383
  slf, *args = args
355
384
  slf = self(slf, parens=True)
@@ -376,6 +405,9 @@ class PrettyContext:
376
405
  return f"{class_name}.{variable_name}"
377
406
  case PropertyRef(_class_name, property_name):
378
407
  return f"{self(args[0], parens=True)}.{property_name}"
408
+ case InitRef(class_name):
409
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
410
+ return str(tp_ref), args
379
411
  assert_never(ref)
380
412
 
381
413
  def _generate_name(self, typ: str) -> str:
@@ -397,6 +429,30 @@ class PrettyContext:
397
429
  return name
398
430
 
399
431
 
432
+ def _pretty_callable(ref: CallableRef) -> str:
433
+ """
434
+ Returns a function call as a string.
435
+ """
436
+ match ref:
437
+ case FunctionRef(name):
438
+ return name
439
+ case (
440
+ ClassMethodRef(class_name, method_name)
441
+ | MethodRef(class_name, method_name)
442
+ | PropertyRef(class_name, method_name)
443
+ ):
444
+ return f"{class_name}.{method_name}"
445
+ case InitRef(class_name):
446
+ return class_name
447
+ case ConstantRef(_):
448
+ msg = "Constants should not be callable"
449
+ raise NotImplementedError(msg)
450
+ case ClassVariableRef(_, _):
451
+ msg = "Class variables should not be callable"
452
+ raise NotADirectoryError(msg)
453
+ assert_never(ref)
454
+
455
+
400
456
  def _plot_line_length(expr: object): # pragma: no cover
401
457
  """
402
458
  Plots the number of line lengths based on different max lengths