egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/bindings.pyi +887 -0
- egglog/builtins.py +1144 -0
- egglog/config.py +8 -0
- egglog/conversion.py +290 -0
- egglog/declarations.py +964 -0
- egglog/deconstruct.py +176 -0
- egglog/egraph.py +2247 -0
- egglog/egraph_state.py +978 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +64 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/MoA.ipynb +617 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/any_expr.py +947 -0
- egglog/exp/any_expr_example.ipynb +408 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +427 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +566 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +888 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +111 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35798 -0
- egglog/visualizer_widget.py +39 -0
- egglog-12.0.0.dist-info/METADATA +93 -0
- egglog-12.0.0.dist-info/RECORD +48 -0
- egglog-12.0.0.dist-info/WHEEL +5 -0
- egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "code",
|
|
5
|
+
"execution_count": null,
|
|
6
|
+
"id": "6005ea18",
|
|
7
|
+
"metadata": {},
|
|
8
|
+
"outputs": [],
|
|
9
|
+
"source": []
|
|
10
|
+
},
|
|
11
|
+
{
|
|
12
|
+
"cell_type": "markdown",
|
|
13
|
+
"id": "7d1e4a46",
|
|
14
|
+
"metadata": {},
|
|
15
|
+
"source": [
|
|
16
|
+
"WIP example of using `any_expr` to generate sklearn code\n"
|
|
17
|
+
]
|
|
18
|
+
},
|
|
19
|
+
{
|
|
20
|
+
"cell_type": "code",
|
|
21
|
+
"execution_count": 1,
|
|
22
|
+
"id": "42cc576e",
|
|
23
|
+
"metadata": {},
|
|
24
|
+
"outputs": [
|
|
25
|
+
{
|
|
26
|
+
"name": "stdout",
|
|
27
|
+
"output_type": "stream",
|
|
28
|
+
"text": [
|
|
29
|
+
"_0 = list((*(), 42))\n",
|
|
30
|
+
"del _0[0]\n",
|
|
31
|
+
"assert (list((*(), 42)) != 10)\n",
|
|
32
|
+
"assert (not (_0 == 1))\n",
|
|
33
|
+
"-_0\n"
|
|
34
|
+
]
|
|
35
|
+
}
|
|
36
|
+
],
|
|
37
|
+
"source": [
|
|
38
|
+
"import os\n",
|
|
39
|
+
"\n",
|
|
40
|
+
"os.environ[\"SCIPY_ARRAY_API\"] = \"1\"\n",
|
|
41
|
+
"\n",
|
|
42
|
+
"from sklearn import set_config\n",
|
|
43
|
+
"from sklearn.datasets import make_classification\n",
|
|
44
|
+
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
|
|
45
|
+
"\n",
|
|
46
|
+
"from egglog import *\n",
|
|
47
|
+
"from egglog.exp import any_expr\n",
|
|
48
|
+
"from egglog.exp.any_expr import *\n",
|
|
49
|
+
"\n",
|
|
50
|
+
"set_config(array_api_dispatch=True)"
|
|
51
|
+
]
|
|
52
|
+
},
|
|
53
|
+
{
|
|
54
|
+
"cell_type": "code",
|
|
55
|
+
"execution_count": 2,
|
|
56
|
+
"id": "a7d3e251",
|
|
57
|
+
"metadata": {},
|
|
58
|
+
"outputs": [],
|
|
59
|
+
"source": [
|
|
60
|
+
"@ruleset\n",
|
|
61
|
+
"def extra_rules(x: A, y: A, z: A, o: PyObject):\n",
|
|
62
|
+
" yield rewrite(x.__array_namespace__(y, z).__name__, subsume=True).to(A(__name__))\n",
|
|
63
|
+
"\n",
|
|
64
|
+
" # File ~/p/egg-smol-python/.venv/lib/python3.13/site-packages/sklearn/utils/_unique.py:72, in _cached_unique(y, xp)\n",
|
|
65
|
+
" # 63 \"\"\"Return the unique values of y.\n",
|
|
66
|
+
" # 64\n",
|
|
67
|
+
" # 65 Use the cached values from dtype.metadata if present.\n",
|
|
68
|
+
" # (...) 69 Call `attach_unique` to attach the unique values to y.\n",
|
|
69
|
+
" # 70 \"\"\"\n",
|
|
70
|
+
" # 71 try:\n",
|
|
71
|
+
" # ---> 72 if y.dtype.metadata is not None and \"unique\" in y.dtype.metadata:\n",
|
|
72
|
+
"\n",
|
|
73
|
+
" yield rule(eq(A(None)).to(x.dtype.metadata)).then(getattr_eager(x.dtype, \"metadata\"))\n",
|
|
74
|
+
"\n",
|
|
75
|
+
" # File ~/p/egg-smol-python/.venv/lib/python3.13/site-packages/sklearn/utils/validation.py:1093, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_all_finite, ensure_non_negative, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\n",
|
|
76
|
+
" # 1085 msg = (\n",
|
|
77
|
+
" # 1086 f\"Expected 2D array, got 1D array instead:\\narray={array}.\\n\"\n",
|
|
78
|
+
" # 1087 \"Reshape your data either using array.reshape(-1, 1) if \"\n",
|
|
79
|
+
" # 1088 \"your data has a single feature or array.reshape(1, -1) \"\n",
|
|
80
|
+
" # 1089 \"if it contains a single sample.\"\n",
|
|
81
|
+
" # 1090 )\n",
|
|
82
|
+
" # 1091 raise ValueError(msg)\n",
|
|
83
|
+
" # -> 1093 if dtype_numeric and hasattr(array.dtype, \"kind\") and array.dtype.kind in \"USV\":\n",
|
|
84
|
+
" # 1094 raise ValueError(\n",
|
|
85
|
+
" # 1095 \"dtype='numeric' is not compatible with arrays of bytes/strings.\"\n",
|
|
86
|
+
" # 1096 \"Convert your data to numeric values explicitly instead.\"\n",
|
|
87
|
+
" # 1097 )\n",
|
|
88
|
+
" # 1098 if not allow_nd and array.ndim >= 3:\n",
|
|
89
|
+
"\n",
|
|
90
|
+
" # TypeError: 'in <string>' requires string as left operand, not RuntimeExpr\n",
|
|
91
|
+
" # Eagerly eval kind on dtype attribute\n",
|
|
92
|
+
" yield rule(x.dtype.kind).then(getattr_eager(x.dtype, \"kind\"))\n",
|
|
93
|
+
"\n",
|
|
94
|
+
"\n",
|
|
95
|
+
"any_expr.any_expr_schedule = (extra_rules.saturate() + any_expr_ruleset).saturate()"
|
|
96
|
+
]
|
|
97
|
+
},
|
|
98
|
+
{
|
|
99
|
+
"cell_type": "code",
|
|
100
|
+
"execution_count": 3,
|
|
101
|
+
"id": "6080fd28",
|
|
102
|
+
"metadata": {},
|
|
103
|
+
"outputs": [],
|
|
104
|
+
"source": [
|
|
105
|
+
"lda = LinearDiscriminantAnalysis()\n",
|
|
106
|
+
"X_np, y_np = make_classification(random_state=0)\n",
|
|
107
|
+
"\n",
|
|
108
|
+
"egraph = EGraph()\n",
|
|
109
|
+
"with set_any_expr_egraph(egraph):\n",
|
|
110
|
+
" x = lda.fit_transform(AnyExpr(X_np), AnyExpr(y_np))"
|
|
111
|
+
]
|
|
112
|
+
},
|
|
113
|
+
{
|
|
114
|
+
"cell_type": "code",
|
|
115
|
+
"execution_count": 4,
|
|
116
|
+
"id": "c408cfcc",
|
|
117
|
+
"metadata": {},
|
|
118
|
+
"outputs": [],
|
|
119
|
+
"source": [
|
|
120
|
+
"# print(str(x))"
|
|
121
|
+
]
|
|
122
|
+
},
|
|
123
|
+
{
|
|
124
|
+
"cell_type": "code",
|
|
125
|
+
"execution_count": 5,
|
|
126
|
+
"id": "cd520b07",
|
|
127
|
+
"metadata": {},
|
|
128
|
+
"outputs": [
|
|
129
|
+
{
|
|
130
|
+
"name": "stdout",
|
|
131
|
+
"output_type": "stream",
|
|
132
|
+
"text": [
|
|
133
|
+
"_0 = array([[-0.03926799, 0.13191176, -0.21120598, ..., 1.97698901,\n",
|
|
134
|
+
" 1.02122474, -0.46931074],\n",
|
|
135
|
+
" [ 0.77416061, 0.10490717, -0.33281176, ..., 1.2678044 ,\n",
|
|
136
|
+
" 0.62251914, -1.49026539],\n",
|
|
137
|
+
" [-0.0148577 , 0.67057045, -0.21416666, ..., -0.10486202,\n",
|
|
138
|
+
" -0.10169727, -0.45130304],\n",
|
|
139
|
+
" ...,\n",
|
|
140
|
+
" [ 0.29673317, -0.49610233, -0.86404499, ..., -1.10453952,\n",
|
|
141
|
+
" 2.01406015, 0.69042902],\n",
|
|
142
|
+
" [ 0.08617684, 0.9836362 , 0.17124355, ..., 2.11564734,\n",
|
|
143
|
+
" 0.11273794, 1.20985013],\n",
|
|
144
|
+
" [-1.58249448, -1.42279491, -0.56430103, ..., 1.26661394,\n",
|
|
145
|
+
" -1.31771734, 1.61805427]], shape=(100, 20))\n",
|
|
146
|
+
"_1 = array([0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1,\n",
|
|
147
|
+
" 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0,\n",
|
|
148
|
+
" 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1,\n",
|
|
149
|
+
" 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1,\n",
|
|
150
|
+
" 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0])\n",
|
|
151
|
+
"_2 = _0.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
152
|
+
"_3 = <module 'numpy' from '/Users/saul/p/egg-smol-python/.venv/lib/python3.13/site-packages/numpy/__init__.py'>\n",
|
|
153
|
+
"assert (_2 == _3)\n",
|
|
154
|
+
"assert (not (_2.__name__ == 'array_api_strict'))\n",
|
|
155
|
+
"_4 = _0.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
156
|
+
"assert (_4 == _3)\n",
|
|
157
|
+
"assert (not (_4.__name__ == 'array_api_strict'))\n",
|
|
158
|
+
"assert (_2.float64 == _0.dtype)\n",
|
|
159
|
+
"assert (_2.float64 == _0.dtype)\n",
|
|
160
|
+
"assert (_4.__name__ == '__main__')\n",
|
|
161
|
+
"_5 = _4.asarray(*(*(), _0), **{**{**{**{}, 'dtype': None}, 'copy': None}, 'device': None})\n",
|
|
162
|
+
"assert (_5.dtype.kind == 'f')\n",
|
|
163
|
+
"assert (_5.dtype.kind == 'f')\n",
|
|
164
|
+
"assert (not (_5.ndim == 0))\n",
|
|
165
|
+
"assert (not (_5.ndim == 1))\n",
|
|
166
|
+
"assert (not (_5.ndim >= 3))\n",
|
|
167
|
+
"_6 = _5.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
168
|
+
"assert (_6 == _3)\n",
|
|
169
|
+
"assert (not (_6.__name__ == 'array_api_strict'))\n",
|
|
170
|
+
"_7 = _6.asarray(*(*(), _5), **{})\n",
|
|
171
|
+
"_8 = _6.isdtype(*(*(*(), _7.dtype), (*(*(), 'real floating'), 'complex floating')), **{})\n",
|
|
172
|
+
"assert _8\n",
|
|
173
|
+
"assert _8\n",
|
|
174
|
+
"_9 = _6.sum(*(*(), _7), **{})\n",
|
|
175
|
+
"_10 = _6.isfinite(*(*(), _9), **{})\n",
|
|
176
|
+
"assert _10\n",
|
|
177
|
+
"assert _10\n",
|
|
178
|
+
"assert (len(_5.shape) == 2)\n",
|
|
179
|
+
"assert (len(_5) == 100)\n",
|
|
180
|
+
"assert (_5.ndim == 2)\n",
|
|
181
|
+
"assert (_5.ndim == 2)\n",
|
|
182
|
+
"assert (not (_5.shape[1] < 1))\n",
|
|
183
|
+
"_11 = _1.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
184
|
+
"assert (_11 == _3)\n",
|
|
185
|
+
"assert (not (_11.__name__ == 'array_api_strict'))\n",
|
|
186
|
+
"_12 = _1.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
187
|
+
"assert (_12 == _3)\n",
|
|
188
|
+
"assert (not (_12.__name__ == 'array_api_strict'))\n",
|
|
189
|
+
"assert (_12.__name__ == '__main__')\n",
|
|
190
|
+
"_13 = _12.asarray(*(*(), _1), **{**{**{**{}, 'dtype': None}, 'copy': None}, 'device': None})\n",
|
|
191
|
+
"assert (_13.dtype.kind == 'i')\n",
|
|
192
|
+
"assert (_13.dtype.kind == 'i')\n",
|
|
193
|
+
"assert (not (_13.ndim >= 3))\n",
|
|
194
|
+
"assert (not (_13.ndim == 2))\n",
|
|
195
|
+
"assert (len(_13.shape) == 1)\n",
|
|
196
|
+
"assert (_11.__name__ == '__main__')\n",
|
|
197
|
+
"_14 = _11.reshape(*(*(*(), _13), (*(), -1)), **{})\n",
|
|
198
|
+
"_15 = _11.asarray(*(*(), _14), **{**{**{**{}, 'dtype': None}, 'copy': None}, 'device': None})\n",
|
|
199
|
+
"_16 = _15.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
200
|
+
"assert (_16 == _3)\n",
|
|
201
|
+
"assert (not (_16.__name__ == 'array_api_strict'))\n",
|
|
202
|
+
"_17 = _16.asarray(*(*(), _15), **{})\n",
|
|
203
|
+
"_18 = _16.isdtype(*(*(*(), _17.dtype), (*(*(), 'real floating'), 'complex floating')), **{})\n",
|
|
204
|
+
"assert (not _18)\n",
|
|
205
|
+
"assert (_15.dtype.kind == 'i')\n",
|
|
206
|
+
"assert (_15.dtype.kind == 'i')\n",
|
|
207
|
+
"assert (len(_5.shape) == 2)\n",
|
|
208
|
+
"assert (len(_5) == 100)\n",
|
|
209
|
+
"assert (len(_15.shape) == 1)\n",
|
|
210
|
+
"assert (len(_15) == 100)\n",
|
|
211
|
+
"assert (len(_5.shape) == 2)\n",
|
|
212
|
+
"_19 = _15.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
213
|
+
"assert (_19 == _3)\n",
|
|
214
|
+
"assert (not (_19.__name__ == 'array_api_strict'))\n",
|
|
215
|
+
"_20 = _15.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
216
|
+
"assert (_20 == _3)\n",
|
|
217
|
+
"assert (not (_20.__name__ == 'array_api_strict'))\n",
|
|
218
|
+
"_21 = _15.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
219
|
+
"assert (_21 == _3)\n",
|
|
220
|
+
"assert (not (_21.__name__ == 'array_api_strict'))\n",
|
|
221
|
+
"_22 = _15.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
222
|
+
"assert (_22 == _3)\n",
|
|
223
|
+
"assert (not (_22.__name__ == 'array_api_strict'))\n",
|
|
224
|
+
"assert (_22.__name__ == '__main__')\n",
|
|
225
|
+
"_23 = _22.asarray(*(*(), _15), **{**{**{**{}, 'dtype': None}, 'copy': None}, 'device': None})\n",
|
|
226
|
+
"assert (_23.dtype.kind == 'i')\n",
|
|
227
|
+
"assert (_23.dtype.kind == 'i')\n",
|
|
228
|
+
"assert (not (_23.ndim == 2))\n",
|
|
229
|
+
"_24 = _15.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
230
|
+
"assert (_24 == _3)\n",
|
|
231
|
+
"assert (not (_24.__name__ == 'array_api_strict'))\n",
|
|
232
|
+
"assert (_24.__name__ == '__main__')\n",
|
|
233
|
+
"_25 = _24.asarray(*(*(), _15), **{**{**{**{}, 'dtype': None}, 'copy': None}, 'device': None})\n",
|
|
234
|
+
"assert (_25.dtype.kind == 'i')\n",
|
|
235
|
+
"assert (_25.dtype.kind == 'i')\n",
|
|
236
|
+
"assert (_25.ndim == 1)\n",
|
|
237
|
+
"assert (_25.ndim == 1)\n",
|
|
238
|
+
"assert (len(_25.shape) == 1)\n",
|
|
239
|
+
"assert _25.shape[0]\n",
|
|
240
|
+
"assert _25.shape[0]\n",
|
|
241
|
+
"_26 = <class 'object'>\n",
|
|
242
|
+
"assert (not (_25.dtype == _26))\n",
|
|
243
|
+
"assert (not (_25.ndim == 2))\n",
|
|
244
|
+
"_27 = _20.isdtype(*(*(*(), _25.dtype), 'real floating'), **{})\n",
|
|
245
|
+
"assert (not _27)\n",
|
|
246
|
+
"assert (_25.dtype.metadata == None)\n",
|
|
247
|
+
"_28 = _25.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
248
|
+
"assert (_28 == _3)\n",
|
|
249
|
+
"assert (not (_28.__name__ == 'array_api_strict'))\n",
|
|
250
|
+
"_29 = _28.unique_values(*(*(), _25), **{})\n",
|
|
251
|
+
"assert (_25.shape[0] > 20)\n",
|
|
252
|
+
"assert (_25.shape[0] > 20)\n",
|
|
253
|
+
"assert (_25.shape[0] > _29.shape[0])\n",
|
|
254
|
+
"assert (_25.shape[0] > _29.shape[0])\n",
|
|
255
|
+
"assert (not (_29.shape[0] > round((0.5 * _25.shape[0]))))\n",
|
|
256
|
+
"assert (not (_29.shape[0] > 2))\n",
|
|
257
|
+
"assert (not (_25.ndim == 2))\n",
|
|
258
|
+
"_30 = _19.asarray(*(*(), _15), **{})\n",
|
|
259
|
+
"assert (_30.dtype.metadata == None)\n",
|
|
260
|
+
"_31 = _19.unique_values(*(*(), _30), **{})\n",
|
|
261
|
+
"_32 = _19.concat(*(*(), list((*(), _31))), **{})\n",
|
|
262
|
+
"_33 = _19.unique_values(*(*(), _32), **{})\n",
|
|
263
|
+
"assert (len(_5.shape) == 2)\n",
|
|
264
|
+
"assert (not (_5.shape[0] == _33.shape[0]))\n",
|
|
265
|
+
"_34 = _2.unique_counts(*(*(), _15), **{})\n",
|
|
266
|
+
"assert (len(_34) == 2)\n",
|
|
267
|
+
"assert (float(_15.shape[0]) == 100.0)\n",
|
|
268
|
+
"_35 = _2.astype(*(*(*(), _34[1]), _5.dtype), **{})\n",
|
|
269
|
+
"_36 = _2.any(*(*(), ((_35 / 100.0) < 0)), **{})\n",
|
|
270
|
+
"assert (not _36)\n",
|
|
271
|
+
"_37 = _2.sum(*(*(), (_35 / 100.0)), **{})\n",
|
|
272
|
+
"_38 = _2.abs(*(*(), (_37 - 1.0)), **{})\n",
|
|
273
|
+
"assert (not (_38 > 1e-05))\n",
|
|
274
|
+
"assert (not (_5.shape[1] < (_33.shape[0] - 1)))\n",
|
|
275
|
+
"_39 = _5.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
276
|
+
"assert (_39 == _3)\n",
|
|
277
|
+
"assert (not (_39.__name__ == 'array_api_strict'))\n",
|
|
278
|
+
"assert (len(_5.shape) == 2)\n",
|
|
279
|
+
"_40 = _5.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
280
|
+
"assert (_40 == _3)\n",
|
|
281
|
+
"assert (not (_40.__name__ == 'array_api_strict'))\n",
|
|
282
|
+
"_41 = _40.unique_inverse(*(*(), _15), **{})\n",
|
|
283
|
+
"assert (len(_41) == 2)\n",
|
|
284
|
+
"assert (operator.index(_41[0].shape[0]) == 2)\n",
|
|
285
|
+
"assert (len(_33) == 2)\n",
|
|
286
|
+
"_42 = _39.asarray(*(*(), (1.0 / (_5.shape[0] - _33.shape[0]))), **{**{**{}, 'dtype': _5.dtype}, 'device': _5.device})\n",
|
|
287
|
+
"_43 = _39.sqrt(*(*(), _42), **{})\n",
|
|
288
|
+
"_44 = _40.zeros(*(*(), (*(*(), _41[0].shape[0]), _5.shape[1])), **{**{**{}, 'device': _5.device}, 'dtype': _5.dtype})\n",
|
|
289
|
+
"_45 = _40.mean(*(*(), _5[(_41[1] == 0)]), **{**{}, 'axis': 0})\n",
|
|
290
|
+
"_44[(*(*(), 0), slice(None, None, None))] = _45\n",
|
|
291
|
+
"_46 = _40.mean(*(*(), _5[(_41[1] == 1)]), **{**{}, 'axis': 0})\n",
|
|
292
|
+
"_44[(*(*(), 1), slice(None, None, None))] = _46\n",
|
|
293
|
+
"_47 = _39.concat(*(*(), list((*(*(), (_5[(_15 == _33[0])] - _44[(*(*(), 0), slice(None, None, None))])), (_5[(_15 == _33[1])] - _44[(*(*(), 1), slice(None, None, None))])))), **{**{}, 'axis': 0})\n",
|
|
294
|
+
"_48 = _39.std(*(*(), _47), **{**{}, 'axis': 0})\n",
|
|
295
|
+
"_48[(_48 == 0)] = 1.0\n",
|
|
296
|
+
"_49 = _39.linalg.svd(*(*(), (_43 * (_47 / _48))), **{**{}, 'full_matrices': False})\n",
|
|
297
|
+
"assert (len(_49) == 3)\n",
|
|
298
|
+
"assert (not (_33.shape[0] == 1))\n",
|
|
299
|
+
"_50 = _39.sqrt(*(*(), ((_5.shape[0] * (_35 / 100.0)) * (1.0 / (_33.shape[0] - 1)))), **{})\n",
|
|
300
|
+
"_51 = _39.astype(*(*(*(), (_49[1] > 0.0001)), _39.int32), **{})\n",
|
|
301
|
+
"_52 = _39.sum(*(*(), _51), **{})\n",
|
|
302
|
+
"_53 = _39.linalg.svd(*(*(), ((_50 * (_44 - ((_35 / 100.0) @ _44)).T).T @ ((_49[2][(*(*(), slice(None, _52, None)), slice(None, None, None))] / _48).T / _49[1][slice(None, _52, None)]))), **{**{}, 'full_matrices': False})\n",
|
|
303
|
+
"assert (len(_53) == 3)\n",
|
|
304
|
+
"assert (not ((_33.shape[0] - 1) == 0))\n",
|
|
305
|
+
"assert (len(_33.shape) == 1)\n",
|
|
306
|
+
"assert ((1 * _33.shape[0]) == 2)\n",
|
|
307
|
+
"assert ((1 * _33.shape[0]) == 2)\n",
|
|
308
|
+
"_54 = _0.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
309
|
+
"assert (_54 == _3)\n",
|
|
310
|
+
"assert (not (_54.__name__ == 'array_api_strict'))\n",
|
|
311
|
+
"_55 = _0.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
312
|
+
"assert (_55 == _3)\n",
|
|
313
|
+
"assert (not (_55.__name__ == 'array_api_strict'))\n",
|
|
314
|
+
"assert (_0.dtype.kind == 'f')\n",
|
|
315
|
+
"assert (_0.dtype.kind == 'f')\n",
|
|
316
|
+
"assert (_55.__name__ == '__main__')\n",
|
|
317
|
+
"_56 = _55.asarray(*(*(), _0), **{**{**{**{}, 'dtype': None}, 'copy': None}, 'device': None})\n",
|
|
318
|
+
"assert (_56.dtype.kind == 'f')\n",
|
|
319
|
+
"assert (_56.dtype.kind == 'f')\n",
|
|
320
|
+
"assert (not (_56.ndim == 0))\n",
|
|
321
|
+
"assert (not (_56.ndim == 1))\n",
|
|
322
|
+
"assert (_56.dtype.kind == 'f')\n",
|
|
323
|
+
"assert (_56.dtype.kind == 'f')\n",
|
|
324
|
+
"assert (not (_56.ndim >= 3))\n",
|
|
325
|
+
"_57 = _56.__array_namespace__(*(), **{**{}, 'api_version': None})\n",
|
|
326
|
+
"assert (_57 == _3)\n",
|
|
327
|
+
"assert (not (_57.__name__ == 'array_api_strict'))\n",
|
|
328
|
+
"_58 = _57.asarray(*(*(), _56), **{})\n",
|
|
329
|
+
"_59 = _57.isdtype(*(*(*(), _58.dtype), (*(*(), 'real floating'), 'complex floating')), **{})\n",
|
|
330
|
+
"assert _59\n",
|
|
331
|
+
"assert _59\n",
|
|
332
|
+
"_60 = _57.sum(*(*(), _58), **{})\n",
|
|
333
|
+
"_61 = _57.isfinite(*(*(), _60), **{})\n",
|
|
334
|
+
"assert _61\n",
|
|
335
|
+
"assert _61\n",
|
|
336
|
+
"assert (len(_56.shape) == 2)\n",
|
|
337
|
+
"assert (len(_56) == 100)\n",
|
|
338
|
+
"assert (_56.ndim == 2)\n",
|
|
339
|
+
"assert (_56.ndim == 2)\n",
|
|
340
|
+
"assert (not (_56.shape[1] < 1))\n",
|
|
341
|
+
"assert (len(_0.shape) == 2)\n",
|
|
342
|
+
"assert (not (_0.shape[1] != _5.shape[1]))\n",
|
|
343
|
+
"_62 = _39.astype(*(*(*(), (_53[1] > (0.0001 * _53[1][0]))), _39.int32), **{})\n",
|
|
344
|
+
"_63 = _39.sum(*(*(), _62), **{})\n",
|
|
345
|
+
"((_56 - ((_35 / 100.0) @ _44)) @ (((_49[2][(*(*(), slice(None, _52, None)), slice(None, None, None))] / _48).T / _49[1][slice(None, _52, None)]) @ _53[2].T[(*(*(), slice(None, None, None)), slice(None, _63, None))]))[(*(*(), slice(None, None, None)), slice(None, (_33.shape[0] - 1), None))]\n"
|
|
346
|
+
]
|
|
347
|
+
}
|
|
348
|
+
],
|
|
349
|
+
"source": [
|
|
350
|
+
"print(any_expr_source(x))"
|
|
351
|
+
]
|
|
352
|
+
},
|
|
353
|
+
{
|
|
354
|
+
"cell_type": "code",
|
|
355
|
+
"execution_count": 6,
|
|
356
|
+
"id": "664abeee",
|
|
357
|
+
"metadata": {},
|
|
358
|
+
"outputs": [
|
|
359
|
+
{
|
|
360
|
+
"name": "stdout",
|
|
361
|
+
"output_type": "stream",
|
|
362
|
+
"text": [
|
|
363
|
+
"The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n"
|
|
364
|
+
]
|
|
365
|
+
}
|
|
366
|
+
],
|
|
367
|
+
"source": [
|
|
368
|
+
"# egraph = EGraph()\n",
|
|
369
|
+
"# # x = given(given(10, 100), 1000)\n",
|
|
370
|
+
"# egraph.register(x)\n",
|
|
371
|
+
"# egraph.run(given_ruleset.saturate())\n",
|
|
372
|
+
"# # egraph.display()\n",
|
|
373
|
+
"# print(str(egraph.extract(x))) # $, cost_model=greedy_dag_cost_model())))\n",
|
|
374
|
+
"\n",
|
|
375
|
+
"# a"
|
|
376
|
+
]
|
|
377
|
+
},
|
|
378
|
+
{
|
|
379
|
+
"cell_type": "code",
|
|
380
|
+
"execution_count": null,
|
|
381
|
+
"id": "b94d88c7",
|
|
382
|
+
"metadata": {},
|
|
383
|
+
"outputs": [],
|
|
384
|
+
"source": []
|
|
385
|
+
}
|
|
386
|
+
],
|
|
387
|
+
"metadata": {
|
|
388
|
+
"kernelspec": {
|
|
389
|
+
"display_name": "egglog",
|
|
390
|
+
"language": "python",
|
|
391
|
+
"name": "python3"
|
|
392
|
+
},
|
|
393
|
+
"language_info": {
|
|
394
|
+
"codemirror_mode": {
|
|
395
|
+
"name": "ipython",
|
|
396
|
+
"version": 3
|
|
397
|
+
},
|
|
398
|
+
"file_extension": ".py",
|
|
399
|
+
"mimetype": "text/x-python",
|
|
400
|
+
"name": "python",
|
|
401
|
+
"nbconvert_exporter": "python",
|
|
402
|
+
"pygments_lexer": "ipython3",
|
|
403
|
+
"version": "3.13.3"
|
|
404
|
+
}
|
|
405
|
+
},
|
|
406
|
+
"nbformat": 4,
|
|
407
|
+
"nbformat_minor": 5
|
|
408
|
+
}
|