egglog 7.1.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

@@ -0,0 +1,145 @@
1
+ """
2
+ In progress module
3
+
4
+ https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
5
+
6
+ """
7
+
8
+ # %%
9
+ # mypy: disable-error-code="empty-body"
10
+ from __future__ import annotations
11
+
12
+ from egglog import *
13
+ from egglog.exp.array_api import *
14
+
15
+
16
+ class ShapeAPI(Expr):
17
+ @method(unextractable=True)
18
+ def __init__(self, dims: TupleIntLike) -> None: ...
19
+
20
+ @method(unextractable=True)
21
+ def deselect(self, axis: TupleIntLike) -> ShapeAPI: ...
22
+
23
+ @method(unextractable=True)
24
+ def select(self, axis: TupleIntLike) -> ShapeAPI: ...
25
+
26
+ @method(unextractable=True)
27
+ def to_tuple(self) -> TupleInt: ...
28
+
29
+
30
+ @array_api_ruleset.register
31
+ def shape_api_ruleset(dims: TupleInt, axis: TupleInt): # noqa: ANN201
32
+ s = ShapeAPI(dims)
33
+ yield rewrite(s.deselect(axis)).to(
34
+ ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
35
+ )
36
+ yield rewrite(s.select(axis)).to(
37
+ ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
38
+ )
39
+ yield rewrite(s.to_tuple()).to(dims)
40
+
41
+
42
+ class OptionalLoopNestAPI(Expr):
43
+ def __init__(self, value: LoopNestAPI) -> None: ...
44
+
45
+ NONE: ClassVar[OptionalLoopNestAPI]
46
+
47
+ def unwrap(self) -> LoopNestAPI: ...
48
+
49
+
50
+ class LoopNestAPI(Expr):
51
+ def __init__(self, dim: Int, inner: OptionalLoopNestAPI) -> None: ...
52
+
53
+ @classmethod
54
+ def from_tuple(cls, args: TupleInt) -> OptionalLoopNestAPI: ...
55
+
56
+ @method(preserve=True)
57
+ def __iter__(self) -> Iterator[TupleInt]:
58
+ return iter(self.indices)
59
+
60
+ @property
61
+ def indices(self) -> TupleTupleInt: ...
62
+
63
+ def get_dims(self) -> TupleInt: ...
64
+
65
+ def fold(self, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArrayLike) -> NDArray: ...
66
+
67
+
68
+ @function
69
+ def tuple_tuple_int_reduce_ndarray(
70
+ xs: TupleTupleInt, fn: Callable[[NDArray, TupleInt], NDArray], init: NDArray
71
+ ) -> NDArray: ...
72
+
73
+
74
+ @function
75
+ def tuple_int_map_tuple_int(xs: TupleInt, fn: Callable[[Int], TupleInt]) -> TupleTupleInt: ...
76
+
77
+
78
+ @function
79
+ def tuple_tuple_int_product(xs: TupleTupleInt) -> TupleTupleInt: ...
80
+
81
+
82
+ @array_api_ruleset.register
83
+ def _loopnest_api_ruleset(
84
+ head: Int,
85
+ tail: TupleInt,
86
+ lna: LoopNestAPI,
87
+ fn: Callable[[NDArray, TupleInt], NDArray],
88
+ init: NDArray,
89
+ dim: Int,
90
+ idx_fn: Callable[[Int], Int],
91
+ i: i64,
92
+ ):
93
+ # from_tuple
94
+ yield rewrite(LoopNestAPI.from_tuple(TupleInt(0, idx_fn))).to(OptionalLoopNestAPI.NONE)
95
+ yield rewrite(LoopNestAPI.from_tuple(TupleInt(Int(i), idx_fn))).to(
96
+ OptionalLoopNestAPI(
97
+ LoopNestAPI(idx_fn(Int(0)), LoopNestAPI.from_tuple(TupleInt(Int(i - 1), lambda i: idx_fn(i + 1))))
98
+ ),
99
+ ne(i).to(i64(0)),
100
+ )
101
+ # reduce
102
+ yield rewrite(lna.fold(fn, init)).to(tuple_tuple_int_reduce_ndarray(lna.indices, fn, init))
103
+ # get_dims
104
+ yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI.NONE).get_dims()).to(TupleInt.single(dim))
105
+ yield rewrite(LoopNestAPI(dim, OptionalLoopNestAPI(lna)).get_dims()).to(TupleInt.single(dim) + lna.get_dims())
106
+ # indices
107
+ yield rewrite(lna.indices).to(tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range)))
108
+
109
+
110
+ @function(ruleset=array_api_ruleset, unextractable=True)
111
+ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: # noqa: N803
112
+ # peel off the outer shape for result array
113
+ outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
114
+ # get only the inner shape for reduction
115
+ reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
116
+
117
+ return NDArray(
118
+ outshape,
119
+ X.dtype,
120
+ lambda k: sqrt(
121
+ LoopNestAPI.from_tuple(reduce_axis)
122
+ .unwrap()
123
+ .fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
124
+ ).to_value(),
125
+ )
126
+
127
+
128
+ # %%
129
+ egraph = EGraph(save_egglog_string=True)
130
+
131
+ X = NDArray.var("X")
132
+ assume_shape(X, (3, 2, 3, 4))
133
+ val = linalg_norm(X, (0, 1))
134
+ egraph.register(val.shape)
135
+ egraph.run(array_api_ruleset.saturate())
136
+ egraph.extract_multiple(val.shape, 10)
137
+
138
+ # %%
139
+ egraph = EGraph()
140
+ egraph.register(val.shape[2])
141
+ egraph.run(array_api_ruleset.saturate())
142
+ egraph.display(split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
143
+
144
+
145
+ # %%
@@ -67,6 +67,6 @@ def _unique_inverse(x: NDArray, i: Int):
67
67
  return [
68
68
  # Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
69
69
  rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
70
- x == NDArray.scalar(unique_values(x).index(TupleInt(i)))
70
+ x == NDArray.scalar(unique_values(x).index((i,)))
71
71
  ),
72
72
  ]
