egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/pretty.py
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pretty printing for declerations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import ast
|
|
8
|
+
from collections import Counter, defaultdict
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import TYPE_CHECKING, TypeAlias, assert_never
|
|
11
|
+
|
|
12
|
+
import black
|
|
13
|
+
import cloudpickle
|
|
14
|
+
|
|
15
|
+
from .declarations import *
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Mapping
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"BINARY_METHODS",
|
|
23
|
+
"UNARY_METHODS",
|
|
24
|
+
"pretty_callable_ref",
|
|
25
|
+
"pretty_decl",
|
|
26
|
+
]
|
|
27
|
+
MAX_LINE_LENGTH = 110
|
|
28
|
+
LINE_DIFFERENCE = 10
|
|
29
|
+
BLACK_MODE = black.Mode(line_length=180)
|
|
30
|
+
|
|
31
|
+
# Use this special character in place of the args, so that if the args are inlined
|
|
32
|
+
# in the viz, they will replace it
|
|
33
|
+
ARG_STR = "·"
|
|
34
|
+
|
|
35
|
+
# Special methods which we might want to use as functions
|
|
36
|
+
# Mapping to the operator they represent for pretty printing them
|
|
37
|
+
# https://docs.python.org/3/reference/datamodel.html
|
|
38
|
+
BINARY_METHODS = {
|
|
39
|
+
"__lt__": "<",
|
|
40
|
+
"__le__": "<=",
|
|
41
|
+
"__eq__": "==",
|
|
42
|
+
"__ne__": "!=",
|
|
43
|
+
"__gt__": ">",
|
|
44
|
+
"__ge__": ">=",
|
|
45
|
+
# Numeric
|
|
46
|
+
"__add__": "+",
|
|
47
|
+
"__sub__": "-",
|
|
48
|
+
"__mul__": "*",
|
|
49
|
+
"__matmul__": "@",
|
|
50
|
+
"__truediv__": "/",
|
|
51
|
+
"__floordiv__": "//",
|
|
52
|
+
"__mod__": "%",
|
|
53
|
+
# TODO: Support divmod, with tuple return value
|
|
54
|
+
# "__divmod__": "divmod",
|
|
55
|
+
# TODO: Three arg power
|
|
56
|
+
"__pow__": "**",
|
|
57
|
+
"__lshift__": "<<",
|
|
58
|
+
"__rshift__": ">>",
|
|
59
|
+
"__and__": "&",
|
|
60
|
+
"__xor__": "^",
|
|
61
|
+
"__or__": "|",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
UNARY_METHODS = {
|
|
66
|
+
"__pos__": "+",
|
|
67
|
+
"__neg__": "-",
|
|
68
|
+
"__invert__": "~",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
NAMED_UNARY_METHODS = {
|
|
72
|
+
"__abs__": "abs",
|
|
73
|
+
"__round__": "round",
|
|
74
|
+
"__trunc__": "trunc",
|
|
75
|
+
"__floor__": "floor",
|
|
76
|
+
"__ceil__": "ceil",
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
AllDecls: TypeAlias = (
|
|
80
|
+
RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl | BackOffDecl
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def pretty_decl(
|
|
85
|
+
decls: Declarations, decl: AllDecls, *, wrapping_fn: str | None = None, ruleset_ident: Ident | None = None
|
|
86
|
+
) -> str:
|
|
87
|
+
"""
|
|
88
|
+
Pretty print a decleration.
|
|
89
|
+
|
|
90
|
+
This will use re-format the result and put the expression on the last line, preceeded by the statements.
|
|
91
|
+
"""
|
|
92
|
+
traverse = TraverseContext(decls)
|
|
93
|
+
traverse(decl, toplevel=True)
|
|
94
|
+
pretty = traverse.pretty()
|
|
95
|
+
expr = pretty(decl, ruleset_ident=ruleset_ident)
|
|
96
|
+
if wrapping_fn:
|
|
97
|
+
expr = f"{wrapping_fn}({expr})"
|
|
98
|
+
program = "\n".join([*pretty.statements, expr])
|
|
99
|
+
try:
|
|
100
|
+
# TODO: Try replacing with ruff for speed
|
|
101
|
+
# https://github.com/amyreese/ruff-api
|
|
102
|
+
return black.format_str(program, mode=BLACK_MODE).strip()
|
|
103
|
+
except black.parsing.InvalidInput:
|
|
104
|
+
return program
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def pretty_callable_ref(
|
|
108
|
+
decls: Declarations,
|
|
109
|
+
ref: CallableRef,
|
|
110
|
+
first_arg: ExprDecl | None = None,
|
|
111
|
+
bound_tp_params: tuple[JustTypeRef, ...] | None = None,
|
|
112
|
+
include_all_args: bool = False,
|
|
113
|
+
) -> str:
|
|
114
|
+
"""
|
|
115
|
+
Pretty print a callable reference, using a dummy value for
|
|
116
|
+
the args if the function is not in the form `f(x, ...)`.
|
|
117
|
+
|
|
118
|
+
To be used in the visualization.
|
|
119
|
+
"""
|
|
120
|
+
# Pass in three dummy args, which are the max used for any operation that
|
|
121
|
+
# is not a generic function call
|
|
122
|
+
args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * 3
|
|
123
|
+
if first_arg:
|
|
124
|
+
args.insert(0, first_arg)
|
|
125
|
+
context = PrettyContext(decls, defaultdict(lambda: 0))
|
|
126
|
+
res = context._call_inner(ref, args, bound_tp_params=bound_tp_params, parens=False)
|
|
127
|
+
# Either returns a function or a function with args. If args are provided, they would just be called,
|
|
128
|
+
# on the function, so return them, because they are dummies
|
|
129
|
+
if isinstance(res, tuple):
|
|
130
|
+
# If we want to include all args as ARG_STR, then we need to figure out how many to use
|
|
131
|
+
# used for set_cost so that `cost(E(...))` will show up as a call
|
|
132
|
+
if include_all_args:
|
|
133
|
+
signature = decls.get_callable_decl(ref).signature
|
|
134
|
+
assert isinstance(signature, FunctionSignature)
|
|
135
|
+
correct_args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * len(signature.arg_types)
|
|
136
|
+
return f"{res[0]}({', '.join(context(a, parens=False, unwrap_lit=True) for a in correct_args)})"
|
|
137
|
+
return res[0]
|
|
138
|
+
return res
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
|
|
142
|
+
# so that things like Math.__add__ will be represented properly
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class TraverseContext:
|
|
147
|
+
"""
|
|
148
|
+
State for traversing expressions (or declerations that contain expressions), so we can know how many parents each
|
|
149
|
+
expression has.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
decls: Declarations
|
|
153
|
+
|
|
154
|
+
# All expressions we have seen (incremented the parent counts of all children)
|
|
155
|
+
_seen: set[AllDecls] = field(default_factory=set)
|
|
156
|
+
# The number of parents for each expressions
|
|
157
|
+
parents: Counter[AllDecls] = field(default_factory=Counter)
|
|
158
|
+
|
|
159
|
+
def pretty(self) -> PrettyContext:
|
|
160
|
+
"""
|
|
161
|
+
Create a pretty context from the state of this traverse context.
|
|
162
|
+
"""
|
|
163
|
+
return PrettyContext(self.decls, self.parents)
|
|
164
|
+
|
|
165
|
+
def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901, PLR0912
|
|
166
|
+
if not toplevel:
|
|
167
|
+
self.parents[decl] += 1
|
|
168
|
+
if decl in self._seen:
|
|
169
|
+
return
|
|
170
|
+
match decl:
|
|
171
|
+
case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions):
|
|
172
|
+
self(lhs)
|
|
173
|
+
self(rhs)
|
|
174
|
+
for cond in conditions:
|
|
175
|
+
self(cond)
|
|
176
|
+
case RuleDecl(head, body, _):
|
|
177
|
+
for action in head:
|
|
178
|
+
self(action)
|
|
179
|
+
for fact in body:
|
|
180
|
+
self(fact)
|
|
181
|
+
case SetDecl(_, lhs, rhs) | UnionDecl(_, lhs, rhs) | EqDecl(_, lhs, rhs):
|
|
182
|
+
self(lhs)
|
|
183
|
+
self(rhs)
|
|
184
|
+
case LetDecl(_, d) | ExprActionDecl(d) | ExprFactDecl(d):
|
|
185
|
+
self(d.expr)
|
|
186
|
+
case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d):
|
|
187
|
+
self(d)
|
|
188
|
+
case PanicDecl(_) | UnboundVarDecl(_) | LetRefDecl(_) | LitDecl(_) | PyObjectDecl(_):
|
|
189
|
+
pass
|
|
190
|
+
case SequenceDecl(decls) | RulesetDecl(decls):
|
|
191
|
+
for de in decls:
|
|
192
|
+
if isinstance(de, DefaultRewriteDecl):
|
|
193
|
+
continue
|
|
194
|
+
self(de)
|
|
195
|
+
case CallDecl(ref, exprs, _) | GetCostDecl(ref, exprs):
|
|
196
|
+
match ref:
|
|
197
|
+
case FunctionRef(UnnamedFunctionRef(_, res)):
|
|
198
|
+
self(res.expr)
|
|
199
|
+
case _:
|
|
200
|
+
for e in exprs:
|
|
201
|
+
self(e.expr)
|
|
202
|
+
case RunDecl(_, until, scheduler):
|
|
203
|
+
if until:
|
|
204
|
+
for f in until:
|
|
205
|
+
self(f)
|
|
206
|
+
if scheduler:
|
|
207
|
+
self(scheduler)
|
|
208
|
+
case PartialCallDecl(c):
|
|
209
|
+
self(c)
|
|
210
|
+
case CombinedRulesetDecl(_):
|
|
211
|
+
pass
|
|
212
|
+
case DefaultRewriteDecl():
|
|
213
|
+
pass
|
|
214
|
+
case SetCostDecl(_, e, c):
|
|
215
|
+
self(e)
|
|
216
|
+
self(c)
|
|
217
|
+
case BackOffDecl() | ValueDecl():
|
|
218
|
+
pass
|
|
219
|
+
case LetSchedulerDecl(scheduler, schedule):
|
|
220
|
+
self(scheduler)
|
|
221
|
+
self(schedule)
|
|
222
|
+
case GetCostDecl(ref, args):
|
|
223
|
+
self(CallDecl(ref, args))
|
|
224
|
+
case _:
|
|
225
|
+
assert_never(decl)
|
|
226
|
+
|
|
227
|
+
self._seen.add(decl)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class PrettyContext:
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
We need to build up a list of all the expressions we are pretty printing, so that we can see who has parents and who is mutated
|
|
235
|
+
and create temp variables for them.
|
|
236
|
+
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
decls: Declarations
|
|
240
|
+
parents: Mapping[AllDecls, int]
|
|
241
|
+
|
|
242
|
+
# All the expressions we have saved as names
|
|
243
|
+
names: dict[AllDecls, str] = field(default_factory=dict)
|
|
244
|
+
# A list of statements assigning variables or calling destructive ops
|
|
245
|
+
statements: list[str] = field(default_factory=list)
|
|
246
|
+
# Mapping of type to the number of times we have generated a name for that type, used to generate unique names
|
|
247
|
+
_gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
|
|
248
|
+
|
|
249
|
+
def __call__(
|
|
250
|
+
self, decl: AllDecls, *, unwrap_lit: bool = False, parens: bool = False, ruleset_ident: Ident | None = None
|
|
251
|
+
) -> str:
|
|
252
|
+
if decl in self.names:
|
|
253
|
+
return self.names[decl]
|
|
254
|
+
expr, tp_name = self.uncached(decl, unwrap_lit=unwrap_lit, parens=parens, ruleset_ident=ruleset_ident)
|
|
255
|
+
# We use a heuristic to decide whether to name this sub-expression as a variable
|
|
256
|
+
# The rough goal is to reduce the number of newlines, given our line length of ~180
|
|
257
|
+
# We determine it's worth making a new line for this expression if the total characters
|
|
258
|
+
# it would take up is > than some constant (~ line length).
|
|
259
|
+
line_diff: int = len(expr) - LINE_DIFFERENCE
|
|
260
|
+
n_parents = self.parents[decl]
|
|
261
|
+
if n_parents > 1 and (
|
|
262
|
+
n_parents * line_diff > MAX_LINE_LENGTH
|
|
263
|
+
# Schedulers with multiple parents need to be the same object, b/c are created with hidden UUIDs
|
|
264
|
+
or tp_name == "scheduler"
|
|
265
|
+
):
|
|
266
|
+
self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False)
|
|
267
|
+
return expr_name
|
|
268
|
+
return expr
|
|
269
|
+
|
|
270
|
+
def uncached( # noqa: C901, PLR0911, PLR0912
|
|
271
|
+
self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_ident: Ident | None
|
|
272
|
+
) -> tuple[str, str]:
|
|
273
|
+
"""
|
|
274
|
+
Returns a tuple of a string value of the decleration and the "type" to use when create a memoized cached version
|
|
275
|
+
for de-duplication.
|
|
276
|
+
"""
|
|
277
|
+
match decl:
|
|
278
|
+
case LitDecl(value):
|
|
279
|
+
match value:
|
|
280
|
+
case None:
|
|
281
|
+
return "Unit()", "Unit"
|
|
282
|
+
case bool(b):
|
|
283
|
+
return str(b) if unwrap_lit else f"Bool({b})", "Bool"
|
|
284
|
+
case int(i):
|
|
285
|
+
return str(i) if unwrap_lit else f"i64({i})", "i64"
|
|
286
|
+
case float(f):
|
|
287
|
+
return str(f) if unwrap_lit else f"f64({f})", "f64"
|
|
288
|
+
case str(s):
|
|
289
|
+
return repr(s) if unwrap_lit else f"String({s!r})", "String"
|
|
290
|
+
assert_never(value)
|
|
291
|
+
case UnboundVarDecl(name) | LetRefDecl(name):
|
|
292
|
+
return name, name
|
|
293
|
+
case CallDecl(_, _, _):
|
|
294
|
+
return self._call(decl, parens)
|
|
295
|
+
case PartialCallDecl(CallDecl(ref, typed_args, _)):
|
|
296
|
+
return self._pretty_partial(ref, [a.expr for a in typed_args], parens), "fn"
|
|
297
|
+
case PyObjectDecl(pickled):
|
|
298
|
+
value = cloudpickle.loads(pickled)
|
|
299
|
+
value_str = repr(value)
|
|
300
|
+
if not is_valid_python_expr(value_str):
|
|
301
|
+
# If this isn't a valid python expr, represent as string
|
|
302
|
+
value_str = f"eval({value_str!r})"
|
|
303
|
+
return value_str if unwrap_lit else f"PyObject({value_str})", "PyObject"
|
|
304
|
+
case ActionCommandDecl(action):
|
|
305
|
+
return self(action), "action"
|
|
306
|
+
case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions):
|
|
307
|
+
args = ", ".join(map(self, (rhs, *conditions)))
|
|
308
|
+
fn = "rewrite" if isinstance(decl, RewriteDecl) else "birewrite"
|
|
309
|
+
return f"{fn}({self(lhs)}).to({args})", "rewrite"
|
|
310
|
+
case RuleDecl(head, body, name):
|
|
311
|
+
l = ", ".join(map(self, body))
|
|
312
|
+
if name:
|
|
313
|
+
l += f", name={name}"
|
|
314
|
+
r = ", ".join(map(self, head))
|
|
315
|
+
return f"rule({l}).then({r})", "rule"
|
|
316
|
+
case SetDecl(_, lhs, rhs):
|
|
317
|
+
return f"set_({self(lhs)}).to({self(rhs)})", "action"
|
|
318
|
+
case UnionDecl(_, lhs, rhs):
|
|
319
|
+
return f"union({self(lhs)}).with_({self(rhs)})", "action"
|
|
320
|
+
case LetDecl(name, expr):
|
|
321
|
+
return f"let({name!r}, {self(expr.expr)})", "action"
|
|
322
|
+
case ExprActionDecl(expr):
|
|
323
|
+
return self(expr.expr), "action"
|
|
324
|
+
case ExprFactDecl(expr):
|
|
325
|
+
return self(expr.expr), "fact"
|
|
326
|
+
case ChangeDecl(_, expr, change):
|
|
327
|
+
return f"{change}({self(expr)})", "action"
|
|
328
|
+
case PanicDecl(s):
|
|
329
|
+
return f"panic({s!r})", "action"
|
|
330
|
+
case SetCostDecl(_, expr, cost):
|
|
331
|
+
return f"set_cost({self(expr)}, {self(cost, unwrap_lit=True)})", "action"
|
|
332
|
+
case EqDecl(_, left, right):
|
|
333
|
+
return f"eq({self(left)}).to({self(right)})", "fact"
|
|
334
|
+
case RulesetDecl(rules):
|
|
335
|
+
if ruleset_ident:
|
|
336
|
+
return f"ruleset(name={ruleset_ident.name!r})", f"ruleset_{ruleset_ident.name}"
|
|
337
|
+
args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
|
|
338
|
+
return f"ruleset({args})", "ruleset"
|
|
339
|
+
case CombinedRulesetDecl(rulesets):
|
|
340
|
+
list_args = [r.name for r in rulesets]
|
|
341
|
+
if ruleset_ident:
|
|
342
|
+
list_args.append(f"name={ruleset_ident.name!r})")
|
|
343
|
+
return (f"unstable_combine_rulesets({', '.join(list_args)})", "combined_ruleset")
|
|
344
|
+
case SaturateDecl(schedule):
|
|
345
|
+
return f"{self(schedule, parens=True)}.saturate()", "schedule"
|
|
346
|
+
case RepeatDecl(schedule, times):
|
|
347
|
+
return f"{self(schedule, parens=True)} * {times}", "schedule"
|
|
348
|
+
case SequenceDecl(schedules):
|
|
349
|
+
if len(schedules) == 2:
|
|
350
|
+
return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule"
|
|
351
|
+
args = ", ".join(map(self, schedules))
|
|
352
|
+
return f"seq({args})", "schedule"
|
|
353
|
+
case LetSchedulerDecl(scheduler, schedule):
|
|
354
|
+
return f"{self(scheduler, parens=True)}.scope({self(schedule, parens=True)})", "schedule"
|
|
355
|
+
case RunDecl(ruleset_ident, until, scheduler):
|
|
356
|
+
ruleset = self.decls._rulesets[ruleset_ident]
|
|
357
|
+
ruleset_str = self(ruleset, ruleset_ident=ruleset_ident)
|
|
358
|
+
if not until and not scheduler:
|
|
359
|
+
return ruleset_str, "schedule"
|
|
360
|
+
arg_lst = list(map(self, until or []))
|
|
361
|
+
if scheduler:
|
|
362
|
+
arg_lst.append(f"scheduler={self(scheduler)}")
|
|
363
|
+
return f"run({ruleset_str}, {', '.join(arg_lst)})", "schedule"
|
|
364
|
+
case DefaultRewriteDecl():
|
|
365
|
+
msg = "default rewrites should not be pretty printed"
|
|
366
|
+
raise TypeError(msg)
|
|
367
|
+
case BackOffDecl(_, match_limit, ban_length):
|
|
368
|
+
list_args = []
|
|
369
|
+
if match_limit is not None:
|
|
370
|
+
list_args.append(f"match_limit={match_limit}")
|
|
371
|
+
if ban_length is not None:
|
|
372
|
+
list_args.append(f"ban_length={ban_length}")
|
|
373
|
+
return f"back_off({', '.join(list_args)})", "scheduler"
|
|
374
|
+
case ValueDecl(value):
|
|
375
|
+
return str(value), "value"
|
|
376
|
+
case GetCostDecl(ref, args):
|
|
377
|
+
return f"get_cost({self(CallDecl(ref, args))})", "get_cost"
|
|
378
|
+
assert_never(decl)
|
|
379
|
+
|
|
380
|
+
def _call(
|
|
381
|
+
self,
|
|
382
|
+
decl: CallDecl,
|
|
383
|
+
parens: bool,
|
|
384
|
+
) -> tuple[str, str]:
|
|
385
|
+
"""
|
|
386
|
+
Pretty print the call. Also returns if it was saved as a name.
|
|
387
|
+
|
|
388
|
+
:param parens: If true, wrap the call in parens if it is a binary method call.
|
|
389
|
+
"""
|
|
390
|
+
args = [a.expr for a in decl.args]
|
|
391
|
+
ref = decl.callable
|
|
392
|
+
# Special case !=
|
|
393
|
+
if decl.callable == FunctionRef(Ident.builtin("!=")):
|
|
394
|
+
l, r = self(args[0]), self(args[1])
|
|
395
|
+
return f"ne({l}).to({r})", "Unit"
|
|
396
|
+
signature = self.decls.get_callable_decl(ref).signature
|
|
397
|
+
|
|
398
|
+
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
|
|
399
|
+
n_defaults = 0
|
|
400
|
+
# Dont try counting defaults for function application
|
|
401
|
+
if isinstance(signature, FunctionSignature):
|
|
402
|
+
for arg, default in zip(
|
|
403
|
+
reversed(args), reversed(signature.arg_defaults), strict=not signature.var_arg_type
|
|
404
|
+
):
|
|
405
|
+
if arg != default:
|
|
406
|
+
break
|
|
407
|
+
n_defaults += 1
|
|
408
|
+
if n_defaults:
|
|
409
|
+
args = args[:-n_defaults]
|
|
410
|
+
|
|
411
|
+
# If this is a function application, the type is the first type arg of the function object
|
|
412
|
+
if signature == "fn-app":
|
|
413
|
+
tp_name = decl.args[0].tp.args[0].ident.name
|
|
414
|
+
else:
|
|
415
|
+
assert isinstance(signature, FunctionSignature)
|
|
416
|
+
tp_name = signature.semantic_return_type.ident.name
|
|
417
|
+
if isinstance(signature, FunctionSignature) and signature.mutates:
|
|
418
|
+
first_arg = args[0]
|
|
419
|
+
expr_str = self(first_arg)
|
|
420
|
+
# copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly)
|
|
421
|
+
has_multiple_parents = self.parents[first_arg] > 1
|
|
422
|
+
self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
|
|
423
|
+
# Set the first arg to be the name of the mutated arg and return the name
|
|
424
|
+
args[0] = LetRefDecl(expr_name)
|
|
425
|
+
else:
|
|
426
|
+
expr_name = None
|
|
427
|
+
res = self._call_inner(ref, args, decl.bound_tp_params, parens)
|
|
428
|
+
expr = (
|
|
429
|
+
(f"{res[0]}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})")
|
|
430
|
+
if isinstance(res, tuple)
|
|
431
|
+
else res
|
|
432
|
+
)
|
|
433
|
+
# If we have a name, then we mutated
|
|
434
|
+
if expr_name:
|
|
435
|
+
self.statements.append(expr)
|
|
436
|
+
return expr_name, tp_name
|
|
437
|
+
return expr, tp_name
|
|
438
|
+
|
|
439
|
+
def _call_inner( # noqa: C901, PLR0911, PLR0912
|
|
440
|
+
self,
|
|
441
|
+
ref: CallableRef,
|
|
442
|
+
args: list[ExprDecl],
|
|
443
|
+
bound_tp_params: tuple[JustTypeRef, ...] | None,
|
|
444
|
+
parens: bool,
|
|
445
|
+
) -> tuple[str, list[ExprDecl]] | str:
|
|
446
|
+
"""
|
|
447
|
+
Pretty print the call, returning either the full function call or a tuple of the function and the args.
|
|
448
|
+
"""
|
|
449
|
+
match ref:
|
|
450
|
+
case FunctionRef(Ident(name)):
|
|
451
|
+
return name, args
|
|
452
|
+
case ClassMethodRef(class_name, method_name):
|
|
453
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
454
|
+
return f"{tp_ref}.{method_name}", args
|
|
455
|
+
case MethodRef(_class_name, method_name):
|
|
456
|
+
slf, *args = args
|
|
457
|
+
non_str_slf = slf
|
|
458
|
+
slf = self(slf, parens=True)
|
|
459
|
+
match method_name:
|
|
460
|
+
case _ if method_name in UNARY_METHODS:
|
|
461
|
+
expr = f"{UNARY_METHODS[method_name]}{slf}"
|
|
462
|
+
return f"({expr})" if parens else expr
|
|
463
|
+
case _ if method_name in BINARY_METHODS:
|
|
464
|
+
expr = f"{slf} {BINARY_METHODS[method_name]} {self(args[0], parens=True, unwrap_lit=True)}"
|
|
465
|
+
return f"({expr})" if parens else expr
|
|
466
|
+
case "__getitem__":
|
|
467
|
+
return f"{slf}[{self(args[0], unwrap_lit=True)}]"
|
|
468
|
+
case "__call__":
|
|
469
|
+
return slf, args
|
|
470
|
+
case "__delitem__":
|
|
471
|
+
return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
|
|
472
|
+
case "__setitem__":
|
|
473
|
+
return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
|
|
474
|
+
case _ if method_name in NAMED_UNARY_METHODS:
|
|
475
|
+
return NAMED_UNARY_METHODS[method_name], [non_str_slf, *args]
|
|
476
|
+
case "__getattr__" if isinstance(args[0], LitDecl) and isinstance(args[0].value, str):
|
|
477
|
+
return f"{slf}.{args[0].value}"
|
|
478
|
+
case _:
|
|
479
|
+
return f"{slf}.{method_name}", args
|
|
480
|
+
case ConstantRef(Ident(name)):
|
|
481
|
+
return name
|
|
482
|
+
case ClassVariableRef(Ident(class_name), variable_name):
|
|
483
|
+
return f"{class_name}.{variable_name}"
|
|
484
|
+
case PropertyRef(_class_name, property_name):
|
|
485
|
+
return f"{self(args[0], parens=True)}.{property_name}"
|
|
486
|
+
case InitRef(class_name):
|
|
487
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
488
|
+
return str(tp_ref), args
|
|
489
|
+
case UnnamedFunctionRef():
|
|
490
|
+
expr = self._pretty_function_body(ref, [])
|
|
491
|
+
return f"({expr})", args
|
|
492
|
+
assert_never(ref)
|
|
493
|
+
|
|
494
|
+
def _generate_name(self, typ: str) -> str:
|
|
495
|
+
self._gen_name_types[typ] += 1
|
|
496
|
+
return f"_{typ}_{self._gen_name_types[typ]}"
|
|
497
|
+
|
|
498
|
+
def _name_expr(self, tp_name: str, expr_str: str, copy_identifier: bool) -> str:
|
|
499
|
+
# tp_name =
|
|
500
|
+
# If the thing we are naming is already a variable, we don't need to name it
|
|
501
|
+
if expr_str.isidentifier():
|
|
502
|
+
if copy_identifier:
|
|
503
|
+
name = self._generate_name(tp_name)
|
|
504
|
+
self.statements.append(f"{name} = copy({expr_str})")
|
|
505
|
+
else:
|
|
506
|
+
name = expr_str
|
|
507
|
+
else:
|
|
508
|
+
name = self._generate_name(tp_name)
|
|
509
|
+
self.statements.append(f"{name} = {expr_str}")
|
|
510
|
+
return name
|
|
511
|
+
|
|
512
|
+
def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl], parens: bool) -> str:
|
|
513
|
+
"""
|
|
514
|
+
Returns a partial function call as a string.
|
|
515
|
+
"""
|
|
516
|
+
match ref:
|
|
517
|
+
case FunctionRef(Ident(name)):
|
|
518
|
+
fn = name
|
|
519
|
+
case UnnamedFunctionRef():
|
|
520
|
+
res = self._pretty_function_body(ref, args)
|
|
521
|
+
return f"({res})" if parens else res
|
|
522
|
+
case (
|
|
523
|
+
ClassMethodRef(Ident(class_name), method_name)
|
|
524
|
+
| MethodRef(Ident(class_name), method_name)
|
|
525
|
+
| PropertyRef(Ident(class_name), method_name)
|
|
526
|
+
):
|
|
527
|
+
fn = f"{class_name}.{method_name}"
|
|
528
|
+
case InitRef(Ident(class_name)):
|
|
529
|
+
fn = class_name
|
|
530
|
+
case ConstantRef(_):
|
|
531
|
+
msg = "Constants should not be callable"
|
|
532
|
+
raise NotImplementedError(msg)
|
|
533
|
+
case ClassVariableRef(_, _):
|
|
534
|
+
msg = "Class variables should not be callable"
|
|
535
|
+
raise NotADirectoryError(msg)
|
|
536
|
+
case _:
|
|
537
|
+
assert_never(ref)
|
|
538
|
+
if not args:
|
|
539
|
+
return fn
|
|
540
|
+
arg_strs = (
|
|
541
|
+
fn,
|
|
542
|
+
*(self(a, parens=False, unwrap_lit=True) for a in args),
|
|
543
|
+
)
|
|
544
|
+
return f"partial({', '.join(arg_strs)})"
|
|
545
|
+
|
|
546
|
+
def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) -> str:
|
|
547
|
+
"""
|
|
548
|
+
Pretty print the body of a function, partially applying some arguments.
|
|
549
|
+
"""
|
|
550
|
+
var_args = fn.args
|
|
551
|
+
replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
|
|
552
|
+
var_args = var_args[len(args) :]
|
|
553
|
+
res = replace_typed_expr(fn.res, replacements)
|
|
554
|
+
arg_names = fn.args[len(args) :]
|
|
555
|
+
prefix = "lambda"
|
|
556
|
+
if arg_names:
|
|
557
|
+
prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
|
|
558
|
+
return f"{prefix}: {self(res.expr)}"
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def is_valid_python_expr(s: str) -> bool:
|
|
562
|
+
try:
|
|
563
|
+
ast.parse(s, mode="eval")
|
|
564
|
+
except SyntaxError:
|
|
565
|
+
return False
|
|
566
|
+
return True
|
egglog/py.typed
ADDED
|
File without changes
|