egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +887 -0
  4. egglog/builtins.py +1144 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +290 -0
  7. egglog/declarations.py +964 -0
  8. egglog/deconstruct.py +176 -0
  9. egglog/egraph.py +2247 -0
  10. egglog/egraph_state.py +978 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +64 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/MoA.ipynb +617 -0
  26. egglog/exp/__init__.py +3 -0
  27. egglog/exp/any_expr.py +947 -0
  28. egglog/exp/any_expr_example.ipynb +408 -0
  29. egglog/exp/array_api.py +2019 -0
  30. egglog/exp/array_api_jit.py +51 -0
  31. egglog/exp/array_api_loopnest.py +74 -0
  32. egglog/exp/array_api_numba.py +69 -0
  33. egglog/exp/array_api_program_gen.py +510 -0
  34. egglog/exp/program_gen.py +427 -0
  35. egglog/exp/siu_examples.py +32 -0
  36. egglog/ipython_magic.py +41 -0
  37. egglog/pretty.py +566 -0
  38. egglog/py.typed +0 -0
  39. egglog/runtime.py +888 -0
  40. egglog/thunk.py +97 -0
  41. egglog/type_constraint_solver.py +111 -0
  42. egglog/visualizer.css +1 -0
  43. egglog/visualizer.js +35798 -0
  44. egglog/visualizer_widget.py +39 -0
  45. egglog-12.0.0.dist-info/METADATA +93 -0
  46. egglog-12.0.0.dist-info/RECORD +48 -0
  47. egglog-12.0.0.dist-info/WHEEL +5 -0
  48. egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/exp/MoA.ipynb ADDED