@@ -56,22 +56,42 @@ def _int_program(i64_: i64, i: Int, j: Int):
56
56
 
57
57
 
58
58
  @function
59
- def tuple_int_program(x: TupleInt) -> Program: ...
59
+ def tuple_int_program(x: TupleInt) -> Program:
60
+ ...
61
+ # Could be rewritten as a fold, but we don't support generic folds yet
62
+ # return x.fold(Program("("), lambda acc, i: acc + ", " + int_program(i)) + ")"
60
63
 
61
64
 
62
65
  @function
63
- def tuple_int_program_inner(x: TupleInt) -> Program: ...
66
+ def tuple_int_program_inner(x: TupleInt) -> Program:
67
+ """
68
+ Returns the tuple w/ out the parenthesis
69
+ """
64
70
 
65
71
 
66
72
  @array_api_program_gen_ruleset.register
67
- def _tuple_int_program(i: Int, j: Int, ti: TupleInt, ti1: TupleInt, ti2: TupleInt):
73
+ def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int], vec_int: Vec[Int]):
68
74
  yield rewrite(int_program(ti[i])).to(tuple_int_program(ti) + "[" + int_program(i) + "]")
75
+ yield rewrite(int_program(ti.length())).to(Program("len(") + tuple_int_program(ti) + ")")
69
76
 
70
77
  yield rewrite(tuple_int_program(ti)).to(Program("(") + tuple_int_program_inner(ti) + ")")
71
- yield rewrite(tuple_int_program_inner(ti1 + ti2)).to(
72
- tuple_int_program_inner(ti1) + " " + tuple_int_program_inner(ti2)
78
+
79
+ yield rewrite(tuple_int_program_inner(TupleInt(0, idx_fn))).to(Program(""))
80
+
81
+ yield rewrite(tuple_int_program_inner(TupleInt(Int(k), idx_fn))).to(
82
+ int_program(idx_fn(Int(0))) + ", " + tuple_int_program_inner(TupleInt(Int(k - 1), lambda i: idx_fn(i + 1))),
83
+ ne(k).to(i64(0)),
84
+ )
85
+
86
+ yield rewrite(tuple_int_program_inner(TupleInt.from_vec(Vec[Int]()))).to(Program(""))
87
+ yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
88
+ int_program(vec_int[0]) + ", " + tuple_int_program_inner(TupleInt.from_vec(vec_int.remove(0))),
89
+ vec_int.length() > 1,
90
+ )
91
+ yield rewrite(tuple_int_program_inner(TupleInt.from_vec(vec_int))).to(
92
+ int_program(vec_int[0]) + ",",
93
+ eq(vec_int.length()).to(i64(1)),
73
94
  )
