egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/exp/any_expr.py
ADDED
|
@@ -0,0 +1,947 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WIP
|
|
3
|
+
|
|
4
|
+
An `AnyExpr`, which can be used to trace arbitrary expressions.
|
|
5
|
+
|
|
6
|
+
Created from any Python object, it should forward any operations on it to the underlying Python object.
|
|
7
|
+
|
|
8
|
+
This will only happen when it needs to be "materialized" however, through operations like `__bool__` or `__iter__`.
|
|
9
|
+
|
|
10
|
+
Generally it will try to avoid materializing the underlying object, and instead just treat it as a black box.
|
|
11
|
+
"""
|
|
12
|
+
# mypy: disable-error-code="empty-body"
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import contextlib
|
|
17
|
+
import math
|
|
18
|
+
import operator
|
|
19
|
+
from collections.abc import Iterator
|
|
20
|
+
from copy import copy
|
|
21
|
+
from functools import reduce
|
|
22
|
+
from typing import Any, TypeAlias
|
|
23
|
+
|
|
24
|
+
from egglog import *
|
|
25
|
+
from egglog.exp.program_gen import *
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AnyExpr(Expr):
|
|
29
|
+
"""
|
|
30
|
+
Wraps an arbitrary Python object.
|
|
31
|
+
|
|
32
|
+
Any operations on it will be forwarded to the underlying object when needed.
|
|
33
|
+
|
|
34
|
+
Attempts to implement as many operations from https://docs.python.org/3/reference/datamodel.html as possible.
|
|
35
|
+
|
|
36
|
+
Can be converted from any Python object:
|
|
37
|
+
|
|
38
|
+
>>> AnyExpr(42) + 42
|
|
39
|
+
AnyExpr(A(42) + A(42))
|
|
40
|
+
|
|
41
|
+
Will also convert tuples and lists item by item:
|
|
42
|
+
|
|
43
|
+
>>> AnyExpr((1, 2,)) + (5, 6)
|
|
44
|
+
AnyExpr(append(append(A(()), A(1)), A(2)) + append(append(A(()), A(5)), A(6)))
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, obj: ALike) -> None: ...
|
|
48
|
+
|
|
49
|
+
__match_args__ = ("egglog_any_expr_value",)
|
|
50
|
+
|
|
51
|
+
@method(preserve=True) # type: ignore[prop-decorator]
|
|
52
|
+
@property
|
|
53
|
+
def egglog_any_expr_value(self) -> A:
|
|
54
|
+
"""
|
|
55
|
+
Return the underlying Python object, if it was constructued with one.
|
|
56
|
+
|
|
57
|
+
Long method name so it doesn't conflict with any user-defined properties.
|
|
58
|
+
|
|
59
|
+
>>> AnyExpr(10).egglog_any_expr_value
|
|
60
|
+
A(10)
|
|
61
|
+
"""
|
|
62
|
+
match get_callable_args(self, AnyExpr):
|
|
63
|
+
case (A() as any_expr,):
|
|
64
|
+
return any_expr
|
|
65
|
+
raise ExprValueError(self, "AnyExpr")
|
|
66
|
+
|
|
67
|
+
@method(preserve=True)
|
|
68
|
+
def __bytes__(self) -> bytes:
|
|
69
|
+
"""
|
|
70
|
+
>>> bytes(AnyExpr(b"hello"))
|
|
71
|
+
b'hello'
|
|
72
|
+
"""
|
|
73
|
+
return any_eval(bytes_(self))
|
|
74
|
+
|
|
75
|
+
@method(preserve=True)
|
|
76
|
+
def __bool__(self) -> bool:
|
|
77
|
+
"""
|
|
78
|
+
>>> bool(AnyExpr(True))
|
|
79
|
+
True
|
|
80
|
+
>>> bool(AnyExpr(False))
|
|
81
|
+
False
|
|
82
|
+
"""
|
|
83
|
+
return any_eval(bool_(self))
|
|
84
|
+
|
|
85
|
+
@method(preserve=True)
|
|
86
|
+
def __eq__(self, other: object) -> AnyExpr: # type: ignore[override]
|
|
87
|
+
"""
|
|
88
|
+
>>> bool(AnyExpr(1) == AnyExpr(1))
|
|
89
|
+
True
|
|
90
|
+
>>> bool(AnyExpr(1) == AnyExpr(2))
|
|
91
|
+
False
|
|
92
|
+
"""
|
|
93
|
+
return with_assert(self.egglog_any_expr_value == other)
|
|
94
|
+
|
|
95
|
+
@method(preserve=True)
|
|
96
|
+
def __ne__(self, other: object) -> AnyExpr: # type: ignore[override]
|
|
97
|
+
"""
|
|
98
|
+
>>> bool(AnyExpr(1) != AnyExpr(2))
|
|
99
|
+
True
|
|
100
|
+
>>> bool(AnyExpr(1) != AnyExpr(1))
|
|
101
|
+
False
|
|
102
|
+
"""
|
|
103
|
+
return with_assert(self.egglog_any_expr_value != other)
|
|
104
|
+
|
|
105
|
+
@method(preserve=True)
|
|
106
|
+
def __lt__(self, other: object) -> AnyExpr:
|
|
107
|
+
"""
|
|
108
|
+
>>> bool(AnyExpr(1) < AnyExpr(2))
|
|
109
|
+
True
|
|
110
|
+
>>> bool(AnyExpr(2) < AnyExpr(1))
|
|
111
|
+
False
|
|
112
|
+
"""
|
|
113
|
+
return with_assert(self.egglog_any_expr_value < other)
|
|
114
|
+
|
|
115
|
+
@method(preserve=True)
|
|
116
|
+
def __le__(self, other: object) -> AnyExpr:
|
|
117
|
+
"""
|
|
118
|
+
>>> bool(AnyExpr(2) <= AnyExpr(2))
|
|
119
|
+
True
|
|
120
|
+
>>> bool(AnyExpr(3) <= AnyExpr(2))
|
|
121
|
+
False
|
|
122
|
+
"""
|
|
123
|
+
return with_assert(self.egglog_any_expr_value <= other)
|
|
124
|
+
|
|
125
|
+
@method(preserve=True)
|
|
126
|
+
def __gt__(self, other: object) -> AnyExpr:
|
|
127
|
+
"""
|
|
128
|
+
>>> bool(AnyExpr(3) > AnyExpr(2))
|
|
129
|
+
True
|
|
130
|
+
>>> bool(AnyExpr(2) > AnyExpr(3))
|
|
131
|
+
False
|
|
132
|
+
"""
|
|
133
|
+
return with_assert(self.egglog_any_expr_value > other)
|
|
134
|
+
|
|
135
|
+
@method(preserve=True)
|
|
136
|
+
def __ge__(self, other: object) -> AnyExpr:
|
|
137
|
+
"""
|
|
138
|
+
>>> bool(AnyExpr(3) >= AnyExpr(3))
|
|
139
|
+
True
|
|
140
|
+
>>> bool(AnyExpr(2) >= AnyExpr(3))
|
|
141
|
+
False
|
|
142
|
+
"""
|
|
143
|
+
return with_assert(self.egglog_any_expr_value >= other)
|
|
144
|
+
|
|
145
|
+
@method(preserve=True)
|
|
146
|
+
def __hash__(self) -> int:
|
|
147
|
+
"""
|
|
148
|
+
Turn the underlying object into a hash.
|
|
149
|
+
|
|
150
|
+
>>> hash(AnyExpr("hello")) == hash("hello")
|
|
151
|
+
True
|
|
152
|
+
"""
|
|
153
|
+
return hash(any_eval(self.egglog_any_expr_value))
|
|
154
|
+
|
|
155
|
+
@method(preserve=True)
|
|
156
|
+
def __getattr__(self, name: StringLike) -> AnyExpr | Any:
|
|
157
|
+
"""
|
|
158
|
+
Get an attribute of the underlying object.
|
|
159
|
+
|
|
160
|
+
>>> int(AnyExpr([1, 2, 3]).index(2))
|
|
161
|
+
1
|
|
162
|
+
|
|
163
|
+
Also should work with hasattr:
|
|
164
|
+
>>> hasattr(AnyExpr([1, 2, 3]), "index")
|
|
165
|
+
True
|
|
166
|
+
>>> hasattr(AnyExpr([1, 2, 3]), "nonexistent")
|
|
167
|
+
False
|
|
168
|
+
"""
|
|
169
|
+
inner = self.egglog_any_expr_value
|
|
170
|
+
# Need to raise attribute error if it doesn't exist, since this is called for hasattr
|
|
171
|
+
if not any_eval(hasattr_(inner, name)):
|
|
172
|
+
raise AttributeError(f"{self} has no attribute {name}")
|
|
173
|
+
egraph = _get_current_egraph()
|
|
174
|
+
res = inner.__getattr__(name)
|
|
175
|
+
egraph.register(res)
|
|
176
|
+
egraph.run(any_expr_schedule)
|
|
177
|
+
if egraph.check_bool(getattr_eager(inner, name)):
|
|
178
|
+
return any_eval(res)
|
|
179
|
+
return with_assert(res)
|
|
180
|
+
|
|
181
|
+
# TODO: Not working for now
|
|
182
|
+
# @method(mutates_self=True)
|
|
183
|
+
# def __setattr__(self, name: StringLike, value: object) -> None:
|
|
184
|
+
# """
|
|
185
|
+
# Set an attribute of the underlying object.
|
|
186
|
+
|
|
187
|
+
# >>> x = lambda: None
|
|
188
|
+
# >>> expr = AnyExpr(x)
|
|
189
|
+
# >>> expr.attr = 42
|
|
190
|
+
# >>> int(expr.attr)
|
|
191
|
+
# 42
|
|
192
|
+
# """
|
|
193
|
+
|
|
194
|
+
# TODO: delattr
|
|
195
|
+
# TODO: __get__/__set__?
|
|
196
|
+
|
|
197
|
+
@method(preserve=True)
|
|
198
|
+
def __len__(self) -> int:
|
|
199
|
+
"""
|
|
200
|
+
Get the length of the underlying object.
|
|
201
|
+
|
|
202
|
+
>>> len(AnyExpr([1, 2, 3]))
|
|
203
|
+
3
|
|
204
|
+
"""
|
|
205
|
+
return any_eval(len_(self))
|
|
206
|
+
|
|
207
|
+
@method(preserve=True)
|
|
208
|
+
def __call__(self, *args: object, **kwargs: object) -> AnyExpr:
|
|
209
|
+
"""
|
|
210
|
+
Call the underlying object.
|
|
211
|
+
|
|
212
|
+
>>> int(AnyExpr(int)(42))
|
|
213
|
+
42
|
|
214
|
+
>>> int(AnyExpr(lambda *x, **y: len(x) + len(y))(1, 2, a=3, b=4))
|
|
215
|
+
4
|
|
216
|
+
"""
|
|
217
|
+
args_expr = A(())
|
|
218
|
+
for a in args:
|
|
219
|
+
args_expr = append(args_expr, a)
|
|
220
|
+
kwargs_expr = A({})
|
|
221
|
+
for k, v in kwargs.items():
|
|
222
|
+
kwargs_expr = set_kwarg(kwargs_expr, k, v)
|
|
223
|
+
return with_assert(self.egglog_any_expr_value(args_expr, kwargs_expr))
|
|
224
|
+
|
|
225
|
+
@method(preserve=True)
|
|
226
|
+
def __getitem__(self, key: object) -> AnyExpr:
|
|
227
|
+
"""
|
|
228
|
+
Get an item from the underlying object.
|
|
229
|
+
|
|
230
|
+
>>> int(AnyExpr([1, 2, 3])[1])
|
|
231
|
+
2
|
|
232
|
+
"""
|
|
233
|
+
return with_assert(self.egglog_any_expr_value[key])
|
|
234
|
+
|
|
235
|
+
@method(preserve=True)
|
|
236
|
+
def __setitem__(self, key: object, value: object) -> None:
|
|
237
|
+
"""
|
|
238
|
+
Set an item in the underlying object.
|
|
239
|
+
|
|
240
|
+
>>> x = [1, 2, 3]
|
|
241
|
+
>>> expr = AnyExpr(x)
|
|
242
|
+
>>> expr[1] = 42
|
|
243
|
+
>>> int(expr[1])
|
|
244
|
+
42
|
|
245
|
+
"""
|
|
246
|
+
any_expr_inner = self.egglog_any_expr_value
|
|
247
|
+
any_expr_inner[key] = value
|
|
248
|
+
self.__replace_expr__(AnyExpr(with_assert(any_expr_inner)))
|
|
249
|
+
|
|
250
|
+
@method(preserve=True)
|
|
251
|
+
def __delitem__(self, key: object) -> None:
|
|
252
|
+
"""
|
|
253
|
+
Delete an item from the underlying object.
|
|
254
|
+
|
|
255
|
+
>>> x = [1, 2, 3]
|
|
256
|
+
>>> expr = AnyExpr(x)
|
|
257
|
+
>>> del expr[1]
|
|
258
|
+
>>> len(expr)
|
|
259
|
+
2
|
|
260
|
+
"""
|
|
261
|
+
any_expr_inner = self.egglog_any_expr_value
|
|
262
|
+
del any_expr_inner[key]
|
|
263
|
+
self.__replace_expr__(AnyExpr(with_assert(any_expr_inner)))
|
|
264
|
+
|
|
265
|
+
# TODO: support real iterators
|
|
266
|
+
@method(preserve=True)
|
|
267
|
+
def __iter__(self) -> Iterator[AnyExpr]:
|
|
268
|
+
"""
|
|
269
|
+
Iterate over the underlying object.
|
|
270
|
+
|
|
271
|
+
>>> list(AnyExpr((1, 2)))
|
|
272
|
+
[AnyExpr(append(append(A(()), A(1)), A(2))[A(0)]), AnyExpr(append(append(A(()), A(1)), A(2))[A(1)])]
|
|
273
|
+
"""
|
|
274
|
+
return iter(self[i] for i in range(len(self)))
|
|
275
|
+
|
|
276
|
+
# TODO: Not working for now
|
|
277
|
+
# @method(preserve=True)
|
|
278
|
+
# def __reversed__(self) -> Iterator[AnyExpr]:
|
|
279
|
+
# """
|
|
280
|
+
# Reverse iterate over the underlying object.
|
|
281
|
+
|
|
282
|
+
# >>> list(reversed(AnyExpr([1, 2, 3])))
|
|
283
|
+
# [AnyExpr(3), AnyExpr(2), AnyExpr(1)]
|
|
284
|
+
# """
|
|
285
|
+
# return map(AnyExpr, any_eval(reversed_op(self)))
|
|
286
|
+
|
|
287
|
+
@method(preserve=True)
|
|
288
|
+
def __contains__(self, item: object) -> bool:
|
|
289
|
+
"""
|
|
290
|
+
Check if the underlying object contains an item.
|
|
291
|
+
|
|
292
|
+
>>> class A:
|
|
293
|
+
... def __contains__(self, item):
|
|
294
|
+
... return item == 42
|
|
295
|
+
>>> 42 in AnyExpr(A())
|
|
296
|
+
True
|
|
297
|
+
>>> 2 in AnyExpr(A())
|
|
298
|
+
False
|
|
299
|
+
"""
|
|
300
|
+
return any_eval(contains(self, item))
|
|
301
|
+
|
|
302
|
+
##
|
|
303
|
+
# Emulating numeric types
|
|
304
|
+
# https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
|
|
305
|
+
##
|
|
306
|
+
|
|
307
|
+
@method(preserve=True)
|
|
308
|
+
def __add__(self, other: object) -> AnyExpr:
|
|
309
|
+
"""
|
|
310
|
+
>>> int(AnyExpr(1) + 2)
|
|
311
|
+
3
|
|
312
|
+
"""
|
|
313
|
+
return with_assert(self.egglog_any_expr_value + other)
|
|
314
|
+
|
|
315
|
+
@method(preserve=True)
|
|
316
|
+
def __sub__(self, other: object) -> AnyExpr:
|
|
317
|
+
"""
|
|
318
|
+
>>> int(AnyExpr(3) - 2)
|
|
319
|
+
1
|
|
320
|
+
"""
|
|
321
|
+
return with_assert(self.egglog_any_expr_value - other)
|
|
322
|
+
|
|
323
|
+
@method(preserve=True)
|
|
324
|
+
def __mul__(self, other: object) -> AnyExpr:
|
|
325
|
+
"""
|
|
326
|
+
# >>> int(AnyExpr(3) * 2)
|
|
327
|
+
# 6
|
|
328
|
+
>>> 4 * AnyExpr(3)
|
|
329
|
+
AnyExpr(A(4) * A(3))
|
|
330
|
+
"""
|
|
331
|
+
return with_assert(self.egglog_any_expr_value * other)
|
|
332
|
+
|
|
333
|
+
@method(preserve=True)
|
|
334
|
+
def __matmul__(self, other: object) -> AnyExpr:
|
|
335
|
+
"""
|
|
336
|
+
>>> class Matrix:
|
|
337
|
+
... def __matmul__(self, other):
|
|
338
|
+
... return 42
|
|
339
|
+
>>> int(AnyExpr(Matrix()) @ Matrix())
|
|
340
|
+
42
|
|
341
|
+
"""
|
|
342
|
+
return with_assert(self.egglog_any_expr_value @ other)
|
|
343
|
+
|
|
344
|
+
@method(preserve=True)
|
|
345
|
+
def __truediv__(self, other: object) -> AnyExpr:
|
|
346
|
+
"""
|
|
347
|
+
>>> float(AnyExpr(3) / 2)
|
|
348
|
+
1.5
|
|
349
|
+
"""
|
|
350
|
+
return with_assert(self.egglog_any_expr_value / other)
|
|
351
|
+
|
|
352
|
+
@method(preserve=True)
|
|
353
|
+
def __floordiv__(self, other: object) -> AnyExpr:
|
|
354
|
+
"""
|
|
355
|
+
>>> int(AnyExpr(3) // 2)
|
|
356
|
+
1
|
|
357
|
+
"""
|
|
358
|
+
return with_assert(self.egglog_any_expr_value // other)
|
|
359
|
+
|
|
360
|
+
@method(preserve=True)
|
|
361
|
+
def __mod__(self, other: object) -> AnyExpr:
|
|
362
|
+
"""
|
|
363
|
+
>>> int(AnyExpr(3) % 2)
|
|
364
|
+
1
|
|
365
|
+
"""
|
|
366
|
+
return with_assert(self.egglog_any_expr_value % other)
|
|
367
|
+
|
|
368
|
+
@method(preserve=True)
|
|
369
|
+
def __divmod__(self, other: object) -> AnyExpr:
|
|
370
|
+
"""
|
|
371
|
+
>>> div, mod = divmod(AnyExpr(3), 2)
|
|
372
|
+
>>> int(div)
|
|
373
|
+
1
|
|
374
|
+
>>> int(mod)
|
|
375
|
+
1
|
|
376
|
+
"""
|
|
377
|
+
return with_assert(divmod(self.egglog_any_expr_value, other))
|
|
378
|
+
|
|
379
|
+
# TODO: Support modulo
|
|
380
|
+
@method(preserve=True)
|
|
381
|
+
def __pow__(self, other: object) -> AnyExpr:
|
|
382
|
+
"""
|
|
383
|
+
>>> int(AnyExpr(3) ** 2)
|
|
384
|
+
9
|
|
385
|
+
"""
|
|
386
|
+
return with_assert(self.egglog_any_expr_value**other)
|
|
387
|
+
|
|
388
|
+
@method(preserve=True)
|
|
389
|
+
def __lshift__(self, other: object) -> AnyExpr:
|
|
390
|
+
"""
|
|
391
|
+
>>> int(AnyExpr(1) << 2)
|
|
392
|
+
4
|
|
393
|
+
"""
|
|
394
|
+
return with_assert(self.egglog_any_expr_value << other)
|
|
395
|
+
|
|
396
|
+
@method(preserve=True)
|
|
397
|
+
def __rshift__(self, other: object) -> AnyExpr:
|
|
398
|
+
"""
|
|
399
|
+
>>> int(AnyExpr(4) >> 2)
|
|
400
|
+
1
|
|
401
|
+
"""
|
|
402
|
+
return with_assert(self.egglog_any_expr_value >> other)
|
|
403
|
+
|
|
404
|
+
@method(preserve=True)
|
|
405
|
+
def __and__(self, other: object) -> AnyExpr:
|
|
406
|
+
"""
|
|
407
|
+
>>> int(AnyExpr(6) & 3)
|
|
408
|
+
2
|
|
409
|
+
"""
|
|
410
|
+
return with_assert(self.egglog_any_expr_value & other)
|
|
411
|
+
|
|
412
|
+
@method(preserve=True)
|
|
413
|
+
def __xor__(self, other: object) -> AnyExpr:
|
|
414
|
+
"""
|
|
415
|
+
>>> int(AnyExpr(6) ^ 3)
|
|
416
|
+
5
|
|
417
|
+
"""
|
|
418
|
+
return with_assert(self.egglog_any_expr_value ^ other)
|
|
419
|
+
|
|
420
|
+
@method(preserve=True)
|
|
421
|
+
def __or__(self, other: object) -> AnyExpr:
|
|
422
|
+
"""
|
|
423
|
+
>>> int(AnyExpr(6) | 3)
|
|
424
|
+
7
|
|
425
|
+
"""
|
|
426
|
+
return with_assert(self.egglog_any_expr_value | other)
|
|
427
|
+
|
|
428
|
+
@method(preserve=True)
|
|
429
|
+
def __neg__(self) -> AnyExpr:
|
|
430
|
+
"""
|
|
431
|
+
>>> int(-AnyExpr(3))
|
|
432
|
+
-3
|
|
433
|
+
"""
|
|
434
|
+
return with_assert(-self.egglog_any_expr_value)
|
|
435
|
+
|
|
436
|
+
@method(preserve=True)
|
|
437
|
+
def __pos__(self) -> AnyExpr:
|
|
438
|
+
"""
|
|
439
|
+
>>> int(+AnyExpr(3))
|
|
440
|
+
3
|
|
441
|
+
"""
|
|
442
|
+
return with_assert(+self.egglog_any_expr_value)
|
|
443
|
+
|
|
444
|
+
@method(preserve=True)
|
|
445
|
+
def __abs__(self) -> AnyExpr:
|
|
446
|
+
"""
|
|
447
|
+
>>> int(abs(AnyExpr(-3)))
|
|
448
|
+
3
|
|
449
|
+
"""
|
|
450
|
+
return with_assert(abs(self.egglog_any_expr_value))
|
|
451
|
+
|
|
452
|
+
@method(preserve=True)
|
|
453
|
+
def __complex__(self) -> complex:
|
|
454
|
+
"""
|
|
455
|
+
>>> complex(AnyExpr(3+4j))
|
|
456
|
+
(3+4j)
|
|
457
|
+
"""
|
|
458
|
+
return any_eval(complex_(self))
|
|
459
|
+
|
|
460
|
+
@method(preserve=True)
|
|
461
|
+
def __int__(self) -> int:
|
|
462
|
+
"""
|
|
463
|
+
>>> int(AnyExpr(42))
|
|
464
|
+
42
|
|
465
|
+
"""
|
|
466
|
+
return any_eval(int_(self))
|
|
467
|
+
|
|
468
|
+
@method(preserve=True)
|
|
469
|
+
def __float__(self) -> float:
|
|
470
|
+
"""
|
|
471
|
+
>>> float(AnyExpr(3.14))
|
|
472
|
+
3.14
|
|
473
|
+
"""
|
|
474
|
+
return any_eval(float_(self))
|
|
475
|
+
|
|
476
|
+
@method(preserve=True)
|
|
477
|
+
def __index__(self) -> int:
|
|
478
|
+
"""
|
|
479
|
+
>>> import operator
|
|
480
|
+
>>> operator.index(AnyExpr(42))
|
|
481
|
+
42
|
|
482
|
+
"""
|
|
483
|
+
return any_eval(index(self))
|
|
484
|
+
|
|
485
|
+
# TODO: support ndigits with optional int
|
|
486
|
+
@method(preserve=True)
|
|
487
|
+
def __round__(self) -> AnyExpr:
|
|
488
|
+
"""
|
|
489
|
+
>>> int(round(AnyExpr(3.6)))
|
|
490
|
+
4
|
|
491
|
+
"""
|
|
492
|
+
return with_assert(round(self.egglog_any_expr_value))
|
|
493
|
+
|
|
494
|
+
@method(preserve=True)
|
|
495
|
+
def __trunc__(self) -> AnyExpr:
|
|
496
|
+
"""
|
|
497
|
+
>>> import math
|
|
498
|
+
>>> int(math.trunc(AnyExpr(3.6)))
|
|
499
|
+
3
|
|
500
|
+
"""
|
|
501
|
+
return with_assert(math.trunc(self.egglog_any_expr_value))
|
|
502
|
+
|
|
503
|
+
@method(preserve=True)
|
|
504
|
+
def __floor__(self) -> AnyExpr:
|
|
505
|
+
"""
|
|
506
|
+
>>> import math
|
|
507
|
+
>>> int(math.floor(AnyExpr(3.6)))
|
|
508
|
+
3
|
|
509
|
+
"""
|
|
510
|
+
return with_assert(math.floor(self.egglog_any_expr_value))
|
|
511
|
+
|
|
512
|
+
@method(preserve=True)
|
|
513
|
+
def __ceil__(self) -> AnyExpr:
|
|
514
|
+
"""
|
|
515
|
+
>>> import math
|
|
516
|
+
>>> int(math.ceil(AnyExpr(3.4)))
|
|
517
|
+
4
|
|
518
|
+
"""
|
|
519
|
+
return with_assert(math.ceil(self.egglog_any_expr_value))
|
|
520
|
+
|
|
521
|
+
# TODO: https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
class A(Expr):
|
|
525
|
+
def __init__(self, obj: object) -> None: ...
|
|
526
|
+
|
|
527
|
+
__match_args__ = ("egglog_any_expr_value",)
|
|
528
|
+
|
|
529
|
+
@method(preserve=True) # type: ignore[prop-decorator]
|
|
530
|
+
@property
|
|
531
|
+
def egglog_any_expr_value(self) -> object:
|
|
532
|
+
"""
|
|
533
|
+
Return the underlying Python object, if it was constructued with one.
|
|
534
|
+
|
|
535
|
+
Long method name so it doesn't conflict with any user-defined properties.
|
|
536
|
+
|
|
537
|
+
>>> A(10).egglog_any_expr_value
|
|
538
|
+
10
|
|
539
|
+
"""
|
|
540
|
+
match get_callable_args(self, A):
|
|
541
|
+
case (PyObject(obj),):
|
|
542
|
+
return obj
|
|
543
|
+
raise ExprValueError(self, "A")
|
|
544
|
+
|
|
545
|
+
def __eq__(self, other: ALike) -> A: ... # type: ignore[override]
|
|
546
|
+
def __ne__(self, other: ALike) -> A: ... # type: ignore[override]
|
|
547
|
+
def __lt__(self, other: ALike) -> A: ...
|
|
548
|
+
def __le__(self, other: ALike) -> A: ...
|
|
549
|
+
def __gt__(self, other: ALike) -> A: ...
|
|
550
|
+
def __ge__(self, other: ALike) -> A: ...
|
|
551
|
+
def __getattr__(self, name: StringLike) -> A: ...
|
|
552
|
+
def __call__(self, args: ALike = (), kwargs: ALike = {}) -> A: ... # noqa: B006
|
|
553
|
+
def __getitem__(self, key: ALike) -> A: ...
|
|
554
|
+
def __setitem__(self, key: ALike, value: ALike) -> None: ...
|
|
555
|
+
def __delitem__(self, key: ALike) -> None: ...
|
|
556
|
+
def __add__(self, other: ALike) -> A: ...
|
|
557
|
+
def __sub__(self, other: ALike) -> A: ...
|
|
558
|
+
def __mul__(self, other: ALike) -> A: ...
|
|
559
|
+
def __matmul__(self, other: ALike) -> A: ...
|
|
560
|
+
def __truediv__(self, other: ALike) -> A: ...
|
|
561
|
+
def __floordiv__(self, other: ALike) -> A: ...
|
|
562
|
+
def __mod__(self, other: ALike) -> A: ...
|
|
563
|
+
def __divmod__(self, other: ALike) -> A: ...
|
|
564
|
+
def __pow__(self, other: ALike) -> A: ...
|
|
565
|
+
def __lshift__(self, other: ALike) -> A: ...
|
|
566
|
+
def __rshift__(self, other: ALike) -> A: ...
|
|
567
|
+
def __and__(self, other: ALike) -> A: ...
|
|
568
|
+
def __xor__(self, other: ALike) -> A: ...
|
|
569
|
+
def __or__(self, other: ALike) -> A: ...
|
|
570
|
+
def __neg__(self) -> A: ...
|
|
571
|
+
def __pos__(self) -> A: ...
|
|
572
|
+
def __abs__(self) -> A: ...
|
|
573
|
+
def __round__(self) -> A: ...
|
|
574
|
+
def __trunc__(self) -> A: ...
|
|
575
|
+
def __floor__(self) -> A: ...
|
|
576
|
+
def __ceil__(self) -> A: ...
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
converter(A, AnyExpr, AnyExpr)
|
|
580
|
+
converter(object, AnyExpr, lambda x: AnyExpr(A(PyObject(x))))
|
|
581
|
+
|
|
582
|
+
converter(AnyExpr, A, lambda a: a.egglog_any_expr_value, cost=10)
|
|
583
|
+
converter(PyObject, A, A, cost=10)
|
|
584
|
+
converter(object, A, lambda x: A(PyObject(x)), cost=10)
|
|
585
|
+
|
|
586
|
+
ALike: TypeAlias = A | object
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@function()
|
|
590
|
+
def given(expr: ALike, condition: ALike) -> A: ...
|
|
591
|
+
@function
|
|
592
|
+
def bytes_(expr: ALike) -> A: ...
|
|
593
|
+
@function
|
|
594
|
+
def bool_(expr: ALike) -> A: ...
|
|
595
|
+
@function
|
|
596
|
+
def hasattr_(expr: ALike, name: StringLike) -> A: ...
|
|
597
|
+
@function
|
|
598
|
+
def getattr_eager(expr: ALike, name: StringLike) -> Unit:
|
|
599
|
+
"""
|
|
600
|
+
Set if we should eagerly get the attribute.
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
@function
|
|
605
|
+
def len_(expr: ALike) -> A: ...
|
|
606
|
+
@function
|
|
607
|
+
def append(expr: ALike, item: ALike) -> A:
|
|
608
|
+
"""
|
|
609
|
+
Appends an item to a tuple.
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
@function
|
|
614
|
+
def set_kwarg(expr: ALike, key: StringLike, value: ALike) -> A:
|
|
615
|
+
"""
|
|
616
|
+
Sets a value in a dict with a string key
|
|
617
|
+
"""
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
@function
|
|
621
|
+
def contains(expr: ALike, item: ALike) -> A: ...
|
|
622
|
+
@function
|
|
623
|
+
def complex_(expr: ALike) -> A: ...
|
|
624
|
+
@function
|
|
625
|
+
def int_(expr: ALike) -> A: ...
|
|
626
|
+
@function
|
|
627
|
+
def float_(expr: ALike) -> A: ...
|
|
628
|
+
@function
|
|
629
|
+
def index(expr: ALike) -> A: ...
|
|
630
|
+
@function
|
|
631
|
+
def slice_(start: ALike = None, stop: ALike = None, step: ALike = None) -> A: ...
|
|
632
|
+
@function
|
|
633
|
+
def list_(expr: ALike) -> A: ...
|
|
634
|
+
@function
|
|
635
|
+
def not_(expr: ALike) -> A: ...
|
|
636
|
+
@function
|
|
637
|
+
def and_(left: ALike, right: ALike) -> A: ...
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
# Special case containers so that Any expressions inside
|
|
641
|
+
converter(tuple, A, lambda x: reduce(append, x, A(())))
|
|
642
|
+
converter(list, A, lambda x: list_(tuple(x)))
|
|
643
|
+
converter(slice, A, lambda x: slice_(x.start, x.stop, x.step))
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
@ruleset
|
|
647
|
+
def any_expr_ruleset(x: PyObject, y: PyObject, z: PyObject, s: String, a: A):
|
|
648
|
+
yield rewrite(bytes_(A(x))).to(A(PyObject(bytes)(x)))
|
|
649
|
+
yield rewrite(bool_(A(x))).to(A(PyObject(bool)(x)))
|
|
650
|
+
yield rewrite(A(x) == A(y)).to(A(PyObject(operator.eq)(x, y)))
|
|
651
|
+
yield rewrite(A(x) != A(y)).to(A(PyObject(operator.ne)(x, y)))
|
|
652
|
+
yield rewrite(A(x) < A(y)).to(A(PyObject(operator.lt)(x, y)))
|
|
653
|
+
yield rewrite(A(x) <= A(y)).to(A(PyObject(operator.le)(x, y)))
|
|
654
|
+
yield rewrite(A(x) > A(y)).to(A(PyObject(operator.gt)(x, y)))
|
|
655
|
+
yield rewrite(A(x) >= A(y)).to(A(PyObject(operator.ge)(x, y)))
|
|
656
|
+
yield rewrite(A(x).__getattr__(s)).to(A(PyObject(getattr)(x, PyObject.from_string(s))))
|
|
657
|
+
yield rewrite(hasattr_(A(x), s)).to(A(PyObject(hasattr)(x, PyObject.from_string(s))))
|
|
658
|
+
yield rewrite(len_(A(x))).to(A(PyObject(len)(x)))
|
|
659
|
+
yield rewrite(A(x)(y, z)).to(A(x.call_extended(y, z)))
|
|
660
|
+
yield rewrite(append(A(x), A(y))).to(A(PyObject(lambda t, v: (*t, v))(x, y)))
|
|
661
|
+
yield rewrite(set_kwarg(A(x), s, A(y))).to(A(PyObject(lambda d, k, v: {**d, k: v})(x, PyObject.from_string(s), y)))
|
|
662
|
+
yield rewrite(A(x)[A(y)]).to(A(PyObject(operator.getitem)(x, y)))
|
|
663
|
+
setitem_any = A(x)
|
|
664
|
+
setitem_any[A(y)] = A(z)
|
|
665
|
+
yield rewrite(setitem_any).to(A(PyObject(lambda obj, k, v: operator.setitem(obj, k, v) or obj)(x, y, z)))
|
|
666
|
+
delitem_any = A(x)
|
|
667
|
+
del delitem_any[A(y)]
|
|
668
|
+
yield rewrite(delitem_any).to(A(PyObject(lambda obj, k: operator.delitem(obj, k) or obj)(x, y)))
|
|
669
|
+
yield rewrite(contains(A(x), A(y))).to(A(PyObject(operator.contains)(x, y)))
|
|
670
|
+
yield rewrite(A(x) + A(y)).to(A(PyObject(operator.add)(x, y)))
|
|
671
|
+
yield rewrite(A(x) - A(y)).to(A(PyObject(operator.sub)(x, y)))
|
|
672
|
+
yield rewrite(A(x) * A(y)).to(A(PyObject(operator.mul)(x, y)))
|
|
673
|
+
yield rewrite(A(x) @ A(y)).to(A(PyObject(operator.matmul)(x, y)))
|
|
674
|
+
yield rewrite(A(x) / A(y)).to(A(PyObject(operator.truediv)(x, y)))
|
|
675
|
+
yield rewrite(A(x) // A(y)).to(A(PyObject(operator.floordiv)(x, y)))
|
|
676
|
+
yield rewrite(A(x) % A(y)).to(A(PyObject(operator.mod)(x, y)))
|
|
677
|
+
yield rewrite(divmod(A(x), A(y))).to(A(PyObject(divmod)(x, y)))
|
|
678
|
+
yield rewrite(A(x) ** A(y)).to(A(PyObject(operator.pow)(x, y)))
|
|
679
|
+
yield rewrite(A(x) << A(y)).to(A(PyObject(operator.lshift)(x, y)))
|
|
680
|
+
yield rewrite(A(x) >> A(y)).to(A(PyObject(operator.rshift)(x, y)))
|
|
681
|
+
yield rewrite(A(x) & A(y)).to(A(PyObject(operator.and_)(x, y)))
|
|
682
|
+
yield rewrite(A(x) ^ A(y)).to(A(PyObject(operator.xor)(x, y)))
|
|
683
|
+
yield rewrite(A(x) | A(y)).to(A(PyObject(operator.or_)(x, y)))
|
|
684
|
+
yield rewrite(-A(x)).to(A(PyObject(operator.neg)(x)))
|
|
685
|
+
yield rewrite(+A(x)).to(A(PyObject(operator.pos)(x)))
|
|
686
|
+
yield rewrite(abs(A(x))).to(A(PyObject(operator.abs)(x)))
|
|
687
|
+
yield rewrite(complex_(A(x))).to(A(PyObject(complex)(x)))
|
|
688
|
+
yield rewrite(int_(A(x))).to(A(PyObject(int)(x)))
|
|
689
|
+
yield rewrite(float_(A(x))).to(A(PyObject(float)(x)))
|
|
690
|
+
yield rewrite(index(A(x))).to(A(PyObject(operator.index)(x)))
|
|
691
|
+
yield rewrite(round(A(x))).to(A(PyObject(round)(x)))
|
|
692
|
+
yield rewrite(math.trunc(A(x))).to(A(PyObject(math.trunc)(x)))
|
|
693
|
+
yield rewrite(math.floor(A(x))).to(A(PyObject(math.floor)(x)))
|
|
694
|
+
yield rewrite(math.ceil(A(x))).to(A(PyObject(math.ceil)(x)))
|
|
695
|
+
yield rewrite(list_(A(x))).to(A(PyObject(list)(x)))
|
|
696
|
+
yield rewrite(slice_(A(x), A(y), A(z))).to(A(PyObject(slice)(x, y, z)))
|
|
697
|
+
|
|
698
|
+
# Given
|
|
699
|
+
yield rewrite(given(A(x), a)).to(A(x))
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
any_expr_schedule = any_expr_ruleset.saturate()
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def any_eval(self: A) -> Any:
|
|
706
|
+
"""
|
|
707
|
+
Evaluate the AnyExpr to get its underlying Python value.
|
|
708
|
+
|
|
709
|
+
Runs rules if it's not already resolved
|
|
710
|
+
"""
|
|
711
|
+
global _LAST_ASSERT
|
|
712
|
+
egraph = _get_current_egraph()
|
|
713
|
+
# 1. First see if it's already a primitive value
|
|
714
|
+
try:
|
|
715
|
+
return self.egglog_any_expr_value
|
|
716
|
+
except ExprValueError:
|
|
717
|
+
pass
|
|
718
|
+
# 2. If not, try to extract it from the egraph
|
|
719
|
+
expr = egraph.extract(self)
|
|
720
|
+
try:
|
|
721
|
+
res = expr.egglog_any_expr_value
|
|
722
|
+
except ExprValueError:
|
|
723
|
+
# 3. If that isn't one, then try running the schedule to extract it
|
|
724
|
+
egraph.register(expr)
|
|
725
|
+
egraph.run(any_expr_schedule)
|
|
726
|
+
expr = egraph.extract(expr)
|
|
727
|
+
res = expr.egglog_any_expr_value
|
|
728
|
+
# Don't save hasattr asserts
|
|
729
|
+
if get_callable_fn(self) != hasattr_:
|
|
730
|
+
# If we are calling bool_ same as just asserting vlaues
|
|
731
|
+
match get_callable_args(self, bool_):
|
|
732
|
+
case (A() as inner,):
|
|
733
|
+
self = inner
|
|
734
|
+
if eq(expr).to(A(True)):
|
|
735
|
+
asserted = self
|
|
736
|
+
_LAST_ASSERT = with_assert(self).egglog_any_expr_value
|
|
737
|
+
elif eq(expr).to(A(False)):
|
|
738
|
+
match get_callable_args(self, A.__eq__):
|
|
739
|
+
case (A() as left, A() as right):
|
|
740
|
+
asserted = left != right
|
|
741
|
+
case _:
|
|
742
|
+
match get_callable_args(self, A.__ne__):
|
|
743
|
+
case (A() as left, A() as right):
|
|
744
|
+
asserted = left == right
|
|
745
|
+
case _:
|
|
746
|
+
asserted = not_(self)
|
|
747
|
+
else:
|
|
748
|
+
asserted = self == expr
|
|
749
|
+
# _LAST_ASSERT = (
|
|
750
|
+
# asserted if _LAST_ASSERT is None or eq(_LAST_ASSERT).to(asserted) else and_(_LAST_ASSERT, asserted)
|
|
751
|
+
# )
|
|
752
|
+
_LAST_ASSERT = given(asserted, _LAST_ASSERT) if _LAST_ASSERT is not None else asserted
|
|
753
|
+
return res
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
_CURRENT_EGRAPH: None | EGraph = None
|
|
757
|
+
_LAST_ASSERT: None | A = None
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
@contextlib.contextmanager
|
|
761
|
+
def set_any_expr_egraph(egraph: EGraph) -> Iterator[None]:
|
|
762
|
+
"""
|
|
763
|
+
Context manager that will set the current egraph. It will be set back after.
|
|
764
|
+
"""
|
|
765
|
+
global _CURRENT_EGRAPH, _LAST_ASSERT
|
|
766
|
+
assert _CURRENT_EGRAPH is None
|
|
767
|
+
assert _LAST_ASSERT is None
|
|
768
|
+
_CURRENT_EGRAPH = egraph
|
|
769
|
+
try:
|
|
770
|
+
yield
|
|
771
|
+
finally:
|
|
772
|
+
_CURRENT_EGRAPH = None
|
|
773
|
+
_LAST_ASSERT = None
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def _get_current_egraph() -> EGraph:
|
|
777
|
+
return _CURRENT_EGRAPH or EGraph()
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def with_assert(expr: A) -> AnyExpr:
|
|
781
|
+
"""
|
|
782
|
+
Add all current asserts to the given expression.
|
|
783
|
+
|
|
784
|
+
This is used to make sure that any_evaled expressions are consistent with
|
|
785
|
+
the current context.
|
|
786
|
+
"""
|
|
787
|
+
if _CURRENT_EGRAPH and _LAST_ASSERT is not None: # noqa: SIM108
|
|
788
|
+
a = given(expr, _LAST_ASSERT)
|
|
789
|
+
# match get_callable_args(expr, given):
|
|
790
|
+
# case (A() as inner, A() as condition):
|
|
791
|
+
# a = expr if eq(condition).to(_LAST_ASSERT) else given(inner, and_(condition, _LAST_ASSERT))
|
|
792
|
+
# case _:
|
|
793
|
+
|
|
794
|
+
else:
|
|
795
|
+
a = expr
|
|
796
|
+
return AnyExpr(a)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
@ruleset
|
|
800
|
+
def given_ruleset(x: A, y: A, z: A):
|
|
801
|
+
yield rewrite(not_(given(x, y)), subsume=True).to(given(not_(x), y))
|
|
802
|
+
yield rewrite(given(given(x, y), z), subsume=True).to(given(x, and_(y, z)))
|
|
803
|
+
yield rewrite(and_(x, x), subsume=True).to(x)
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
@function
|
|
807
|
+
def any_expr_program(x: AnyExpr) -> Program:
|
|
808
|
+
r"""
|
|
809
|
+
Convert an AnyExpr to a Program.
|
|
810
|
+
|
|
811
|
+
>>> any_expr_source(AnyExpr(42) == 10)
|
|
812
|
+
'(42 == 10)'
|
|
813
|
+
"""
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
@function
|
|
817
|
+
def a_program(x: A) -> Program:
|
|
818
|
+
"""
|
|
819
|
+
Convert an A to a Program.
|
|
820
|
+
"""
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def w(p: Program) -> Program:
|
|
824
|
+
return Program("(") + p + ")"
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
def ca(p: ProgramLike, *args: ProgramLike) -> Program:
|
|
828
|
+
args_expr = Program("")
|
|
829
|
+
for a in args[:-1]:
|
|
830
|
+
args_expr += a + ", "
|
|
831
|
+
args_expr += args[-1]
|
|
832
|
+
return convert(p, Program) + Program("(") + args_expr + Program(")")
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
INLINE_TYPES = int, str, float, bytes, bool, type(None), tuple, dict
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
@ruleset
|
|
839
|
+
def any_program_ruleset(a: A, b: A, c: A, p: PyObject, s: String):
|
|
840
|
+
yield rewrite(any_expr_program(AnyExpr(a)), subsume=True).to(a_program(a))
|
|
841
|
+
|
|
842
|
+
yield rewrite(a_program(A(p)), subsume=True).to(
|
|
843
|
+
Program(PyObject(repr)(p).to_string()),
|
|
844
|
+
PyObject(lambda x: isinstance(x, INLINE_TYPES))(p).to_bool() == Bool(True),
|
|
845
|
+
)
|
|
846
|
+
yield rewrite(a_program(A(p)), subsume=True).to(
|
|
847
|
+
Program(PyObject(repr)(p).to_string()).assign(),
|
|
848
|
+
PyObject(lambda x: isinstance(x, INLINE_TYPES))(p).to_bool() == Bool(False),
|
|
849
|
+
)
|
|
850
|
+
yield rewrite(a_program(bytes_(a)), subsume=True).to(a_program(a) + ".bytes()")
|
|
851
|
+
yield rewrite(a_program(bool_(a)), subsume=True).to(ca("bool", a_program(a)))
|
|
852
|
+
yield rewrite(a_program(a == b), subsume=True).to(w(a_program(a) + " == " + a_program(b)))
|
|
853
|
+
yield rewrite(a_program(a != b), subsume=True).to(w(a_program(a) + " != " + a_program(b)))
|
|
854
|
+
yield rewrite(a_program(a < b), subsume=True).to(w(a_program(a) + " < " + a_program(b)))
|
|
855
|
+
yield rewrite(a_program(a <= b), subsume=True).to(w(a_program(a) + " <= " + a_program(b)))
|
|
856
|
+
yield rewrite(a_program(a > b), subsume=True).to(w(a_program(a) + " > " + a_program(b)))
|
|
857
|
+
yield rewrite(a_program(a >= b), subsume=True).to(w(a_program(a) + " >= " + a_program(b)))
|
|
858
|
+
yield rewrite(a_program(a.__getattr__(s)), subsume=True).to(a_program(a) + "." + s)
|
|
859
|
+
yield rewrite(a_program(hasattr_(a, s)), subsume=True).to(
|
|
860
|
+
ca("hasattr", a_program(a), PyObject(repr)(PyObject.from_string(s)).to_string())
|
|
861
|
+
)
|
|
862
|
+
yield rewrite(a_program(len_(a)), subsume=True).to(ca("len", a_program(a)))
|
|
863
|
+
yield rewrite(a_program(a(b, c)), subsume=True).to(
|
|
864
|
+
ca(a_program(a), "*" + a_program(b), "**" + a_program(c)).assign()
|
|
865
|
+
)
|
|
866
|
+
yield rewrite(a_program(append(a, b)), subsume=True).to(ca("", "*" + a_program(a), a_program(b)))
|
|
867
|
+
yield rewrite(a_program(set_kwarg(a, s, b)), subsume=True).to(
|
|
868
|
+
"{**" + a_program(a) + ", " + PyObject(repr)(PyObject.from_string(s)).to_string() + ": " + a_program(b) + "}"
|
|
869
|
+
)
|
|
870
|
+
yield rewrite(a_program(a[b]), subsume=True).to(a_program(a) + "[" + a_program(b) + "]")
|
|
871
|
+
assigned_a = a_program(a).assign()
|
|
872
|
+
setitem_a = copy(a)
|
|
873
|
+
setitem_a[b] = c
|
|
874
|
+
yield rewrite(a_program(setitem_a), subsume=True).to(
|
|
875
|
+
assigned_a.statement(assigned_a + "[" + a_program(b) + "] = " + a_program(c))
|
|
876
|
+
)
|
|
877
|
+
delitem_a = copy(a)
|
|
878
|
+
del delitem_a[b]
|
|
879
|
+
yield rewrite(a_program(delitem_a), subsume=True).to(
|
|
880
|
+
assigned_a.statement("del " + assigned_a + "[" + a_program(b) + "]")
|
|
881
|
+
)
|
|
882
|
+
yield rewrite(a_program(contains(a, b)), subsume=True).to(w(a_program(a) + " in " + a_program(b)))
|
|
883
|
+
yield rewrite(a_program(a + b), subsume=True).to(w(a_program(a) + " + " + a_program(b)))
|
|
884
|
+
yield rewrite(a_program(a - b), subsume=True).to(w(a_program(a) + " - " + a_program(b)))
|
|
885
|
+
yield rewrite(a_program(a * b), subsume=True).to(w(a_program(a) + " * " + a_program(b)))
|
|
886
|
+
yield rewrite(a_program(a @ b), subsume=True).to(w(a_program(a) + " @ " + a_program(b)))
|
|
887
|
+
yield rewrite(a_program(a / b), subsume=True).to(w(a_program(a) + " / " + a_program(b)))
|
|
888
|
+
yield rewrite(a_program(a // b), subsume=True).to(w(a_program(a) + " // " + a_program(b)))
|
|
889
|
+
yield rewrite(a_program(a % b), subsume=True).to(w(a_program(a) + " % " + a_program(b)))
|
|
890
|
+
yield rewrite(a_program(divmod(a, b)), subsume=True).to(ca("divmod", a_program(a), a_program(b)))
|
|
891
|
+
yield rewrite(a_program(a**b), subsume=True).to(w(a_program(a) + " ** " + a_program(b)))
|
|
892
|
+
yield rewrite(a_program(a << b), subsume=True).to(w(a_program(a) + " << " + a_program(b)))
|
|
893
|
+
yield rewrite(a_program(a >> b), subsume=True).to(w(a_program(a) + " >> " + a_program(b)))
|
|
894
|
+
yield rewrite(a_program(a & b), subsume=True).to(w(a_program(a) + " & " + a_program(b)))
|
|
895
|
+
yield rewrite(a_program(a ^ b), subsume=True).to(w(a_program(a) + " ^ " + a_program(b)))
|
|
896
|
+
yield rewrite(a_program(a | b), subsume=True).to(w(a_program(a) + " | " + a_program(b)))
|
|
897
|
+
yield rewrite(a_program(-a), subsume=True).to("-" + a_program(a))
|
|
898
|
+
yield rewrite(a_program(+a), subsume=True).to("+" + a_program(a))
|
|
899
|
+
yield rewrite(a_program(abs(a)), subsume=True).to(ca("abs", a_program(a)))
|
|
900
|
+
yield rewrite(a_program(complex_(a)), subsume=True).to(ca("complex", a_program(a)))
|
|
901
|
+
yield rewrite(a_program(int_(a)), subsume=True).to(ca("int", a_program(a)))
|
|
902
|
+
yield rewrite(a_program(float_(a)), subsume=True).to(ca("float", a_program(a)))
|
|
903
|
+
yield rewrite(a_program(index(a)), subsume=True).to(ca("operator.index", a_program(a)))
|
|
904
|
+
yield rewrite(a_program(round(a)), subsume=True).to(ca("round", a_program(a)))
|
|
905
|
+
yield rewrite(a_program(math.trunc(a)), subsume=True).to(ca("math.trunc", a_program(a)))
|
|
906
|
+
yield rewrite(a_program(math.floor(a)), subsume=True).to(ca("math.floor", a_program(a)))
|
|
907
|
+
yield rewrite(a_program(math.ceil(a)), subsume=True).to(ca("math.ceil", a_program(a)))
|
|
908
|
+
yield rewrite(a_program(list_(a)), subsume=True).to(ca("list", a_program(a)))
|
|
909
|
+
yield rewrite(a_program(slice_(a, b, c)), subsume=True).to(ca("slice", a_program(a), a_program(b), a_program(c)))
|
|
910
|
+
|
|
911
|
+
yield rewrite(a_program(not_(a)), subsume=True).to(w("not " + a_program(a)))
|
|
912
|
+
yield rewrite(a_program(and_(a, b)), subsume=True).to(w(a_program(a) + " and " + a_program(b)))
|
|
913
|
+
# # Given
|
|
914
|
+
yield rewrite(a_program(given(a, b)), subsume=True).to(a_program(a).statement("assert " + a_program(b)))
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
any_program_schedule = any_program_ruleset.saturate() + program_gen_ruleset.saturate()
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
def any_expr_source(x: AnyExpr) -> str:
|
|
921
|
+
x = x.egglog_any_expr_value
|
|
922
|
+
# print(x)
|
|
923
|
+
program = a_program(x)
|
|
924
|
+
# print("program", program)
|
|
925
|
+
egraph = EGraph()
|
|
926
|
+
# program = egraph.let("program", program)
|
|
927
|
+
egraph.register(program)
|
|
928
|
+
egraph.run(any_program_ruleset.saturate())
|
|
929
|
+
res_program = egraph.extract(program)
|
|
930
|
+
egraph = EGraph()
|
|
931
|
+
egraph.register(res_program.compile())
|
|
932
|
+
egraph.run(program_gen_ruleset.saturate())
|
|
933
|
+
# print(egraph.extract(program))
|
|
934
|
+
# while egraph.run(any_program_ruleset).updated:
|
|
935
|
+
# print(egraph.extract(program))
|
|
936
|
+
# print("extracted", egraph.extract(program))
|
|
937
|
+
# egraph.run(program_gen_ruleset.saturate())
|
|
938
|
+
res = join(res_program.statements, res_program.expr)
|
|
939
|
+
return egraph.extract(res).value
|
|
940
|
+
# egraph.display()
|
|
941
|
+
# return black.format_str(str_res, mode=black.Mode()).strip()
|
|
942
|
+
|
|
943
|
+
|
|
944
|
+
x = AnyExpr([42])
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
print(x[0] + 10)
|