@@ -0,0 +1,617 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "922a695b",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Mathematics of Arrays in Egglog\n",
9
+ "\n",
10
+ "\n",
11
+ "This notebook shows how if you define array operations as higher order functions, we can compose them and end up with a simpler algebra that just uses boolean and integers and functions.\n",
12
+ "\n",
13
+ "We take as our input this MoA program, defined in [the PSI compiler](https://saulshanabrook.github.io/psi-compiler/src/):\n",
14
+ "\n",
15
+ "\n",
16
+ "```\n",
17
+ "main ()\n",
18
+ "\n",
19
+ "{\n",
20
+ " array Amts^3 <2 3 4>;\n",
21
+ " array Ams^3 <2 3 4>;\n",
22
+ " const array RAMY^3 <2 3 4>=<1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 \n",
23
+ "\t\t\t\t11 12>;\n",
24
+ " const array AMY^3 <2 3 4>=<9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9>;\n",
25
+ " Amts=<2> take (<2> drop (RAMY cat AMY));\n",
26
+ "}\n",
27
+ "```\n",
28
+ "\n",
29
+ "This result `Amts` is equivalent to `AMY`, since we are concatenating `RAMY` and `AMY` along the first axis, dropping the first 2 elements (which removes all of `RAMY`), and then taking the next 2 elements (which is all of `AMY`).\n",
30
+ "\n",
31
+ "Compiling it produces this C program which copies AMY into Amts:\n",
32
+ "\n",
33
+ "```c\n",
34
+ "#include <stdlib.h>\n",
35
+ "#include \"moalib.e\"\n",
36
+ "\n",
37
+ "main()\n",
38
+ "\n",
39
+ "{\n",
40
+ " double *offset0;\n",
41
+ " int i0;\n",
42
+ " int i1;\n",
43
+ " int i2;\n",
44
+ " double *shift;\n",
45
+ " double _RAMY[]={1.000000, 2.000000, 3.000000, 4.000000, 5.000000,\n",
46
+ " 6.000000, 7.000000, 8.000000, 9.000000, 10.000000,\n",
47
+ " 11.000000, 12.000000, 1.000000, 2.000000, 3.000000,\n",
48
+ " 4.000000, 5.000000, 6.000000, 7.000000, 8.000000,\n",
49
+ " 9.000000, 10.000000, 11.000000, 12.000000};\n",
50
+ " double _AMY[]={9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n",
51
+ " 9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n",
52
+ " 9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n",
53
+ " 9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n",
54
+ " 9.000000, 9.000000, 9.000000, 9.000000};\n",
55
+ " double _Y[]={8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n",
56
+ " 8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n",
57
+ " 8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n",
58
+ " 8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n",
59
+ " 8.000000, 8.000000, 8.000000, 8.000000};\n",
60
+ " double _V[]={1.000000, 1.000000};\n",
61
+ " double _Amts[2*3*4];\n",
62
+ "\n",
63
+ "/*******\n",
64
+ "Amts=<2.000000> take (<2.000000> drop (RAMY cat AMY))\n",
65
+ "********/\n",
66
+ "\n",
67
+ " shift=_Amts+0*12+0*4+0;\n",
68
+ " offset0=_AMY+0*12+0*4+0;\n",
69
+ " for (i0=0; i0<2; i0++) {\n",
70
+ " for (i1=0; i1<3; i1++) {\n",
71
+ " for (i2=0; i2<4; i2++) {\n",
72
+ " *(shift)= *(offset0);\n",
73
+ " offset0+=1;\n",
74
+ " shift+=1;\n",
75
+ " }\n",
76
+ " }\n",
77
+ " }\n",
78
+ "```\n",
79
+ "\n",
80
+ "What we want to show here is not the full compilation into C and into loops, but just the fact that by defining each array operation as a higher order function, we can compose them and end up with a simpler algebra that just uses boolean and integers and functions. This could then be compiled into loops. The hypothesis here is that we don't *lose* any information by erasing the `take`, `drop`, and `cat` operations and replacing them with their definitions in terms of functions.\n"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 27,
86
+ "id": "1b715c58",
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "data": {
91
+ "text/html": [
92
+ "<style>pre { line-height: 125%; }\n",
93
+ "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
94
+ "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
95
+ "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
96
+ "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
97
+ ".output_html .hll { background-color: #ffffcc }\n",
98
+ ".output_html { background: #f8f8f8; }\n",
99
+ ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n",
100
+ ".output_html .err { border: 1px solid #F00 } /* Error */\n",
101
+ ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n",
102
+ ".output_html .o { color: #666 } /* Operator */\n",
103
+ ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n",
104
+ ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n",
105
+ ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n",
106
+ ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n",
107
+ ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n",
108
+ ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n",
109
+ ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n",
110
+ ".output_html .ge { font-style: italic } /* Generic.Emph */\n",
111
+ ".output_html .ges { font-weight: bold; font-style: italic } /* Generic.EmphStrong */\n",
112
+ ".output_html .gr { color: #E40000 } /* Generic.Error */\n",
113
+ ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n",
114
+ ".output_html .gi { color: #008400 } /* Generic.Inserted */\n",
115
+ ".output_html .go { color: #717171 } /* Generic.Output */\n",
116
+ ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n",
117
+ ".output_html .gs { font-weight: bold } /* Generic.Strong */\n",
118
+ ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n",
119
+ ".output_html .gt { color: #04D } /* Generic.Traceback */\n",
120
+ ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n",
121
+ ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n",
122
+ ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n",
123
+ ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n",
124
+ ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n",
125
+ ".output_html .kt { color: #B00040 } /* Keyword.Type */\n",
126
+ ".output_html .m { color: #666 } /* Literal.Number */\n",
127
+ ".output_html .s { color: #BA2121 } /* Literal.String */\n",
128
+ ".output_html .na { color: #687822 } /* Name.Attribute */\n",
129
+ ".output_html .nb { color: #008000 } /* Name.Builtin */\n",
130
+ ".output_html .nc { color: #00F; font-weight: bold } /* Name.Class */\n",
131
+ ".output_html .no { color: #800 } /* Name.Constant */\n",
132
+ ".output_html .nd { color: #A2F } /* Name.Decorator */\n",
133
+ ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n",
134
+ ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n",
135
+ ".output_html .nf { color: #00F } /* Name.Function */\n",
136
+ ".output_html .nl { color: #767600 } /* Name.Label */\n",
137
+ ".output_html .nn { color: #00F; font-weight: bold } /* Name.Namespace */\n",
138
+ ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n",
139
+ ".output_html .nv { color: #19177C } /* Name.Variable */\n",
140
+ ".output_html .ow { color: #A2F; font-weight: bold } /* Operator.Word */\n",
141
+ ".output_html .w { color: #BBB } /* Text.Whitespace */\n",
142
+ ".output_html .mb { color: #666 } /* Literal.Number.Bin */\n",
143
+ ".output_html .mf { color: #666 } /* Literal.Number.Float */\n",
144
+ ".output_html .mh { color: #666 } /* Literal.Number.Hex */\n",
145
+ ".output_html .mi { color: #666 } /* Literal.Number.Integer */\n",
146
+ ".output_html .mo { color: #666 } /* Literal.Number.Oct */\n",
147
+ ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n",
148
+ ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n",
149
+ ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n",
150
+ ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n",
151
+ ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n",
152
+ ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n",
153
+ ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n",
154
+ ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n",
155
+ ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n",
156
+ ".output_html .sx { color: #008000 } /* Literal.String.Other */\n",
157
+ ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n",
158
+ ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n",
159
+ ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n",
160
+ ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n",
161
+ ".output_html .fm { color: #00F } /* Name.Function.Magic */\n",
162
+ ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n",
163
+ ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n",
164
+ ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n",
165
+ ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n",
166
+ ".output_html .il { color: #666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"n\">take</span><span class=\"p\">(</span><span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">2</span><span class=\"p\">),</span> <span class=\"n\">drop</span><span class=\"p\">(</span><span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">2</span><span class=\"p\">),</span> <span class=\"n\">cat</span><span class=\"p\">(</span><span class=\"n\">NDArray</span><span class=\"o\">.</span><span class=\"n\">from_memory</span><span class=\"p\">(</span><span class=\"n\">TupleInt</span><span class=\"o\">.</span><span class=\"n\">from_vec</span><span class=\"p\">(</span><span class=\"n\">Vec</span><span class=\"p\">(</span><span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">2</span><span class=\"p\">),</span> <span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">3</span><span class=\"p\">),</span> <span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">4</span><span class=\"p\">))),</span> <span class=\"n\">RAMY</span><span class=\"p\">),</span> <span class=\"n\">NDArray</span><span class=\"o\">.</span><span class=\"n\">from_memory</span><span class=\"p\">(</span><span class=\"n\">TupleInt</span><span class=\"o\">.</span><span class=\"n\">from_vec</span><span class=\"p\">(</span><span class=\"n\">Vec</span><span class=\"p\">(</span><span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">2</span><span class=\"p\">),</span> <span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">3</span><span class=\"p\">),</span> <span class=\"n\">Int</span><span class=\"p\">(</span><span class=\"mi\">4</span><span class=\"p\">))),</span> <span class=\"n\">AMY</span><span class=\"p\">))))</span>\n",
167
+ "</pre></div>\n"
168
+ ],
169
+ "text/latex": [
170
+ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
171
+ "\\PY{n}{take}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{drop}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{cat}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}memory}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{3}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{RAMY}\\PY{p}{)}\\PY{p}{,} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}memory}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{3}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{AMY}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
172
+ "\\end{Verbatim}\n"
173
+ ],
174
+ "text/plain": [
175
+ "take(Int(2), drop(Int(2), cat(NDArray.from_memory(TupleInt.from_vec(Vec(Int(2), Int(3), Int(4))), RAMY), NDArray.from_memory(TupleInt.from_vec(Vec(Int(2), Int(3), Int(4))), AMY))))"
176
+ ]
177
+ },
178
+ "metadata": {},
179
+ "output_type": "display_data"
180
+ }
181
+ ],
182
+ "source": [
183
+ "from __future__ import annotations\n",
184
+ "\n",
185
+ "from collections.abc import Callable\n",
186
+ "\n",
187
+ "from egglog import *\n",
188
+ "\n",
189
+ "array_ruleset = ruleset(name=\"array_ruleset\")\n",
190
+ "\n",
191
+ "\n",
192
+ "class Boolean(Expr):\n",
193
+ " def __init__(self, val: BoolLike) -> None: ...\n",
194
+ " def if_bool(self, then: Int, else_: Int) -> Int: ...\n",
195
+ "\n",
196
+ "\n",
197
+ "class Int(Expr):\n",
198
+ " def __init__(self, val: i64Like) -> None: ...\n",
199
+ " def __eq__(self, other: Int) -> Boolean: ... # type: ignore[override]\n",
200
+ " def __lt__(self, other: Int) -> Boolean: ...\n",
201
+ " def __add__(self, other: Int) -> Int: ...\n",
202
+ " def __sub__(self, other: Int) -> Int: ...\n",
203
+ " def __mul__(self, other: Int) -> Int: ...\n",
204
+ "\n",
205
+ "\n",
206
+ "@array_ruleset.register\n",
207
+ "def _int(i: i64, j: i64, x: Int, y: Int):\n",
208
+ " yield rewrite(Int(i) + Int(j)).to(Int(i + j))\n",
209
+ " yield rewrite(Int(i) - Int(j)).to(Int(i - j))\n",
210
+ " yield rewrite(Int(i) * Int(j)).to(Int(i * j))\n",
211
+ " yield rewrite(Int(i) == Int(i)).to(Boolean(True))\n",
212
+ " yield rewrite(Int(i) == Int(j)).to(Boolean(False), i != j)\n",
213
+ " yield rewrite(Int(i) < Int(j)).to(Boolean(True), i < j)\n",
214
+ " yield rewrite(Int(i) < Int(j)).to(Boolean(False), i >= j)\n",
215
+ " yield rewrite(Boolean(True).if_bool(x, y)).to(x)\n",
216
+ " yield rewrite(Boolean(False).if_bool(x, y)).to(y)\n",
217
+ "\n",
218
+ "\n",
219
+ "@function\n",
220
+ "def vec_index(vec: Vec[Int], index: Int) -> Int: ...\n",
221
+ "\n",
222
+ "\n",
223
+ "@array_ruleset.register\n",
224
+ "def _vec_index(i: i64, xs: Vec[Int]):\n",
225
+ " yield rewrite(vec_index(xs, Int(i))).to(xs[i])\n",
226
+ "\n",
227
+ "\n",
228
+ "class TupleInt(Expr, ruleset=array_ruleset):\n",
229
+ " def __init__(self, length: Int, getitem_fn: Callable[[Int], Int]) -> None: ...\n",
230
+ " def __getitem__(self, index: Int) -> Int: ...\n",
231
+ "\n",
232
+ " @property\n",
233
+ " def length(self) -> Int: ...\n",
234
+ "\n",
235
+ " @classmethod\n",
236
+ " def from_vec(cls, xs: Vec[Int]) -> TupleInt:\n",
237
+ " return TupleInt(\n",
238
+ " Int(xs.length()),\n",
239
+ " lambda i: vec_index(xs, i),\n",
240
+ " )\n",
241
+ "\n",
242
+ "\n",
243
+ "@array_ruleset.register\n",
244
+ "def _tuple_int(l: Int, fn: Callable[[Int], Int], i: Int):\n",
245
+ " ti = TupleInt(l, fn)\n",
246
+ " yield rewrite(ti.length).to(l)\n",
247
+ " yield rewrite(ti[i]).to(fn(i))\n",
248
+ "\n",
249
+ "\n",
250
+ "class NDArray(Expr, ruleset=array_ruleset):\n",
251
+ " def __init__(self, shape: TupleInt, idx_fn: Callable[[TupleInt], Int]) -> None: ...\n",
252
+ "\n",
253
+ " @classmethod\n",
254
+ " def from_memory(cls, shape: TupleInt, values: TupleInt) -> NDArray:\n",
255
+ " # Only work on ndim = 3 for now\n",
256
+ " return NDArray(\n",
257
+ " shape,\n",
258
+ " lambda idx: values[\n",
259
+ " idx[Int(0)] * (shape[Int(1)] * shape[Int(2)]) + idx[Int(1)] * shape[Int(2)] + idx[Int(2)]\n",
260
+ " ],\n",
261
+ " )\n",
262
+ "\n",
263
+ " @property\n",
264
+ " def shape(self) -> TupleInt: ...\n",
265
+ "\n",
266
+ " def __getitem__(self, index: TupleInt) -> Int: ...\n",
267
+ "\n",
268
+ "\n",
269
+ "@array_ruleset.register\n",
270
+ "def _ndarray(shape: TupleInt, fn: Callable[[TupleInt], Int], idx: TupleInt):\n",
271
+ " nda = NDArray(shape, fn)\n",
272
+ " yield rewrite(nda.shape).to(shape)\n",
273
+ " yield rewrite(nda[idx]).to(fn(idx))\n",
274
+ "\n",
275
+ "\n",
276
+ "@function(subsume=True, ruleset=array_ruleset)\n",
277
+ "def cat(l: NDArray, r: NDArray) -> NDArray:\n",
278
+ " \"\"\"\n",
279
+ " Returns the concatenation of two arrays, they should have the same shape and the first dimension is added.\n",
280
+ " \"\"\"\n",
281
+ " return NDArray(\n",
282
+ " TupleInt(\n",
283
+ " l.shape.length,\n",
284
+ " lambda i: (i == Int(0)).if_bool(l.shape[Int(0)] + r.shape[Int(0)], l.shape[i]),\n",
285
+ " ),\n",
286
+ " lambda idx: (idx[Int(0)] < l.shape[Int(0)]).if_bool(\n",
287
+ " l[idx], r[TupleInt(r.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - l.shape[Int(0)], idx[i]))]\n",
288
+ " ),\n",
289
+ " )\n",
290
+ "\n",
291
+ "\n",
292
+ "@function(subsume=True, ruleset=array_ruleset)\n",
293
+ "def drop(x: Int, arr: NDArray) -> NDArray:\n",
294
+ " \"\"\"\n",
295
+ " Drops the first `x` elements off the front of the array `arr`.\n",
296
+ " \"\"\"\n",
297
+ " return NDArray(\n",
298
+ " TupleInt(\n",
299
+ " arr.shape.length,\n",
300
+ " lambda i: (i == Int(0)).if_bool(arr.shape[Int(0)] - x, arr.shape[i]),\n",
301
+ " ),\n",
302
+ " lambda idx: arr[\n",
303
+ " TupleInt(\n",
304
+ " arr.shape.length,\n",
305
+ " # Add x to the first index, so it skips the first x elements\n",
306
+ " lambda i: (i == Int(0)).if_bool(idx[Int(0)] + x, idx[i]),\n",
307
+ " )\n",
308
+ " ],\n",
309
+ " )\n",
310
+ "\n",
311
+ "\n",
312
+ "@function(subsume=True, ruleset=array_ruleset)\n",
313
+ "def take(x: Int, arr: NDArray) -> NDArray:\n",
314
+ " \"\"\"\n",
315
+ " Takes the first `x` elements off the front of the array `arr`.\n",
316
+ " \"\"\"\n",
317
+ " return NDArray(\n",
318
+ " TupleInt(\n",
319
+ " arr.shape.length,\n",
320
+ " lambda i: (i == Int(0)).if_bool(x, arr.shape[i]),\n",
321
+ " ),\n",
322
+ " lambda idx: arr[idx],\n",
323
+ " )\n",
324
+ "\n",
325
+ "\n",
326
+ "shape = TupleInt.from_vec(Vec(Int(2), Int(3), Int(4)))\n",
327
+ "RAMY = NDArray.from_memory(shape, constant(\"RAMY\", TupleInt))\n",
328
+ "AMY = NDArray.from_memory(shape, constant(\"AMY\", TupleInt))\n",
329
+ "Amts = take(Int(2), drop(Int(2), cat(RAMY, AMY)))\n",
330
+ "Amts"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 28,
336
+ "id": "1ada95b6",
337
+ "metadata": {},
338
+ "outputs": [
339
+ {
340
+ "name": "stdout",
341
+ "output_type": "stream",
342
+ "text": [
343
+ "Amts.shape.length()=Int(3)\n",
344
+ "Amts.shape[0]=Int(2)\n",
345
+ "Amts.shape[1]=Int(3)\n",
346
+ "Amts.shape[2]=Int(4)\n",
347
+ "Amts[i, j, k]=((i + Int(2)) < Int(2)).if_bool(RAMY[(((i + Int(2)) * Int(12)) + (j * Int(4))) + k], AMY[((((i + Int(2)) - Int(2)) * Int(12)) + (j * Int(4))) + k])\n",
348
+ "AMY[i, j, k]=AMY[((i * Int(12)) + (j * Int(4))) + k]\n"
349
+ ]
350
+ }
351
+ ],
352
+ "source": [
353
+ "egraph = EGraph()\n",
354
+ "ndim = egraph.let(\"ndim\", Amts.shape.length)\n",
355
+ "shape_1 = egraph.let(\"shape_1\", Amts.shape[Int(0)])\n",
356
+ "shape_2 = egraph.let(\"shape_2\", Amts.shape[Int(1)])\n",
357
+ "shape_3 = egraph.let(\"shape_3\", Amts.shape[Int(2)])\n",
358
+ "idxs = TupleInt.from_vec(Vec(constant(\"i\", Int), constant(\"j\", Int), constant(\"k\", Int)))\n",
359
+ "idxed = egraph.let(\"idxed\", Amts[idxs])\n",
360
+ "amy_idxed = egraph.let(\"amy_idxed\", AMY[idxs])\n",
361
+ "\n",
362
+ "egraph.run(array_ruleset.saturate())\n",
363
+ "print(f\"Amts.shape.length()={egraph.extract(ndim)}\")\n",
364
+ "print(f\"Amts.shape[0]={egraph.extract(shape_1)}\")\n",
365
+ "print(f\"Amts.shape[1]={egraph.extract(shape_2)}\")\n",
366
+ "print(f\"Amts.shape[2]={egraph.extract(shape_3)}\")\n",
367
+ "print(f\"Amts[i, j, k]={egraph.extract(idxed)}\")\n",
368
+ "print(f\"AMY[i, j, k]={egraph.extract(amy_idxed)}\")"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "id": "e3dfbd1f",
374
+ "metadata": {},
375
+ "source": [
376
+ "We can see that Amts is equal to AMY, since they have the shape and indexing them produces the same result.\n",
377
+ "\n",
378
+ "With some basic range analysis we could make them simplify to the same expression in the e-graph as well."
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "markdown",
383
+ "id": "5a232786",
384
+ "metadata": {},
385
+ "source": [
386
+ "If we want, we can also see all the intermediate steps to get to the indexed result."
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "markdown",
391
+ "id": "326942be",
392
+ "metadata": {},
393
+ "source": []
394
+ },
395
+ {
396
+ "cell_type": "markdown",
397
+ "id": "a56c640a",
398
+ "metadata": {},
399
+ "source": []
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": 29,
404
+ "id": "c7b757ff",
405
+ "metadata": {},
406
+ "outputs": [
407
+ {
408
+ "name": "stdout",
409
+ "output_type": "stream",
410
+ "text": [
411
+ "take(\n",
412
+ " Int(2),\n",
413
+ " drop(\n",
414
+ " Int(2), cat(NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY), NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY))\n",
415
+ " ),\n",
416
+ ")[TupleInt.from_vec(Vec[Int](i, j, k))] \n",
417
+ "\n",
418
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
419
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
420
+ "_NDArray_3 = NDArray(\n",
421
+ " TupleInt(_NDArray_1.shape.length, lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
422
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
423
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
424
+ " ),\n",
425
+ ")\n",
426
+ "_NDArray_4 = NDArray(\n",
427
+ " TupleInt(_NDArray_3.shape.length, lambda i: (i == Int(0)).if_bool(_NDArray_3.shape[Int(0)] - Int(2), _NDArray_3.shape[i])),\n",
428
+ " lambda idx: _NDArray_3[TupleInt(_NDArray_3.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + Int(2), idx[i]))],\n",
429
+ ")\n",
430
+ "NDArray(TupleInt(_NDArray_4.shape.length, lambda i: (i == Int(0)).if_bool(Int(2), _NDArray_4.shape[i])), lambda idx: _NDArray_4[idx])[TupleInt.from_vec(Vec[Int](i, j, k))] \n",
431
+ "\n",
432
+ "_TupleInt_1 = TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4)))\n",
433
+ "_TupleInt_2 = TupleInt(\n",
434
+ " _TupleInt_1.length,\n",
435
+ " lambda i: (i == Int(0)).if_bool(\n",
436
+ " NDArray.from_memory(_TupleInt_1, RAMY).shape[Int(0)] + NDArray.from_memory(_TupleInt_1, AMY).shape[Int(0)], NDArray.from_memory(_TupleInt_1, RAMY).shape[i]\n",
437
+ " ),\n",
438
+ ")\n",
439
+ "_NDArray_1 = NDArray(\n",
440
+ " _TupleInt_2,\n",
441
+ " lambda idx: (idx[Int(0)] < NDArray.from_memory(_TupleInt_1, RAMY).shape[Int(0)]).if_bool(\n",
442
+ " NDArray.from_memory(_TupleInt_1, RAMY)[idx],\n",
443
+ " NDArray.from_memory(_TupleInt_1, AMY)[\n",
444
+ " TupleInt(\n",
445
+ " NDArray.from_memory(_TupleInt_1, AMY).shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - NDArray.from_memory(_TupleInt_1, RAMY).shape[Int(0)], idx[i])\n",
446
+ " )\n",
447
+ " ],\n",
448
+ " ),\n",
449
+ ")\n",
450
+ "(lambda arr, idx: arr[idx])(\n",
451
+ " NDArray(\n",
452
+ " TupleInt(_TupleInt_2.length, lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] - Int(2), _NDArray_1.shape[i])),\n",
453
+ " lambda idx: _NDArray_1[TupleInt(_NDArray_1.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + Int(2), idx[i]))],\n",
454
+ " ),\n",
455
+ " TupleInt.from_vec(Vec[Int](i, j, k)),\n",
456
+ ") \n",
457
+ "\n",
458
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
459
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
460
+ "_NDArray_3 = NDArray(\n",
461
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
462
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
463
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
464
+ " ),\n",
465
+ ")\n",
466
+ "NDArray(\n",
467
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_3.shape[Int(0)] - Int(2), _NDArray_3.shape[i])),\n",
468
+ " lambda idx: _NDArray_3[TupleInt(_NDArray_3.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + Int(2), idx[i]))],\n",
469
+ ")[TupleInt.from_vec(Vec[Int](i, j, k))] \n",
470
+ "\n",
471
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
472
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
473
+ "(lambda arr, x, idx: arr[TupleInt(arr.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + x, idx[i]))])(\n",
474
+ " NDArray(\n",
475
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
476
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
477
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
478
+ " ),\n",
479
+ " ),\n",
480
+ " Int(2),\n",
481
+ " TupleInt.from_vec(Vec[Int](i, j, k)),\n",
482
+ ") \n",
483
+ "\n",
484
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
485
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
486
+ "NDArray(\n",
487
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
488
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
489
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
490
+ " ),\n",
491
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
492
+ "\n",
493
+ "(lambda l, r, idx: (idx[Int(0)] < l.shape[Int(0)]).if_bool(l[idx], r[TupleInt(r.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - l.shape[Int(0)], idx[i]))]))(\n",
494
+ " NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY),\n",
495
+ " NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY),\n",
496
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i])),\n",
497
+ ") \n",
498
+ "\n",
499
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
500
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
501
+ "NDArray(\n",
502
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
503
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
504
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
505
+ " ),\n",
506
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
507
+ "\n",
508
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
509
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
510
+ "NDArray(\n",
511
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
512
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
513
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
514
+ " ),\n",
515
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
516
+ "\n",
517
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
518
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
519
+ "NDArray(\n",
520
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
521
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
522
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
523
+ " ),\n",
524
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
525
+ "\n",
526
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
527
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
528
+ "NDArray(\n",
529
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
530
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
531
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
532
+ " ),\n",
533
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
534
+ "\n",
535
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
536
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
537
+ "NDArray(\n",
538
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
539
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
540
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
541
+ " ),\n",
542
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
543
+ "\n",
544
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
545
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
546
+ "NDArray(\n",
547
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
548
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
549
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
550
+ " ),\n",
551
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
552
+ "\n",
553
+ "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n",
554
+ "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n",
555
+ "NDArray(\n",
556
+ " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n",
557
+ " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n",
558
+ " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n",
559
+ " ),\n",
560
+ ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n",
561
+ "\n",
562
+ "((i + Int(2)) < Int(2)).if_bool(RAMY[(((i + Int(2)) * Int(12)) + (j * Int(4))) + k], AMY[((((i + Int(2)) - Int(2)) * Int(12)) + (j * Int(4))) + k]) \n",
563
+ "\n"
564
+ ]
565
+ },
566
+ {
567
+ "data": {
568
+ "application/vnd.jupyter.widget-view+json": {
569
+ "model_id": "9da93b4d1d6241819757834a6da521dd",
570
+ "version_major": 2,
571
+ "version_minor": 1
572
+ },
573
+ "text/plain": [
574
+ "VisualizerWidget(egraphs=['{\"nodes\":{\"primitive-i64-2\":{\"op\":\"2\",\"children\":[],\"eclass\":\"i64-2\",\"cost\":1.0,\"su…"
575
+ ]
576
+ },
577
+ "metadata": {},
578
+ "output_type": "display_data"
579
+ }
580
+ ],
581
+ "source": [
582
+ "egraph = EGraph()\n",
583
+ "idxed = egraph.let(\"idxed\", Amts[idxs])\n",
584
+ "egraph.saturate(array_ruleset, expr=idxed)"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": null,
590
+ "id": "2642b054",
591
+ "metadata": {},
592
+ "outputs": [],
593
+ "source": []
594
+ }
595
+ ],
596
+ "metadata": {
597
+ "kernelspec": {
598
+ "display_name": "egglog",
599
+ "language": "python",
600
+ "name": "python3"
601
+ },
602
+ "language_info": {
603
+ "codemirror_mode": {
604
+ "name": "ipython",
605
+ "version": 3
606
+ },
607
+ "file_extension": ".py",
608
+ "mimetype": "text/x-python",
609
+ "name": "python",
610
+ "nbconvert_exporter": "python",
611
+ "pygments_lexer": "ipython3",
612
+ "version": "3.13.3"
613
+ }
614
+ },
615
+ "nbformat": 4,
616
+ "nbformat_minor": 5
617
+ }
egglog/exp/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """
2
+ Experimental interfaces built on egglog.
3
+ """