74
- yield rewrite(tuple_int_program_inner(TupleInt(i))).to(int_program(i) + ",")
75
95
 
76
96
 
77
97
  @function
@@ -248,12 +268,31 @@ def multi_axis_index_key_program(x: MultiAxisIndexKey) -> Program: ...
248
268
 
249
269
 
250
270
  @array_api_program_gen_ruleset.register
251
- def _multi_axis_index_key_program(l: MultiAxisIndexKey, r: MultiAxisIndexKey, item: MultiAxisIndexKeyItem):
252
- yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(item))).to(multi_axis_index_key_item_program(item))
253
- yield rewrite(multi_axis_index_key_program(l + r)).to(
254
- multi_axis_index_key_program(l) + ", " + multi_axis_index_key_program(r)
271
+ def _multi_axis_index_key_program(
272
+ idx_fn: Callable[[Int], MultiAxisIndexKeyItem], k: i64, vec: Vec[MultiAxisIndexKeyItem], i: MultiAxisIndexKeyItem
273
+ ):
274
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(0, idx_fn))).to(Program(""))
275
+
276
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey(Int(k), idx_fn))).to(
277
+ multi_axis_index_key_item_program(idx_fn(Int(0)))
278
+ + ", "
279
+ + multi_axis_index_key_program(MultiAxisIndexKey(Int(k - 1), lambda i: idx_fn(i + 1))),
280
+ ne(k).to(i64(0)),
281
+ )
282
+
283
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem]()))).to(
284
+ Program("")
285
+ )
286
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
287
+ multi_axis_index_key_item_program(vec[0]) + ",",
288
+ eq(vec.length()).to(i64(1)),
289
+ )
290
+ yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec))).to(
291
+ multi_axis_index_key_item_program(vec[0])
292
+ + ", "
293
+ + multi_axis_index_key_program(MultiAxisIndexKey.from_vec(vec.remove(0))),
294
+ vec.length() > 1,
255
295
  )
256
- yield rewrite(multi_axis_index_key_program(MultiAxisIndexKey.EMPTY)).to(Program("()"))
257
296
 
258
297
 
259
298
  @function
@@ -266,7 +305,7 @@ def _index_key_program(i: Int, s: Slice, key: MultiAxisIndexKey, a: NDArray):
266
305
  yield rewrite(index_key_program(IndexKey.int(i))).to(int_program(i))
267
306
  yield rewrite(index_key_program(IndexKey.slice(s))).to(slice_program(s))
268
307
  yield rewrite(index_key_program(IndexKey.multi_axis(key))).to(multi_axis_index_key_program(key))
269
- yield rewrite(index_key_program(ndarray_index(a))).to(ndarray_program(a))
308
+ yield rewrite(index_key_program(IndexKey.ndarray(a))).to(ndarray_program(a))
270
309
 
271
310
 
