egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +887 -0
  4. egglog/builtins.py +1144 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +290 -0
  7. egglog/declarations.py +964 -0
  8. egglog/deconstruct.py +176 -0
  9. egglog/egraph.py +2247 -0
  10. egglog/egraph_state.py +978 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +64 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/MoA.ipynb +617 -0
  26. egglog/exp/__init__.py +3 -0
  27. egglog/exp/any_expr.py +947 -0
  28. egglog/exp/any_expr_example.ipynb +408 -0
  29. egglog/exp/array_api.py +2019 -0
  30. egglog/exp/array_api_jit.py +51 -0
  31. egglog/exp/array_api_loopnest.py +74 -0
  32. egglog/exp/array_api_numba.py +69 -0
  33. egglog/exp/array_api_program_gen.py +510 -0
  34. egglog/exp/program_gen.py +427 -0
  35. egglog/exp/siu_examples.py +32 -0
  36. egglog/ipython_magic.py +41 -0
  37. egglog/pretty.py +566 -0
  38. egglog/py.typed +0 -0
  39. egglog/runtime.py +888 -0
  40. egglog/thunk.py +97 -0
  41. egglog/type_constraint_solver.py +111 -0
  42. egglog/visualizer.css +1 -0
  43. egglog/visualizer.js +35798 -0
  44. egglog/visualizer_widget.py +39 -0
  45. egglog-12.0.0.dist-info/METADATA +93 -0
  46. egglog-12.0.0.dist-info/RECORD +48 -0
  47. egglog-12.0.0.dist-info/WHEEL +5 -0
  48. egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/exp/any_expr.py ADDED