272
311
  @function
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from inspect import Parameter, signature
6
+ from typing import Any, TypeVar, cast
7
+
8
+ __all__ = ["functionalize"]
9
+
10
+
11
+ T = TypeVar("T", bound=Callable)
12
+
13
+
14
+ # TODO: Add `to_lift` param so that we only transform those with vars in them to args
15
+
16
+
17
+ def functionalize(f: T, get_annotation: Callable[[object], type | None]) -> T:
18
+ """
19
+ Takes a function and returns a new function with all names (co_names) and free variables (co_freevars) added as arguments
20
+ and then partially applied with their values. The second arg, get_annotation, will be applied to all values
21
+ to get their type annotation. If it is None, that arg will not be added as a parameter.
22
+
23
+ For example if you have:
24
+
25
+ def get_annotation(x): return int if x <= 10 else None
26
+
27
+ g = 10
28
+ def f(a, a2):
29
+ def h(b: Z):
30
+ return a + a2 + b + g
31
+
32
+ return functionalize(h, get_annotation)
33
+ res = f(9, 11)
34
+
35
+ It should be equivalent to (according to body, signature, and annotations) (Note that the new arguments will be positional only):
36
+
37
+ def h(a: get_annotation(a), g: get_annotation(g), b: Z):
38
+ return a + b + g
39
+ res = partial(h, a, g)
40
+ """
41
+ code = f.__code__
42
+ names = tuple(n for n in code.co_names if n in f.__globals__)
43
+ free_vars = code.co_freevars
44
+
45
+ global_values: list[Any] = [f.__globals__[name] for name in names]
46
+ free_var_values = [cell.cell_contents for cell in f.__closure__] if f.__closure__ else []
47
+ assert len(free_var_values) == len(free_vars), "Free vars and their values do not match"
48
+ global_values_filtered = [
49
+ (i, name, value, annotation)
50
+ for i, (name, value) in enumerate(zip(names, global_values, strict=True))
51
+ if (annotation := get_annotation(value)) is not None
52
+ ]
53
+ free_var_values_filtered = [
54
+ (i, name, value, annotation)
55
+ for i, (name, value) in enumerate(zip(free_vars, free_var_values, strict=True))
56
+ if (annotation := get_annotation(value)) is not None
57
+ ]
58
+ additional_arg_filtered = global_values_filtered + free_var_values_filtered
59
+
60
+ # Create a wrapper function
61
+ def wrapper(*args):
62
+ # Split args into names, free vars and other args
63
+ name_args, free_var_args, rest_args = (
64
+ args[: (n_names := len(global_values_filtered))],
65
+ args[n_names : (n_args := len(additional_arg_filtered))],
66
+ args[n_args:],
67
+ )
68
+ # Update globals with names
69
+ f.__globals__.update({
70
+ name: arg for (_, name, _, _), arg in zip(global_values_filtered, name_args, strict=False)
71
+ })
72
+ # update function free vars with free var args
73
+ for (i, _, _, _), value in zip(free_var_values_filtered, free_var_args, strict=True):
74
+ assert f.__closure__, "Function does not have closure"
75
+ f.__closure__[i].cell_contents = value
76
+ return f(*rest_args)
77
+
78
+ # Set the signature of the new function to a signature with the free vars and names added as arguments
79
+ orig_signature = signature(f)
80
+ wrapper.__signature__ = orig_signature.replace( # type: ignore[attr-defined]
81
+ parameters=[
82
+ *[Parameter(n, Parameter.POSITIONAL_OR_KEYWORD) for _, n, _, _ in additional_arg_filtered],
83
+ *orig_signature.parameters.values(),
84
+ ]
85
+ )
86
+ # Set the annotations of the new function to the annotations of the original function + annotations of passed in values
87
+ wrapper.__annotations__ = f.__annotations__ | {n: a for _, n, _, a in additional_arg_filtered}
88
+ wrapper.__name__ = f.__name__
89
+
90
+ # Partially apply the wrapper function with the current values of the free vars
91
+ return cast(T, partial(wrapper, *(v for _, _, v, _ in additional_arg_filtered)))
egglog/pretty.py CHANGED
@@ -16,6 +16,7 @@ from .declarations import *
16
16
  if TYPE_CHECKING:
17
17
  from collections.abc import Mapping
18
18
 
19
+
19
20
  __all__ = [
20
21
  "pretty_decl",
21
22
  "pretty_callable_ref",
@@ -77,9 +78,9 @@ def pretty_decl(
77
78
 
78
79
  This will use re-format the result and put the expression on the last line, preceeded by the statements.
79
80
  """
80
- traverse = TraverseContext()
81
+ traverse = TraverseContext(decls)
81
82
  traverse(decl, toplevel=True)
82
- pretty = traverse.pretty(decls)
83
+ pretty = traverse.pretty()
83
84
  expr = pretty(decl, ruleset_name=ruleset_name)
84
85
  if wrapping_fn:
85
86
  expr = f"{wrapping_fn}({expr})"
@@ -106,15 +107,20 @@ def pretty_callable_ref(
106
107
  """
107
108
  # Pass in three dummy args, which are the max used for any operation that
108
109
  # is not a generic function call
109
- args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
110
+ args: list[ExprDecl] = [VarDecl(ARG_STR, False)] * 3
110
111
  if first_arg:
111
112
  args.insert(0, first_arg)
112
- res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner(
113
- ref, args, bound_tp_params=bound_tp_params, parens=False
114
- )
113
+ context = PrettyContext(decls, defaultdict(lambda: 0))
114
+ res = context._call_inner(ref, args, bound_tp_params=bound_tp_params, parens=False)
115
115
  # Either returns a function or a function with args. If args are provided, they would just be called,
116
116
  # on the function, so return them, because they are dummies
117
- return res[0] if isinstance(res, tuple) else res
117
+ if isinstance(res, tuple):
118
+ name = res[0]
119
+ # if this is an unnamed function, return it but don't partially apply any args
120
+ if isinstance(name, UnnamedFunctionRef):
121
+ return context._pretty_function_body(name, [])
122
+ return name
123
+ return res
118
124
 
119
125
 
120
126
  # TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
@@ -128,16 +134,18 @@ class TraverseContext:
128
134
  expression has.
129
135
  """
130
136
 
137
+ decls: Declarations
138
+
131
139
  # All expressions we have seen (incremented the parent counts of all children)
132
140
  _seen: set[AllDecls] = field(default_factory=set)
133
141
  # The number of parents for each expressions
134
142
  parents: Counter[AllDecls] = field(default_factory=Counter)
135
143
 
136
- def pretty(self, decls: Declarations) -> PrettyContext:
144
+ def pretty(self) -> PrettyContext:
137
145
  """
138
146
  Create a pretty context from the state of this traverse context.
139
147
  """
140
- return PrettyContext(decls, self.parents)
148
+ return PrettyContext(self.decls, self.parents)
141
149
 
142
150
  def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901
143
151
  if not toplevel:
@@ -166,10 +174,16 @@ class TraverseContext:
166
174
  pass
167
175
  case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls):
168
176
  for de in decls:
177
+ if isinstance(de, DefaultRewriteDecl):
178
+ continue
169
179
  self(de)
170
- case CallDecl(_, exprs, _):
171
- for e in exprs:
172
- self(e.expr)
180
+ case CallDecl(ref, exprs, _):
181
+ match ref:
182
+ case FunctionRef(UnnamedFunctionRef(_, res)):
183
+ self(res.expr)
184
+ case _:
185
+ for e in exprs:
186
+ self(e.expr)
173
187
  case RunDecl(_, until):
174
188
  if until:
175
189
  for f in until:
@@ -178,6 +192,8 @@ class TraverseContext:
178
192
  self(c)
179
193
  case CombinedRulesetDecl(_):
180
194
  pass
195
+ case DefaultRewriteDecl():
196
+ pass
181
197
  case _:
182
198
  assert_never(decl)
183
199
 
@@ -240,8 +256,7 @@ class PrettyContext:
240
256
  case CallDecl(_, _, _):
241
257
  return self._call(decl, parens)
242
258
  case PartialCallDecl(CallDecl(ref, typed_args, _)):
243
- arg_strs = (_pretty_callable(ref), *(self(a.expr, parens=False, unwrap_lit=True) for a in typed_args))
244
- return f"UnstableFn({', '.join(arg_strs)})", "fn"
259
+ return self._pretty_partial(ref, [a.expr for a in typed_args]), "fn"
245
260
  case PyObjectDecl(value):
246
261
  return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
247
262
  case ActionCommandDecl(action):
@@ -276,7 +291,7 @@ class PrettyContext:
276
291
  case RulesetDecl(rules):
277
292
  if ruleset_name:
278
293
  return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
279
- args = ", ".join(map(self, rules))
294
+ args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
280
295
  return f"ruleset({args})", "ruleset"
281
296
  case CombinedRulesetDecl(rulesets):
282
297
  if ruleset_name:
@@ -298,6 +313,9 @@ class PrettyContext:
298
313
  return ruleset_str, "schedule"
299
314
  args = ", ".join(map(self, until))
300
315
  return f"run({ruleset_str}, {args})", "schedule"
316
+ case DefaultRewriteDecl():
317
+ msg = "default rewrites should not be pretty printed"
318
+ raise TypeError(msg)
301
319
  assert_never(decl)
302
320
 
303
321
  def _call(
@@ -345,12 +363,16 @@ class PrettyContext:
345
363
  has_multiple_parents = self.parents[first_arg] > 1
346
364
  self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
347
365
  # Set the first arg to be the name of the mutated arg and return the name
348
- args[0] = VarDecl(expr_name)
366
+ args[0] = VarDecl(expr_name, True)
349
367
  else:
350
368
  expr_name = None
351
369
  res = self._call_inner(ref, args, decl.bound_tp_params, parens)
352
370
  expr = (
353
- f"{res[0]}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
371
+ (
372
+ f"{name}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
373
+ if isinstance((name := res[0]), str)
374
+ else ((called := self._pretty_function_body(name, res[1])) if not parens else f"({called})")
375
+ )
354
376
  if isinstance(res, tuple)
355
377
  else res
356
378
  )
@@ -361,8 +383,12 @@ class PrettyContext:
361
383
  return expr, tp_name
362
384
 
363
385
  def _call_inner( # noqa: PLR0911
364
- self, ref: CallableRef, args: list[ExprDecl], bound_tp_params: tuple[JustTypeRef, ...] | None, parens: bool
365
- ) -> tuple[str, list[ExprDecl]] | str:
386
+ self,
387
+ ref: CallableRef,
388
+ args: list[ExprDecl],
389
+ bound_tp_params: tuple[JustTypeRef, ...] | None,
390
+ parens: bool,
391
+ ) -> tuple[str | UnnamedFunctionRef, list[ExprDecl]] | str:
366
392
  """
367
393
  Pretty print the call, returning either the full function call or a tuple of the function and the args.
368
394
  """
@@ -370,10 +396,8 @@ class PrettyContext:
370
396
  case FunctionRef(name):
371
397
  return name, args
372
398
  case ClassMethodRef(class_name, method_name):
373
- fn_str = str(JustTypeRef(class_name, bound_tp_params or ()))
374
- if method_name != "__init__":
375
- fn_str += f".{method_name}"
376
- return fn_str, args
399
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
400
+ return f"{tp_ref}.{method_name}", args
377
401
  case MethodRef(_class_name, method_name):
378
402
  slf, *args = args
379
403
  slf = self(slf, parens=True)
@@ -400,6 +424,11 @@ class PrettyContext:
400
424
  return f"{class_name}.{variable_name}"
401
425
  case PropertyRef(_class_name, property_name):
402
426
  return f"{self(args[0], parens=True)}.{property_name}"
427
+ case InitRef(class_name):
428
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
429
+ return str(tp_ref), args
430
+ case UnnamedFunctionRef():
431
+ return ref, args
403
432
  assert_never(ref)
404
433
 
405
434
  def _generate_name(self, typ: str) -> str:
@@ -420,27 +449,52 @@ class PrettyContext:
420
449
  self.statements.append(f"{name} = {expr_str}")
421
450
  return name
422
451
 
452
+ def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl]) -> str:
453
+ """
454
+ Returns a partial function call as a string.
455
+ """
456
+ match ref:
457
+ case FunctionRef(name):
458
+ fn = name
459
+ case UnnamedFunctionRef():
460
+ return self._pretty_function_body(ref, args)
461
+ case (
462
+ ClassMethodRef(class_name, method_name)
463
+ | MethodRef(class_name, method_name)
464
+ | PropertyRef(class_name, method_name)
465
+ ):
466
+ fn = f"{class_name}.{method_name}"
467
+ case InitRef(class_name):
468
+ fn = class_name
469
+ case ConstantRef(_):
470
+ msg = "Constants should not be callable"
471
+ raise NotImplementedError(msg)
472
+ case ClassVariableRef(_, _):
473
+ msg = "Class variables should not be callable"
474
+ raise NotADirectoryError(msg)
475
+ case _:
476
+ assert_never(ref)
477
+ if not args:
478
+ return fn
479
+ arg_strs = (
480
+ fn,
481
+ *(self(a, parens=False, unwrap_lit=True) for a in args),
482
+ )
483
+ return f"partial({', '.join(arg_strs)})"
423
484
 
424
- def _pretty_callable(ref: CallableRef) -> str:
425
- """
426
- Returns a function call as a string.
427
- """
428
- match ref:
429
- case FunctionRef(name):
430
- return name
431
- case (
432
- ClassMethodRef(class_name, method_name)
433
- | MethodRef(class_name, method_name)
434
- | PropertyRef(class_name, method_name)
435
- ):
436
- return f"{class_name}.{method_name}"
437
- case ConstantRef(_):
438
- msg = "Constants should not be callable"
439
- raise NotImplementedError(msg)
440
- case ClassVariableRef(_, _):
441
- msg = "Class variables should not be callable"
442
- raise NotADirectoryError(msg)
443
- assert_never(ref)
485
+ def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) -> str:
486
+ """
487
+ Pretty print the body of a function, partially applying some arguments.
488
+ """
489
+ var_args = fn.args
490
+ replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
491
+ var_args = var_args[len(args) :]
492
+ res = replace_typed_expr(fn.res, replacements)
493
+ arg_names = fn.args[len(args) :]
494
+ prefix = "lambda"
495
+ if arg_names:
496
+ prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
497
+ return f"{prefix}: {self(res.expr)}"
444
498
 
445
499
 
446
500
  def _plot_line_length(expr: object): # pragma: no cover