@@ -0,0 +1,947 @@
1
+ """
2
+ WIP
3
+
4
+ An `AnyExpr`, which can be used to trace arbitrary expressions.
5
+
6
+ Created from any Python object, it should forward any operations on it to the underlying Python object.
7
+
8
+ This will only happen when it needs to be "materialized" however, through operations like `__bool__` or `__iter__`.
9
+
10
+ Generally it will try to avoid materializing the underlying object, and instead just treat it as a black box.
11
+ """
12
+ # mypy: disable-error-code="empty-body"
13
+
14
+ from __future__ import annotations
15
+
16
+ import contextlib
17
+ import math
18
+ import operator
19
+ from collections.abc import Iterator
20
+ from copy import copy
21
+ from functools import reduce
22
+ from typing import Any, TypeAlias
23
+
24
+ from egglog import *
25
+ from egglog.exp.program_gen import *
26
+
27
+
28
+ class AnyExpr(Expr):
29
+ """
30
+ Wraps an arbitrary Python object.
31
+
32
+ Any operations on it will be forwarded to the underlying object when needed.
33
+
34
+ Attempts to implement as many operations from https://docs.python.org/3/reference/datamodel.html as possible.
35
+
36
+ Can be converted from any Python object:
37
+
38
+ >>> AnyExpr(42) + 42
39
+ AnyExpr(A(42) + A(42))
40
+
41
+ Will also convert tuples and lists item by item:
42
+
43
+ >>> AnyExpr((1, 2,)) + (5, 6)
44
+ AnyExpr(append(append(A(()), A(1)), A(2)) + append(append(A(()), A(5)), A(6)))
45
+ """
46
+
47
+ def __init__(self, obj: ALike) -> None: ...
48
+
49
+ __match_args__ = ("egglog_any_expr_value",)
50
+
51
+ @method(preserve=True) # type: ignore[prop-decorator]
52
+ @property
53
+ def egglog_any_expr_value(self) -> A:
54
+ """
55
+ Return the underlying Python object, if it was constructued with one.
56
+
57
+ Long method name so it doesn't conflict with any user-defined properties.
58
+
59
+ >>> AnyExpr(10).egglog_any_expr_value
60
+ A(10)
61
+ """
62
+ match get_callable_args(self, AnyExpr):
63
+ case (A() as any_expr,):
64
+ return any_expr
65
+ raise ExprValueError(self, "AnyExpr")
66
+
67
+ @method(preserve=True)
68
+ def __bytes__(self) -> bytes:
69
+ """
70
+ >>> bytes(AnyExpr(b"hello"))
71
+ b'hello'
72
+ """
73
+ return any_eval(bytes_(self))
74
+
75
+ @method(preserve=True)
76
+ def __bool__(self) -> bool:
77
+ """
78
+ >>> bool(AnyExpr(True))
79
+ True
80
+ >>> bool(AnyExpr(False))
81
+ False
82
+ """
83
+ return any_eval(bool_(self))
84
+
85
+ @method(preserve=True)
86
+ def __eq__(self, other: object) -> AnyExpr: # type: ignore[override]
87
+ """
88
+ >>> bool(AnyExpr(1) == AnyExpr(1))
89
+ True
90
+ >>> bool(AnyExpr(1) == AnyExpr(2))
91
+ False
92
+ """
93
+ return with_assert(self.egglog_any_expr_value == other)
94
+
95
+ @method(preserve=True)
96
+ def __ne__(self, other: object) -> AnyExpr: # type: ignore[override]
97
+ """
98
+ >>> bool(AnyExpr(1) != AnyExpr(2))
99
+ True
100
+ >>> bool(AnyExpr(1) != AnyExpr(1))
101
+ False
102
+ """
103
+ return with_assert(self.egglog_any_expr_value != other)
104
+
105
+ @method(preserve=True)
106
+ def __lt__(self, other: object) -> AnyExpr:
107
+ """
108
+ >>> bool(AnyExpr(1) < AnyExpr(2))
109
+ True
110
+ >>> bool(AnyExpr(2) < AnyExpr(1))
111
+ False
112
+ """
113
+ return with_assert(self.egglog_any_expr_value < other)
114
+
115
+ @method(preserve=True)
116
+ def __le__(self, other: object) -> AnyExpr:
117
+ """
118
+ >>> bool(AnyExpr(2) <= AnyExpr(2))
119
+ True
120
+ >>> bool(AnyExpr(3) <= AnyExpr(2))
121
+ False
122
+ """
123
+ return with_assert(self.egglog_any_expr_value <= other)
124
+
125
+ @method(preserve=True)
126
+ def __gt__(self, other: object) -> AnyExpr:
127
+ """
128
+ >>> bool(AnyExpr(3) > AnyExpr(2))
129
+ True
130
+ >>> bool(AnyExpr(2) > AnyExpr(3))
131
+ False
132
+ """
133
+ return with_assert(self.egglog_any_expr_value > other)
134
+
135
+ @method(preserve=True)
136
+ def __ge__(self, other: object) -> AnyExpr:
137
+ """
138
+ >>> bool(AnyExpr(3) >= AnyExpr(3))
139
+ True
140
+ >>> bool(AnyExpr(2) >= AnyExpr(3))
141
+ False
142
+ """
143
+ return with_assert(self.egglog_any_expr_value >= other)
144
+
145
+ @method(preserve=True)
146
+ def __hash__(self) -> int:
147
+ """
148
+ Turn the underlying object into a hash.
149
+
150
+ >>> hash(AnyExpr("hello")) == hash("hello")
151
+ True
152
+ """
153
+ return hash(any_eval(self.egglog_any_expr_value))
154
+
155
+ @method(preserve=True)
156
+ def __getattr__(self, name: StringLike) -> AnyExpr | Any:
157
+ """
158
+ Get an attribute of the underlying object.
159
+
160
+ >>> int(AnyExpr([1, 2, 3]).index(2))
161
+ 1
162
+
163
+ Also should work with hasattr:
164
+ >>> hasattr(AnyExpr([1, 2, 3]), "index")
165
+ True
166
+ >>> hasattr(AnyExpr([1, 2, 3]), "nonexistent")
167
+ False
168
+ """
169
+ inner = self.egglog_any_expr_value
170
+ # Need to raise attribute error if it doesn't exist, since this is called for hasattr
171
+ if not any_eval(hasattr_(inner, name)):
172
+ raise AttributeError(f"{self} has no attribute {name}")
173
+ egraph = _get_current_egraph()
174
+ res = inner.__getattr__(name)
175
+ egraph.register(res)
176
+ egraph.run(any_expr_schedule)
177
+ if egraph.check_bool(getattr_eager(inner, name)):
178
+ return any_eval(res)
179
+ return with_assert(res)
180
+
181
+ # TODO: Not working for now
182
+ # @method(mutates_self=True)
183
+ # def __setattr__(self, name: StringLike, value: object) -> None:
184
+ # """
185
+ # Set an attribute of the underlying object.
186
+
187
+ # >>> x = lambda: None
188
+ # >>> expr = AnyExpr(x)
189
+ # >>> expr.attr = 42
190
+ # >>> int(expr.attr)
191
+ # 42
192
+ # """
193
+
194
+ # TODO: delattr
195
+ # TODO: __get__/__set__?
196
+
197
+ @method(preserve=True)
198
+ def __len__(self) -> int:
199
+ """
200
+ Get the length of the underlying object.
201
+
202
+ >>> len(AnyExpr([1, 2, 3]))
203
+ 3
204
+ """
205
+ return any_eval(len_(self))
206
+
207
+ @method(preserve=True)
208
+ def __call__(self, *args: object, **kwargs: object) -> AnyExpr:
209
+ """
210
+ Call the underlying object.
211
+
212
+ >>> int(AnyExpr(int)(42))
213
+ 42
214
+ >>> int(AnyExpr(lambda *x, **y: len(x) + len(y))(1, 2, a=3, b=4))
215
+ 4
216
+ """
217
+ args_expr = A(())
218
+ for a in args:
219
+ args_expr = append(args_expr, a)
220
+ kwargs_expr = A({})
221
+ for k, v in kwargs.items():
222
+ kwargs_expr = set_kwarg(kwargs_expr, k, v)
223
+ return with_assert(self.egglog_any_expr_value(args_expr, kwargs_expr))
224
+
225
+ @method(preserve=True)
226
+ def __getitem__(self, key: object) -> AnyExpr:
227
+ """
228
+ Get an item from the underlying object.
229
+
230
+ >>> int(AnyExpr([1, 2, 3])[1])
231
+ 2
232
+ """
233
+ return with_assert(self.egglog_any_expr_value[key])
234
+
235
+ @method(preserve=True)
236
+ def __setitem__(self, key: object, value: object) -> None:
237
+ """
238
+ Set an item in the underlying object.
239
+
240
+ >>> x = [1, 2, 3]
241
+ >>> expr = AnyExpr(x)
242
+ >>> expr[1] = 42
243
+ >>> int(expr[1])
244
+ 42
245
+ """
246
+ any_expr_inner = self.egglog_any_expr_value
247
+ any_expr_inner[key] = value
248
+ self.__replace_expr__(AnyExpr(with_assert(any_expr_inner)))
249
+
250
+ @method(preserve=True)
251
+ def __delitem__(self, key: object) -> None:
252
+ """
253
+ Delete an item from the underlying object.
254
+
255
+ >>> x = [1, 2, 3]
256
+ >>> expr = AnyExpr(x)
257
+ >>> del expr[1]
258
+ >>> len(expr)
259
+ 2
260
+ """
261
+ any_expr_inner = self.egglog_any_expr_value
262
+ del any_expr_inner[key]
263
+ self.__replace_expr__(AnyExpr(with_assert(any_expr_inner)))
264
+
265
+ # TODO: support real iterators
266
+ @method(preserve=True)
267
+ def __iter__(self) -> Iterator[AnyExpr]:
268
+ """
269
+ Iterate over the underlying object.
270
+
271
+ >>> list(AnyExpr((1, 2)))
272
+ [AnyExpr(append(append(A(()), A(1)), A(2))[A(0)]), AnyExpr(append(append(A(()), A(1)), A(2))[A(1)])]
273
+ """
274
+ return iter(self[i] for i in range(len(self)))
275
+
276
+ # TODO: Not working for now
277
+ # @method(preserve=True)
278
+ # def __reversed__(self) -> Iterator[AnyExpr]:
279
+ # """
280
+ # Reverse iterate over the underlying object.
281
+
282
+ # >>> list(reversed(AnyExpr([1, 2, 3])))
283
+ # [AnyExpr(3), AnyExpr(2), AnyExpr(1)]
284
+ # """
285
+ # return map(AnyExpr, any_eval(reversed_op(self)))
286
+
287
+ @method(preserve=True)
288
+ def __contains__(self, item: object) -> bool:
289
+ """
290
+ Check if the underlying object contains an item.
291
+
292
+ >>> class A:
293
+ ... def __contains__(self, item):
294
+ ... return item == 42
295
+ >>> 42 in AnyExpr(A())
296
+ True
297
+ >>> 2 in AnyExpr(A())
298
+ False
299
+ """
300
+ return any_eval(contains(self, item))
301
+
302
+ ##
303
+ # Emulating numeric types
304
+ # https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
305
+ ##
306
+
307
+ @method(preserve=True)
308
+ def __add__(self, other: object) -> AnyExpr:
309
+ """
310
+ >>> int(AnyExpr(1) + 2)
311
+ 3
312
+ """
313
+ return with_assert(self.egglog_any_expr_value + other)
314
+
315
+ @method(preserve=True)
316
+ def __sub__(self, other: object) -> AnyExpr:
317
+ """
318
+ >>> int(AnyExpr(3) - 2)
319
+ 1
320
+ """
321
+ return with_assert(self.egglog_any_expr_value - other)
322
+
323
+ @method(preserve=True)
324
+ def __mul__(self, other: object) -> AnyExpr:
325
+ """
326
+ # >>> int(AnyExpr(3) * 2)
327
+ # 6
328
+ >>> 4 * AnyExpr(3)
329
+ AnyExpr(A(4) * A(3))
330
+ """
331
+ return with_assert(self.egglog_any_expr_value * other)
332
+
333
+ @method(preserve=True)
334
+ def __matmul__(self, other: object) -> AnyExpr:
335
+ """
336
+ >>> class Matrix:
337
+ ... def __matmul__(self, other):
338
+ ... return 42
339
+ >>> int(AnyExpr(Matrix()) @ Matrix())
340
+ 42
341
+ """
342
+ return with_assert(self.egglog_any_expr_value @ other)
343
+
344
+ @method(preserve=True)
345
+ def __truediv__(self, other: object) -> AnyExpr:
346
+ """
347
+ >>> float(AnyExpr(3) / 2)
348
+ 1.5
349
+ """
350
+ return with_assert(self.egglog_any_expr_value / other)
351
+
352
+ @method(preserve=True)
353
+ def __floordiv__(self, other: object) -> AnyExpr:
354
+ """
355
+ >>> int(AnyExpr(3) // 2)
356
+ 1
357
+ """
358
+ return with_assert(self.egglog_any_expr_value // other)
359
+
360
+ @method(preserve=True)
361
+ def __mod__(self, other: object) -> AnyExpr:
362
+ """
363
+ >>> int(AnyExpr(3) % 2)
364
+ 1
365
+ """
366
+ return with_assert(self.egglog_any_expr_value % other)
367
+
368
+ @method(preserve=True)
369
+ def __divmod__(self, other: object) -> AnyExpr:
370
+ """
371
+ >>> div, mod = divmod(AnyExpr(3), 2)
372
+ >>> int(div)
373
+ 1
374
+ >>> int(mod)
375
+ 1
376
+ """
377
+ return with_assert(divmod(self.egglog_any_expr_value, other))
378
+
379
+ # TODO: Support modulo
380
+ @method(preserve=True)
381
+ def __pow__(self, other: object) -> AnyExpr:
382
+ """
383
+ >>> int(AnyExpr(3) ** 2)
384
+ 9
385
+ """
386
+ return with_assert(self.egglog_any_expr_value**other)
387
+
388
+ @method(preserve=True)
389
+ def __lshift__(self, other: object) -> AnyExpr:
390
+ """
391
+ >>> int(AnyExpr(1) << 2)
392
+ 4
393
+ """
394
+ return with_assert(self.egglog_any_expr_value << other)
395
+
396
+ @method(preserve=True)
397
+ def __rshift__(self, other: object) -> AnyExpr:
398
+ """
399
+ >>> int(AnyExpr(4) >> 2)
400
+ 1
401
+ """
402
+ return with_assert(self.egglog_any_expr_value >> other)
403
+
404
+ @method(preserve=True)
405
+ def __and__(self, other: object) -> AnyExpr:
406
+ """
407
+ >>> int(AnyExpr(6) & 3)
408
+ 2
409
+ """
410
+ return with_assert(self.egglog_any_expr_value & other)
411
+
412
+ @method(preserve=True)
413
+ def __xor__(self, other: object) -> AnyExpr:
414
+ """
415
+ >>> int(AnyExpr(6) ^ 3)
416
+ 5
417
+ """
418
+ return with_assert(self.egglog_any_expr_value ^ other)
419
+
420
+ @method(preserve=True)
421
+ def __or__(self, other: object) -> AnyExpr:
422
+ """
423
+ >>> int(AnyExpr(6) | 3)
424
+ 7
425
+ """
426
+ return with_assert(self.egglog_any_expr_value | other)
427
+
428
+ @method(preserve=True)
429
+ def __neg__(self) -> AnyExpr:
430
+ """
431
+ >>> int(-AnyExpr(3))
432
+ -3
433
+ """
434
+ return with_assert(-self.egglog_any_expr_value)
435
+
436
+ @method(preserve=True)
437
+ def __pos__(self) -> AnyExpr:
438
+ """
439
+ >>> int(+AnyExpr(3))
440
+ 3
441
+ """
442
+ return with_assert(+self.egglog_any_expr_value)
443
+
444
+ @method(preserve=True)
445
+ def __abs__(self) -> AnyExpr:
446
+ """
447
+ >>> int(abs(AnyExpr(-3)))
448
+ 3
449
+ """
450
+ return with_assert(abs(self.egglog_any_expr_value))
451
+
452
+ @method(preserve=True)
453
+ def __complex__(self) -> complex:
454
+ """
455
+ >>> complex(AnyExpr(3+4j))
456
+ (3+4j)
457
+ """
458
+ return any_eval(complex_(self))
459
+
460
+ @method(preserve=True)
461
+ def __int__(self) -> int:
462
+ """
463
+ >>> int(AnyExpr(42))
464
+ 42
465
+ """
466
+ return any_eval(int_(self))
467
+
468
+ @method(preserve=True)
469
+ def __float__(self) -> float:
470
+ """
471
+ >>> float(AnyExpr(3.14))
472
+ 3.14
473
+ """
474
+ return any_eval(float_(self))
475
+
476
+ @method(preserve=True)
477
+ def __index__(self) -> int:
478
+ """
479
+ >>> import operator
480
+ >>> operator.index(AnyExpr(42))
481
+ 42
482
+ """
483
+ return any_eval(index(self))
484
+
485
+ # TODO: support ndigits with optional int
486
+ @method(preserve=True)
487
+ def __round__(self) -> AnyExpr:
488
+ """
489
+ >>> int(round(AnyExpr(3.6)))
490
+ 4
491
+ """
492
+ return with_assert(round(self.egglog_any_expr_value))
493
+
494
+ @method(preserve=True)
495
+ def __trunc__(self) -> AnyExpr:
496
+ """
497
+ >>> import math
498
+ >>> int(math.trunc(AnyExpr(3.6)))
499
+ 3
500
+ """
501
+ return with_assert(math.trunc(self.egglog_any_expr_value))
502
+
503
+ @method(preserve=True)
504
+ def __floor__(self) -> AnyExpr:
505
+ """
506
+ >>> import math
507
+ >>> int(math.floor(AnyExpr(3.6)))
508
+ 3
509
+ """
510
+ return with_assert(math.floor(self.egglog_any_expr_value))
511
+
512
+ @method(preserve=True)
513
+ def __ceil__(self) -> AnyExpr:
514
+ """
515
+ >>> import math
516
+ >>> int(math.ceil(AnyExpr(3.4)))
517
+ 4
518
+ """
519
+ return with_assert(math.ceil(self.egglog_any_expr_value))
520
+
521
+ # TODO: https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers
522
+
523
+
524
+ class A(Expr):
525
+ def __init__(self, obj: object) -> None: ...
526
+
527
+ __match_args__ = ("egglog_any_expr_value",)
528
+
529
+ @method(preserve=True) # type: ignore[prop-decorator]
530
+ @property
531
+ def egglog_any_expr_value(self) -> object:
532
+ """
533
+ Return the underlying Python object, if it was constructued with one.
534
+
535
+ Long method name so it doesn't conflict with any user-defined properties.
536
+
537
+ >>> A(10).egglog_any_expr_value
538
+ 10
539
+ """
540
+ match get_callable_args(self, A):
541
+ case (PyObject(obj),):
542
+ return obj
543
+ raise ExprValueError(self, "A")
544
+
545
+ def __eq__(self, other: ALike) -> A: ... # type: ignore[override]
546
+ def __ne__(self, other: ALike) -> A: ... # type: ignore[override]
547
+ def __lt__(self, other: ALike) -> A: ...
548
+ def __le__(self, other: ALike) -> A: ...
549
+ def __gt__(self, other: ALike) -> A: ...
550
+ def __ge__(self, other: ALike) -> A: ...
551
+ def __getattr__(self, name: StringLike) -> A: ...
552
+ def __call__(self, args: ALike = (), kwargs: ALike = {}) -> A: ... # noqa: B006
553
+ def __getitem__(self, key: ALike) -> A: ...
554
+ def __setitem__(self, key: ALike, value: ALike) -> None: ...
555
+ def __delitem__(self, key: ALike) -> None: ...
556
+ def __add__(self, other: ALike) -> A: ...
557
+ def __sub__(self, other: ALike) -> A: ...
558
+ def __mul__(self, other: ALike) -> A: ...
559
+ def __matmul__(self, other: ALike) -> A: ...
560
+ def __truediv__(self, other: ALike) -> A: ...
561
+ def __floordiv__(self, other: ALike) -> A: ...
562
+ def __mod__(self, other: ALike) -> A: ...
563
+ def __divmod__(self, other: ALike) -> A: ...
564
+ def __pow__(self, other: ALike) -> A: ...
565
+ def __lshift__(self, other: ALike) -> A: ...
566
+ def __rshift__(self, other: ALike) -> A: ...
567
+ def __and__(self, other: ALike) -> A: ...
568
+ def __xor__(self, other: ALike) -> A: ...
569
+ def __or__(self, other: ALike) -> A: ...
570
+ def __neg__(self) -> A: ...
571
+ def __pos__(self) -> A: ...
572
+ def __abs__(self) -> A: ...
573
+ def __round__(self) -> A: ...
574
+ def __trunc__(self) -> A: ...
575
+ def __floor__(self) -> A: ...
576
+ def __ceil__(self) -> A: ...
577
+
578
+
579
+ converter(A, AnyExpr, AnyExpr)
580
+ converter(object, AnyExpr, lambda x: AnyExpr(A(PyObject(x))))
581
+
582
+ converter(AnyExpr, A, lambda a: a.egglog_any_expr_value, cost=10)
583
+ converter(PyObject, A, A, cost=10)
584
+ converter(object, A, lambda x: A(PyObject(x)), cost=10)
585
+
586
+ ALike: TypeAlias = A | object
587
+
588
+
589
+ @function()
590
+ def given(expr: ALike, condition: ALike) -> A: ...
591
+ @function
592
+ def bytes_(expr: ALike) -> A: ...
593
+ @function
594
+ def bool_(expr: ALike) -> A: ...
595
+ @function
596
+ def hasattr_(expr: ALike, name: StringLike) -> A: ...
597
+ @function
598
+ def getattr_eager(expr: ALike, name: StringLike) -> Unit:
599
+ """
600
+ Set if we should eagerly get the attribute.
601
+ """
602
+
603
+
604
+ @function
605
+ def len_(expr: ALike) -> A: ...
606
+ @function
607
+ def append(expr: ALike, item: ALike) -> A:
608
+ """
609
+ Appends an item to a tuple.
610
+ """
611
+
612
+
613
+ @function
614
+ def set_kwarg(expr: ALike, key: StringLike, value: ALike) -> A:
615
+ """
616
+ Sets a value in a dict with a string key
617
+ """
618
+
619
+
620
+ @function
621
+ def contains(expr: ALike, item: ALike) -> A: ...
622
+ @function
623
+ def complex_(expr: ALike) -> A: ...
624
+ @function
625
+ def int_(expr: ALike) -> A: ...
626
+ @function
627
+ def float_(expr: ALike) -> A: ...
628
+ @function
629
+ def index(expr: ALike) -> A: ...
630
+ @function
631
+ def slice_(start: ALike = None, stop: ALike = None, step: ALike = None) -> A: ...
632
+ @function
633
+ def list_(expr: ALike) -> A: ...
634
+ @function
635
+ def not_(expr: ALike) -> A: ...
636
+ @function
637
+ def and_(left: ALike, right: ALike) -> A: ...
638
+
639
+
640
+ # Special case containers so that Any expressions inside
641
+ converter(tuple, A, lambda x: reduce(append, x, A(())))
642
+ converter(list, A, lambda x: list_(tuple(x)))
643
+ converter(slice, A, lambda x: slice_(x.start, x.stop, x.step))
644
+
645
+
646
+ @ruleset
647
+ def any_expr_ruleset(x: PyObject, y: PyObject, z: PyObject, s: String, a: A):
648
+ yield rewrite(bytes_(A(x))).to(A(PyObject(bytes)(x)))
649
+ yield rewrite(bool_(A(x))).to(A(PyObject(bool)(x)))
650
+ yield rewrite(A(x) == A(y)).to(A(PyObject(operator.eq)(x, y)))
651
+ yield rewrite(A(x) != A(y)).to(A(PyObject(operator.ne)(x, y)))
652
+ yield rewrite(A(x) < A(y)).to(A(PyObject(operator.lt)(x, y)))
653
+ yield rewrite(A(x) <= A(y)).to(A(PyObject(operator.le)(x, y)))
654
+ yield rewrite(A(x) > A(y)).to(A(PyObject(operator.gt)(x, y)))
655
+ yield rewrite(A(x) >= A(y)).to(A(PyObject(operator.ge)(x, y)))
656
+ yield rewrite(A(x).__getattr__(s)).to(A(PyObject(getattr)(x, PyObject.from_string(s))))
657
+ yield rewrite(hasattr_(A(x), s)).to(A(PyObject(hasattr)(x, PyObject.from_string(s))))
658
+ yield rewrite(len_(A(x))).to(A(PyObject(len)(x)))
659
+ yield rewrite(A(x)(y, z)).to(A(x.call_extended(y, z)))
660
+ yield rewrite(append(A(x), A(y))).to(A(PyObject(lambda t, v: (*t, v))(x, y)))
661
+ yield rewrite(set_kwarg(A(x), s, A(y))).to(A(PyObject(lambda d, k, v: {**d, k: v})(x, PyObject.from_string(s), y)))
662
+ yield rewrite(A(x)[A(y)]).to(A(PyObject(operator.getitem)(x, y)))
663
+ setitem_any = A(x)
664
+ setitem_any[A(y)] = A(z)
665
+ yield rewrite(setitem_any).to(A(PyObject(lambda obj, k, v: operator.setitem(obj, k, v) or obj)(x, y, z)))
666
+ delitem_any = A(x)
667
+ del delitem_any[A(y)]
668
+ yield rewrite(delitem_any).to(A(PyObject(lambda obj, k: operator.delitem(obj, k) or obj)(x, y)))
669
+ yield rewrite(contains(A(x), A(y))).to(A(PyObject(operator.contains)(x, y)))
670
+ yield rewrite(A(x) + A(y)).to(A(PyObject(operator.add)(x, y)))
671
+ yield rewrite(A(x) - A(y)).to(A(PyObject(operator.sub)(x, y)))
672
+ yield rewrite(A(x) * A(y)).to(A(PyObject(operator.mul)(x, y)))
673
+ yield rewrite(A(x) @ A(y)).to(A(PyObject(operator.matmul)(x, y)))
674
+ yield rewrite(A(x) / A(y)).to(A(PyObject(operator.truediv)(x, y)))
675
+ yield rewrite(A(x) // A(y)).to(A(PyObject(operator.floordiv)(x, y)))
676
+ yield rewrite(A(x) % A(y)).to(A(PyObject(operator.mod)(x, y)))
677
+ yield rewrite(divmod(A(x), A(y))).to(A(PyObject(divmod)(x, y)))
678
+ yield rewrite(A(x) ** A(y)).to(A(PyObject(operator.pow)(x, y)))
679
+ yield rewrite(A(x) << A(y)).to(A(PyObject(operator.lshift)(x, y)))
680
+ yield rewrite(A(x) >> A(y)).to(A(PyObject(operator.rshift)(x, y)))
681
+ yield rewrite(A(x) & A(y)).to(A(PyObject(operator.and_)(x, y)))
682
+ yield rewrite(A(x) ^ A(y)).to(A(PyObject(operator.xor)(x, y)))
683
+ yield rewrite(A(x) | A(y)).to(A(PyObject(operator.or_)(x, y)))
684
+ yield rewrite(-A(x)).to(A(PyObject(operator.neg)(x)))
685
+ yield rewrite(+A(x)).to(A(PyObject(operator.pos)(x)))
686
+ yield rewrite(abs(A(x))).to(A(PyObject(operator.abs)(x)))
687
+ yield rewrite(complex_(A(x))).to(A(PyObject(complex)(x)))
688
+ yield rewrite(int_(A(x))).to(A(PyObject(int)(x)))
689
+ yield rewrite(float_(A(x))).to(A(PyObject(float)(x)))
690
+ yield rewrite(index(A(x))).to(A(PyObject(operator.index)(x)))
691
+ yield rewrite(round(A(x))).to(A(PyObject(round)(x)))
692
+ yield rewrite(math.trunc(A(x))).to(A(PyObject(math.trunc)(x)))
693
+ yield rewrite(math.floor(A(x))).to(A(PyObject(math.floor)(x)))
694
+ yield rewrite(math.ceil(A(x))).to(A(PyObject(math.ceil)(x)))
695
+ yield rewrite(list_(A(x))).to(A(PyObject(list)(x)))
696
+ yield rewrite(slice_(A(x), A(y), A(z))).to(A(PyObject(slice)(x, y, z)))
697
+
698
+ # Given
699
+ yield rewrite(given(A(x), a)).to(A(x))
700
+
701
+
702
+ any_expr_schedule = any_expr_ruleset.saturate()
703
+
704
+
705
+ def any_eval(self: A) -> Any:
706
+ """
707
+ Evaluate the AnyExpr to get its underlying Python value.
708
+
709
+ Runs rules if it's not already resolved
710
+ """
711
+ global _LAST_ASSERT
712
+ egraph = _get_current_egraph()
713
+ # 1. First see if it's already a primitive value
714
+ try:
715
+ return self.egglog_any_expr_value
716
+ except ExprValueError:
717
+ pass
718
+ # 2. If not, try to extract it from the egraph
719
+ expr = egraph.extract(self)
720
+ try:
721
+ res = expr.egglog_any_expr_value
722
+ except ExprValueError:
723
+ # 3. If that isn't one, then try running the schedule to extract it
724
+ egraph.register(expr)
725
+ egraph.run(any_expr_schedule)
726
+ expr = egraph.extract(expr)
727
+ res = expr.egglog_any_expr_value
728
+ # Don't save hasattr asserts
729
+ if get_callable_fn(self) != hasattr_:
730
+ # If we are calling bool_ same as just asserting vlaues
731
+ match get_callable_args(self, bool_):
732
+ case (A() as inner,):
733
+ self = inner
734
+ if eq(expr).to(A(True)):
735
+ asserted = self
736
+ _LAST_ASSERT = with_assert(self).egglog_any_expr_value
737
+ elif eq(expr).to(A(False)):
738
+ match get_callable_args(self, A.__eq__):
739
+ case (A() as left, A() as right):
740
+ asserted = left != right
741
+ case _:
742
+ match get_callable_args(self, A.__ne__):
743
+ case (A() as left, A() as right):
744
+ asserted = left == right
745
+ case _:
746
+ asserted = not_(self)
747
+ else:
748
+ asserted = self == expr
749
+ # _LAST_ASSERT = (
750
+ # asserted if _LAST_ASSERT is None or eq(_LAST_ASSERT).to(asserted) else and_(_LAST_ASSERT, asserted)
751
+ # )
752
+ _LAST_ASSERT = given(asserted, _LAST_ASSERT) if _LAST_ASSERT is not None else asserted
753
+ return res
754
+
755
+
756
+ _CURRENT_EGRAPH: None | EGraph = None
757
+ _LAST_ASSERT: None | A = None
758
+
759
+
760
+ @contextlib.contextmanager
761
+ def set_any_expr_egraph(egraph: EGraph) -> Iterator[None]:
762
+ """
763
+ Context manager that will set the current egraph. It will be set back after.
764
+ """
765
+ global _CURRENT_EGRAPH, _LAST_ASSERT
766
+ assert _CURRENT_EGRAPH is None
767
+ assert _LAST_ASSERT is None
768
+ _CURRENT_EGRAPH = egraph
769
+ try:
770
+ yield
771
+ finally:
772
+ _CURRENT_EGRAPH = None
773
+ _LAST_ASSERT = None
774
+
775
+
776
+ def _get_current_egraph() -> EGraph:
777
+ return _CURRENT_EGRAPH or EGraph()
778
+
779
+
780
+ def with_assert(expr: A) -> AnyExpr:
781
+ """
782
+ Add all current asserts to the given expression.
783
+
784
+ This is used to make sure that any_evaled expressions are consistent with
785
+ the current context.
786
+ """
787
+ if _CURRENT_EGRAPH and _LAST_ASSERT is not None: # noqa: SIM108
788
+ a = given(expr, _LAST_ASSERT)
789
+ # match get_callable_args(expr, given):
790
+ # case (A() as inner, A() as condition):
791
+ # a = expr if eq(condition).to(_LAST_ASSERT) else given(inner, and_(condition, _LAST_ASSERT))
792
+ # case _:
793
+
794
+ else:
795
+ a = expr
796
+ return AnyExpr(a)
797
+
798
+
799
+ @ruleset
800
+ def given_ruleset(x: A, y: A, z: A):
801
+ yield rewrite(not_(given(x, y)), subsume=True).to(given(not_(x), y))
802
+ yield rewrite(given(given(x, y), z), subsume=True).to(given(x, and_(y, z)))
803
+ yield rewrite(and_(x, x), subsume=True).to(x)
804
+
805
+
806
+ @function
807
+ def any_expr_program(x: AnyExpr) -> Program:
808
+ r"""
809
+ Convert an AnyExpr to a Program.
810
+
811
+ >>> any_expr_source(AnyExpr(42) == 10)
812
+ '(42 == 10)'
813
+ """
814
+
815
+
816
+ @function
817
+ def a_program(x: A) -> Program:
818
+ """
819
+ Convert an A to a Program.
820
+ """
821
+
822
+
823
+ def w(p: Program) -> Program:
824
+ return Program("(") + p + ")"
825
+
826
+
827
+ def ca(p: ProgramLike, *args: ProgramLike) -> Program:
828
+ args_expr = Program("")
829
+ for a in args[:-1]:
830
+ args_expr += a + ", "
831
+ args_expr += args[-1]
832
+ return convert(p, Program) + Program("(") + args_expr + Program(")")
833
+
834
+
835
+ INLINE_TYPES = int, str, float, bytes, bool, type(None), tuple, dict
836
+
837
+
838
+ @ruleset
839
+ def any_program_ruleset(a: A, b: A, c: A, p: PyObject, s: String):
840
+ yield rewrite(any_expr_program(AnyExpr(a)), subsume=True).to(a_program(a))
841
+
842
+ yield rewrite(a_program(A(p)), subsume=True).to(
843
+ Program(PyObject(repr)(p).to_string()),
844
+ PyObject(lambda x: isinstance(x, INLINE_TYPES))(p).to_bool() == Bool(True),
845
+ )
846
+ yield rewrite(a_program(A(p)), subsume=True).to(
847
+ Program(PyObject(repr)(p).to_string()).assign(),
848
+ PyObject(lambda x: isinstance(x, INLINE_TYPES))(p).to_bool() == Bool(False),
849
+ )
850
+ yield rewrite(a_program(bytes_(a)), subsume=True).to(a_program(a) + ".bytes()")
851
+ yield rewrite(a_program(bool_(a)), subsume=True).to(ca("bool", a_program(a)))
852
+ yield rewrite(a_program(a == b), subsume=True).to(w(a_program(a) + " == " + a_program(b)))
853
+ yield rewrite(a_program(a != b), subsume=True).to(w(a_program(a) + " != " + a_program(b)))
854
+ yield rewrite(a_program(a < b), subsume=True).to(w(a_program(a) + " < " + a_program(b)))
855
+ yield rewrite(a_program(a <= b), subsume=True).to(w(a_program(a) + " <= " + a_program(b)))
856
+ yield rewrite(a_program(a > b), subsume=True).to(w(a_program(a) + " > " + a_program(b)))
857
+ yield rewrite(a_program(a >= b), subsume=True).to(w(a_program(a) + " >= " + a_program(b)))
858
+ yield rewrite(a_program(a.__getattr__(s)), subsume=True).to(a_program(a) + "." + s)
859
+ yield rewrite(a_program(hasattr_(a, s)), subsume=True).to(
860
+ ca("hasattr", a_program(a), PyObject(repr)(PyObject.from_string(s)).to_string())
861
+ )
862
+ yield rewrite(a_program(len_(a)), subsume=True).to(ca("len", a_program(a)))
863
+ yield rewrite(a_program(a(b, c)), subsume=True).to(
864
+ ca(a_program(a), "*" + a_program(b), "**" + a_program(c)).assign()
865
+ )
866
+ yield rewrite(a_program(append(a, b)), subsume=True).to(ca("", "*" + a_program(a), a_program(b)))
867
+ yield rewrite(a_program(set_kwarg(a, s, b)), subsume=True).to(
868
+ "{**" + a_program(a) + ", " + PyObject(repr)(PyObject.from_string(s)).to_string() + ": " + a_program(b) + "}"
869
+ )
870
+ yield rewrite(a_program(a[b]), subsume=True).to(a_program(a) + "[" + a_program(b) + "]")
871
+ assigned_a = a_program(a).assign()
872
+ setitem_a = copy(a)
873
+ setitem_a[b] = c
874
+ yield rewrite(a_program(setitem_a), subsume=True).to(
875
+ assigned_a.statement(assigned_a + "[" + a_program(b) + "] = " + a_program(c))
876
+ )
877
+ delitem_a = copy(a)
878
+ del delitem_a[b]
879
+ yield rewrite(a_program(delitem_a), subsume=True).to(
880
+ assigned_a.statement("del " + assigned_a + "[" + a_program(b) + "]")
881
+ )
882
+ yield rewrite(a_program(contains(a, b)), subsume=True).to(w(a_program(a) + " in " + a_program(b)))
883
+ yield rewrite(a_program(a + b), subsume=True).to(w(a_program(a) + " + " + a_program(b)))
884
+ yield rewrite(a_program(a - b), subsume=True).to(w(a_program(a) + " - " + a_program(b)))
885
+ yield rewrite(a_program(a * b), subsume=True).to(w(a_program(a) + " * " + a_program(b)))
886
+ yield rewrite(a_program(a @ b), subsume=True).to(w(a_program(a) + " @ " + a_program(b)))
887
+ yield rewrite(a_program(a / b), subsume=True).to(w(a_program(a) + " / " + a_program(b)))
888
+ yield rewrite(a_program(a // b), subsume=True).to(w(a_program(a) + " // " + a_program(b)))
889
+ yield rewrite(a_program(a % b), subsume=True).to(w(a_program(a) + " % " + a_program(b)))
890
+ yield rewrite(a_program(divmod(a, b)), subsume=True).to(ca("divmod", a_program(a), a_program(b)))
891
+ yield rewrite(a_program(a**b), subsume=True).to(w(a_program(a) + " ** " + a_program(b)))
892
+ yield rewrite(a_program(a << b), subsume=True).to(w(a_program(a) + " << " + a_program(b)))
893
+ yield rewrite(a_program(a >> b), subsume=True).to(w(a_program(a) + " >> " + a_program(b)))
894
+ yield rewrite(a_program(a & b), subsume=True).to(w(a_program(a) + " & " + a_program(b)))
895
+ yield rewrite(a_program(a ^ b), subsume=True).to(w(a_program(a) + " ^ " + a_program(b)))
896
+ yield rewrite(a_program(a | b), subsume=True).to(w(a_program(a) + " | " + a_program(b)))
897
+ yield rewrite(a_program(-a), subsume=True).to("-" + a_program(a))
898
+ yield rewrite(a_program(+a), subsume=True).to("+" + a_program(a))
899
+ yield rewrite(a_program(abs(a)), subsume=True).to(ca("abs", a_program(a)))
900
+ yield rewrite(a_program(complex_(a)), subsume=True).to(ca("complex", a_program(a)))
901
+ yield rewrite(a_program(int_(a)), subsume=True).to(ca("int", a_program(a)))
902
+ yield rewrite(a_program(float_(a)), subsume=True).to(ca("float", a_program(a)))
903
+ yield rewrite(a_program(index(a)), subsume=True).to(ca("operator.index", a_program(a)))
904
+ yield rewrite(a_program(round(a)), subsume=True).to(ca("round", a_program(a)))
905
+ yield rewrite(a_program(math.trunc(a)), subsume=True).to(ca("math.trunc", a_program(a)))
906
+ yield rewrite(a_program(math.floor(a)), subsume=True).to(ca("math.floor", a_program(a)))
907
+ yield rewrite(a_program(math.ceil(a)), subsume=True).to(ca("math.ceil", a_program(a)))
908
+ yield rewrite(a_program(list_(a)), subsume=True).to(ca("list", a_program(a)))
909
+ yield rewrite(a_program(slice_(a, b, c)), subsume=True).to(ca("slice", a_program(a), a_program(b), a_program(c)))
910
+
911
+ yield rewrite(a_program(not_(a)), subsume=True).to(w("not " + a_program(a)))
912
+ yield rewrite(a_program(and_(a, b)), subsume=True).to(w(a_program(a) + " and " + a_program(b)))
913
+ # # Given
914
+ yield rewrite(a_program(given(a, b)), subsume=True).to(a_program(a).statement("assert " + a_program(b)))
915
+
916
+
917
+ any_program_schedule = any_program_ruleset.saturate() + program_gen_ruleset.saturate()
918
+
919
+
920
+ def any_expr_source(x: AnyExpr) -> str:
921
+ x = x.egglog_any_expr_value
922
+ # print(x)
923
+ program = a_program(x)
924
+ # print("program", program)
925
+ egraph = EGraph()
926
+ # program = egraph.let("program", program)
927
+ egraph.register(program)
928
+ egraph.run(any_program_ruleset.saturate())
929
+ res_program = egraph.extract(program)
930
+ egraph = EGraph()
931
+ egraph.register(res_program.compile())
932
+ egraph.run(program_gen_ruleset.saturate())
933
+ # print(egraph.extract(program))
934
+ # while egraph.run(any_program_ruleset).updated:
935
+ # print(egraph.extract(program))
936
+ # print("extracted", egraph.extract(program))
937
+ # egraph.run(program_gen_ruleset.saturate())
938
+ res = join(res_program.statements, res_program.expr)
939
+ return egraph.extract(res).value
940
+ # egraph.display()
941
+ # return black.format_str(str_res, mode=black.Mode()).strip()
942
+
943
+
944
+ x = AnyExpr([42])
945
+
946
+
947
+ print(x[0] + 10)