egglog 6.1.0__cp310-none-win_amd64.whl → 7.0.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +2 -0
- egglog/builtins.py +1 -1
- egglog/conversion.py +172 -0
- egglog/declarations.py +329 -735
- egglog/egraph.py +531 -804
- egglog/egraph_state.py +417 -0
- egglog/exp/array_api.py +92 -80
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +418 -0
- egglog/runtime.py +196 -430
- egglog/thunk.py +72 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/METADATA +19 -19
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/license_files/LICENSE +0 -0
egglog/egraph_state.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implement conversion to/from egglog.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import TYPE_CHECKING, overload
|
|
10
|
+
from weakref import WeakKeyDictionary
|
|
11
|
+
|
|
12
|
+
from typing_extensions import assert_never
|
|
13
|
+
|
|
14
|
+
from . import bindings
|
|
15
|
+
from .declarations import *
|
|
16
|
+
from .pretty import *
|
|
17
|
+
from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from collections.abc import Iterable
|
|
21
|
+
|
|
22
|
+
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
|
|
23
|
+
|
|
24
|
+
# Create a global sort for python objects, so we can store them without an e-graph instance
|
|
25
|
+
# Needed when serializing commands to egg commands when creating modules
|
|
26
|
+
GLOBAL_PY_OBJECT_SORT = bindings.PyObjectSort()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class EGraphState:
|
|
31
|
+
"""
|
|
32
|
+
State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
|
|
33
|
+
|
|
34
|
+
Used for converting to/from egg and for pretty printing.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
egraph: bindings.EGraph
|
|
38
|
+
# The decleratons we have added.
|
|
39
|
+
__egg_decls__: Declarations = field(default_factory=Declarations)
|
|
40
|
+
# Mapping of added rulesets to the added rules
|
|
41
|
+
rulesets: dict[str, set[RewriteOrRuleDecl]] = field(default_factory=dict)
|
|
42
|
+
|
|
43
|
+
# Bidirectional mapping between egg function names and python callable references.
|
|
44
|
+
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
|
|
45
|
+
# for both int and rational classes.
|
|
46
|
+
egg_fn_to_callable_refs: dict[str, set[CallableRef]] = field(
|
|
47
|
+
default_factory=lambda: defaultdict(set, {"!=": {FunctionRef("!=")}})
|
|
48
|
+
)
|
|
49
|
+
callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=lambda: {FunctionRef("!="): "!="})
|
|
50
|
+
|
|
51
|
+
# Bidirectional mapping between egg sort names and python type references.
|
|
52
|
+
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
|
|
53
|
+
|
|
54
|
+
# Cache of egg expressions for converting to egg
|
|
55
|
+
expr_to_egg_cache: WeakKeyDictionary[ExprDecl, bindings._Expr] = field(default_factory=WeakKeyDictionary)
|
|
56
|
+
|
|
57
|
+
def copy(self) -> EGraphState:
|
|
58
|
+
"""
|
|
59
|
+
Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
|
|
60
|
+
"""
|
|
61
|
+
return EGraphState(
|
|
62
|
+
egraph=self.egraph,
|
|
63
|
+
__egg_decls__=self.__egg_decls__.copy(),
|
|
64
|
+
rulesets={k: v.copy() for k, v in self.rulesets.items()},
|
|
65
|
+
egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}),
|
|
66
|
+
callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
|
|
67
|
+
type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
|
|
68
|
+
expr_to_egg_cache=self.expr_to_egg_cache.copy(),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
|
|
72
|
+
match schedule:
|
|
73
|
+
case SaturateDecl(schedule):
|
|
74
|
+
return bindings.Saturate(self.schedule_to_egg(schedule))
|
|
75
|
+
case RepeatDecl(schedule, times):
|
|
76
|
+
return bindings.Repeat(times, self.schedule_to_egg(schedule))
|
|
77
|
+
case SequenceDecl(schedules):
|
|
78
|
+
return bindings.Sequence([self.schedule_to_egg(s) for s in schedules])
|
|
79
|
+
case RunDecl(ruleset_name, until):
|
|
80
|
+
self.ruleset_to_egg(ruleset_name)
|
|
81
|
+
config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
|
|
82
|
+
return bindings.Run(config)
|
|
83
|
+
case _:
|
|
84
|
+
assert_never(schedule)
|
|
85
|
+
|
|
86
|
+
def ruleset_to_egg(self, name: str) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Registers a ruleset if it's not already registered.
|
|
89
|
+
"""
|
|
90
|
+
if name not in self.rulesets:
|
|
91
|
+
if name:
|
|
92
|
+
self.egraph.run_program(bindings.AddRuleset(name))
|
|
93
|
+
rules = self.rulesets[name] = set()
|
|
94
|
+
else:
|
|
95
|
+
rules = self.rulesets[name]
|
|
96
|
+
for rule in self.__egg_decls__._rulesets[name].rules:
|
|
97
|
+
if rule in rules:
|
|
98
|
+
continue
|
|
99
|
+
self.egraph.run_program(self.command_to_egg(rule, name))
|
|
100
|
+
rules.add(rule)
|
|
101
|
+
|
|
102
|
+
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
|
|
103
|
+
match cmd:
|
|
104
|
+
case ActionCommandDecl(action):
|
|
105
|
+
return bindings.ActionCommand(self.action_to_egg(action))
|
|
106
|
+
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
|
|
107
|
+
self.type_ref_to_egg(tp)
|
|
108
|
+
rewrite = bindings.Rewrite(
|
|
109
|
+
self.expr_to_egg(lhs),
|
|
110
|
+
self.expr_to_egg(rhs),
|
|
111
|
+
[self.fact_to_egg(c) for c in conditions],
|
|
112
|
+
)
|
|
113
|
+
return (
|
|
114
|
+
bindings.RewriteCommand(ruleset, rewrite, cmd.subsume)
|
|
115
|
+
if isinstance(cmd, RewriteDecl)
|
|
116
|
+
else bindings.BiRewriteCommand(ruleset, rewrite)
|
|
117
|
+
)
|
|
118
|
+
case RuleDecl(head, body, name):
|
|
119
|
+
rule = bindings.Rule(
|
|
120
|
+
[self.action_to_egg(a) for a in head],
|
|
121
|
+
[self.fact_to_egg(f) for f in body],
|
|
122
|
+
)
|
|
123
|
+
return bindings.RuleCommand(name or "", ruleset, rule)
|
|
124
|
+
case _:
|
|
125
|
+
assert_never(cmd)
|
|
126
|
+
|
|
127
|
+
def action_to_egg(self, action: ActionDecl) -> bindings._Action:
|
|
128
|
+
match action:
|
|
129
|
+
case LetDecl(name, typed_expr):
|
|
130
|
+
return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
|
|
131
|
+
case SetDecl(tp, call, rhs):
|
|
132
|
+
self.type_ref_to_egg(tp)
|
|
133
|
+
call_ = self.expr_to_egg(call)
|
|
134
|
+
return bindings.Set(call_.name, call_.args, self.expr_to_egg(rhs))
|
|
135
|
+
case ExprActionDecl(typed_expr):
|
|
136
|
+
return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
|
|
137
|
+
case ChangeDecl(tp, call, change):
|
|
138
|
+
self.type_ref_to_egg(tp)
|
|
139
|
+
call_ = self.expr_to_egg(call)
|
|
140
|
+
egg_change: bindings._Change
|
|
141
|
+
match change:
|
|
142
|
+
case "delete":
|
|
143
|
+
egg_change = bindings.Delete()
|
|
144
|
+
case "subsume":
|
|
145
|
+
egg_change = bindings.Subsume()
|
|
146
|
+
case _:
|
|
147
|
+
assert_never(change)
|
|
148
|
+
return bindings.Change(egg_change, call_.name, call_.args)
|
|
149
|
+
case UnionDecl(tp, lhs, rhs):
|
|
150
|
+
self.type_ref_to_egg(tp)
|
|
151
|
+
return bindings.Union(self.expr_to_egg(lhs), self.expr_to_egg(rhs))
|
|
152
|
+
case PanicDecl(name):
|
|
153
|
+
return bindings.Panic(name)
|
|
154
|
+
case _:
|
|
155
|
+
assert_never(action)
|
|
156
|
+
|
|
157
|
+
def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
|
|
158
|
+
match fact:
|
|
159
|
+
case EqDecl(tp, exprs):
|
|
160
|
+
self.type_ref_to_egg(tp)
|
|
161
|
+
return bindings.Eq([self.expr_to_egg(e) for e in exprs])
|
|
162
|
+
case ExprFactDecl(typed_expr):
|
|
163
|
+
return bindings.Fact(self.typed_expr_to_egg(typed_expr))
|
|
164
|
+
case _:
|
|
165
|
+
assert_never(fact)
|
|
166
|
+
|
|
167
|
+
def callable_ref_to_egg(self, ref: CallableRef) -> str:
|
|
168
|
+
"""
|
|
169
|
+
Returns the egg function name for a callable reference, registering it if it is not already registered.
|
|
170
|
+
"""
|
|
171
|
+
if ref in self.callable_ref_to_egg_fn:
|
|
172
|
+
return self.callable_ref_to_egg_fn[ref]
|
|
173
|
+
decl = self.__egg_decls__.get_callable_decl(ref)
|
|
174
|
+
self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _generate_callable_egg_name(ref)
|
|
175
|
+
self.egg_fn_to_callable_refs[egg_name].add(ref)
|
|
176
|
+
match decl:
|
|
177
|
+
case RelationDecl(arg_types, _, _):
|
|
178
|
+
self.egraph.run_program(bindings.Relation(egg_name, [self.type_ref_to_egg(a) for a in arg_types]))
|
|
179
|
+
case ConstantDecl(tp, _):
|
|
180
|
+
# Use function decleration instead of constant b/c constants cannot be extracted
|
|
181
|
+
# https://github.com/egraphs-good/egglog/issues/334
|
|
182
|
+
self.egraph.run_program(
|
|
183
|
+
bindings.Function(bindings.FunctionDecl(egg_name, bindings.Schema([], self.type_ref_to_egg(tp))))
|
|
184
|
+
)
|
|
185
|
+
case FunctionDecl():
|
|
186
|
+
if not decl.builtin:
|
|
187
|
+
egg_fn_decl = bindings.FunctionDecl(
|
|
188
|
+
egg_name,
|
|
189
|
+
bindings.Schema(
|
|
190
|
+
[self.type_ref_to_egg(a.to_just()) for a in decl.arg_types],
|
|
191
|
+
self.type_ref_to_egg(decl.semantic_return_type.to_just()),
|
|
192
|
+
),
|
|
193
|
+
self.expr_to_egg(decl.default) if decl.default else None,
|
|
194
|
+
self.expr_to_egg(decl.merge) if decl.merge else None,
|
|
195
|
+
[self.action_to_egg(a) for a in decl.on_merge],
|
|
196
|
+
decl.cost,
|
|
197
|
+
decl.unextractable,
|
|
198
|
+
)
|
|
199
|
+
self.egraph.run_program(bindings.Function(egg_fn_decl))
|
|
200
|
+
case _:
|
|
201
|
+
assert_never(decl)
|
|
202
|
+
return egg_name
|
|
203
|
+
|
|
204
|
+
def type_ref_to_egg(self, ref: JustTypeRef) -> str:
|
|
205
|
+
"""
|
|
206
|
+
Returns the egg sort name for a type reference, registering it if it is not already registered.
|
|
207
|
+
"""
|
|
208
|
+
try:
|
|
209
|
+
return self.type_ref_to_egg_sort[ref]
|
|
210
|
+
except KeyError:
|
|
211
|
+
pass
|
|
212
|
+
decl = self.__egg_decls__._classes[ref.name]
|
|
213
|
+
self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
|
|
214
|
+
if not decl.builtin or ref.args:
|
|
215
|
+
self.egraph.run_program(
|
|
216
|
+
bindings.Sort(
|
|
217
|
+
egg_name,
|
|
218
|
+
(
|
|
219
|
+
(
|
|
220
|
+
self.type_ref_to_egg(JustTypeRef(ref.name)),
|
|
221
|
+
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args],
|
|
222
|
+
)
|
|
223
|
+
if ref.args
|
|
224
|
+
else None
|
|
225
|
+
),
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
|
|
229
|
+
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
|
|
230
|
+
# even if you never use that function.
|
|
231
|
+
if decl.builtin:
|
|
232
|
+
for method in decl.class_methods:
|
|
233
|
+
self.callable_ref_to_egg(ClassMethodRef(ref.name, method))
|
|
234
|
+
|
|
235
|
+
return egg_name
|
|
236
|
+
|
|
237
|
+
def op_mapping(self) -> dict[str, str]:
|
|
238
|
+
"""
|
|
239
|
+
Create a mapping of egglog function name to Python function name, for use in the serialized format
|
|
240
|
+
for better visualization.
|
|
241
|
+
"""
|
|
242
|
+
return {
|
|
243
|
+
k: pretty_callable_ref(self.__egg_decls__, next(iter(v)))
|
|
244
|
+
for k, v in self.egg_fn_to_callable_refs.items()
|
|
245
|
+
if len(v) == 1
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
|
|
249
|
+
self.type_ref_to_egg(typed_expr_decl.tp)
|
|
250
|
+
return self.expr_to_egg(typed_expr_decl.expr)
|
|
251
|
+
|
|
252
|
+
@overload
|
|
253
|
+
def expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
254
|
+
|
|
255
|
+
@overload
|
|
256
|
+
def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
257
|
+
|
|
258
|
+
def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
|
|
259
|
+
"""
|
|
260
|
+
Convert an ExprDecl to an egg expression.
|
|
261
|
+
|
|
262
|
+
Cached using weakrefs to avoid memory leaks.
|
|
263
|
+
"""
|
|
264
|
+
try:
|
|
265
|
+
return self.expr_to_egg_cache[expr_decl]
|
|
266
|
+
except KeyError:
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
res: bindings._Expr
|
|
270
|
+
match expr_decl:
|
|
271
|
+
case VarDecl(name):
|
|
272
|
+
res = bindings.Var(name)
|
|
273
|
+
case LitDecl(value):
|
|
274
|
+
l: bindings._Literal
|
|
275
|
+
match value:
|
|
276
|
+
case None:
|
|
277
|
+
l = bindings.Unit()
|
|
278
|
+
case bool(i):
|
|
279
|
+
l = bindings.Bool(i)
|
|
280
|
+
case int(i):
|
|
281
|
+
l = bindings.Int(i)
|
|
282
|
+
case float(f):
|
|
283
|
+
l = bindings.F64(f)
|
|
284
|
+
case str(s):
|
|
285
|
+
l = bindings.String(s)
|
|
286
|
+
case _:
|
|
287
|
+
assert_never(value)
|
|
288
|
+
res = bindings.Lit(l)
|
|
289
|
+
case CallDecl(ref, args, _):
|
|
290
|
+
egg_fn = self.callable_ref_to_egg(ref)
|
|
291
|
+
egg_args = [self.typed_expr_to_egg(a) for a in args]
|
|
292
|
+
res = bindings.Call(egg_fn, egg_args)
|
|
293
|
+
case PyObjectDecl(value):
|
|
294
|
+
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
295
|
+
case _:
|
|
296
|
+
assert_never(expr_decl.expr)
|
|
297
|
+
|
|
298
|
+
self.expr_to_egg_cache[expr_decl] = res
|
|
299
|
+
return res
|
|
300
|
+
|
|
301
|
+
def exprs_from_egg(
|
|
302
|
+
self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
|
|
303
|
+
) -> Iterable[TypedExprDecl]:
|
|
304
|
+
"""
|
|
305
|
+
Create a function that can convert from an egg term to a typed expr.
|
|
306
|
+
"""
|
|
307
|
+
state = FromEggState(self, termdag)
|
|
308
|
+
return [state.from_expr(tp, term) for term in terms]
|
|
309
|
+
|
|
310
|
+
def _get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
|
|
311
|
+
"""
|
|
312
|
+
Given a class name, returns all possible registered types that it can be.
|
|
313
|
+
"""
|
|
314
|
+
return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _generate_type_egg_name(ref: JustTypeRef) -> str:
|
|
318
|
+
"""
|
|
319
|
+
Generates an egg sort name for this type reference by linearizing the type.
|
|
320
|
+
"""
|
|
321
|
+
name = ref.name
|
|
322
|
+
if not ref.args:
|
|
323
|
+
return name
|
|
324
|
+
return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _generate_callable_egg_name(ref: CallableRef) -> str:
|
|
328
|
+
"""
|
|
329
|
+
Generates a valid egg function name for a callable reference.
|
|
330
|
+
"""
|
|
331
|
+
match ref:
|
|
332
|
+
case FunctionRef(name) | ConstantRef(name):
|
|
333
|
+
return name
|
|
334
|
+
case (
|
|
335
|
+
MethodRef(cls_name, name)
|
|
336
|
+
| ClassMethodRef(cls_name, name)
|
|
337
|
+
| ClassVariableRef(cls_name, name)
|
|
338
|
+
| PropertyRef(cls_name, name)
|
|
339
|
+
):
|
|
340
|
+
return f"{cls_name}_{name}"
|
|
341
|
+
case _:
|
|
342
|
+
assert_never(ref)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@dataclass
|
|
346
|
+
class FromEggState:
|
|
347
|
+
"""
|
|
348
|
+
Dataclass containing state used when converting from an egg term to a typed expr.
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
state: EGraphState
|
|
352
|
+
termdag: bindings.TermDag
|
|
353
|
+
# Cache of termdag ID to TypedExprDecl
|
|
354
|
+
cache: dict[int, TypedExprDecl] = field(default_factory=dict)
|
|
355
|
+
|
|
356
|
+
@property
|
|
357
|
+
def decls(self) -> Declarations:
|
|
358
|
+
return self.state.__egg_decls__
|
|
359
|
+
|
|
360
|
+
def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl:
|
|
361
|
+
"""
|
|
362
|
+
Convert an egg term to a typed expr.
|
|
363
|
+
"""
|
|
364
|
+
expr_decl: ExprDecl
|
|
365
|
+
if isinstance(term, bindings.TermVar):
|
|
366
|
+
expr_decl = VarDecl(term.name)
|
|
367
|
+
elif isinstance(term, bindings.TermLit):
|
|
368
|
+
value = term.value
|
|
369
|
+
expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
|
|
370
|
+
elif isinstance(term, bindings.TermApp):
|
|
371
|
+
if term.name == "py-object":
|
|
372
|
+
call = bindings.termdag_term_to_expr(self.termdag, term)
|
|
373
|
+
expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
|
|
374
|
+
else:
|
|
375
|
+
expr_decl = self.from_call(tp, term)
|
|
376
|
+
else:
|
|
377
|
+
assert_never(term)
|
|
378
|
+
return TypedExprDecl(tp, expr_decl)
|
|
379
|
+
|
|
380
|
+
def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl:
|
|
381
|
+
"""
|
|
382
|
+
Convert a call to a CallDecl.
|
|
383
|
+
|
|
384
|
+
There could be Python call refs which match the call, so we need to find the correct one.
|
|
385
|
+
"""
|
|
386
|
+
# Find the first callable ref that matches the call
|
|
387
|
+
for callable_ref in self.state.egg_fn_to_callable_refs[term.name]:
|
|
388
|
+
# If this is a classmethod, we might need the type params that were bound for this type
|
|
389
|
+
# This could be multiple types if the classmethod is ambiguous, like map create.
|
|
390
|
+
possible_types: Iterable[JustTypeRef | None]
|
|
391
|
+
fn_decl = self.decls.get_callable_decl(callable_ref).to_function_decl()
|
|
392
|
+
if isinstance(callable_ref, ClassMethodRef):
|
|
393
|
+
possible_types = self.state._get_possible_types(callable_ref.class_name)
|
|
394
|
+
cls_name = callable_ref.class_name
|
|
395
|
+
else:
|
|
396
|
+
possible_types = [None]
|
|
397
|
+
cls_name = None
|
|
398
|
+
for possible_type in possible_types:
|
|
399
|
+
tcs = TypeConstraintSolver(self.decls)
|
|
400
|
+
if possible_type and possible_type.args:
|
|
401
|
+
tcs.bind_class(possible_type)
|
|
402
|
+
|
|
403
|
+
try:
|
|
404
|
+
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
405
|
+
fn_decl.arg_types, fn_decl.semantic_return_type, fn_decl.var_arg_type, tp, cls_name
|
|
406
|
+
)
|
|
407
|
+
except TypeConstraintError:
|
|
408
|
+
continue
|
|
409
|
+
args: list[TypedExprDecl] = []
|
|
410
|
+
for a, tp in zip(term.args, arg_types, strict=False):
|
|
411
|
+
try:
|
|
412
|
+
res = self.cache[a]
|
|
413
|
+
except KeyError:
|
|
414
|
+
res = self.cache[a] = self.from_expr(tp, self.termdag.nodes[a])
|
|
415
|
+
args.append(res)
|
|
416
|
+
return CallDecl(callable_ref, tuple(args), bound_tp_params)
|
|
417
|
+
raise ValueError(f"Could not find callable ref for call {term}")
|
egglog/exp/array_api.py
CHANGED
|
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|
|
24
24
|
# Pretend that exprs are numbers b/c sklearn does isinstance checks
|
|
25
25
|
numbers.Integral.register(RuntimeExpr)
|
|
26
26
|
|
|
27
|
-
array_api_ruleset = ruleset()
|
|
27
|
+
array_api_ruleset = ruleset(name="array_api_ruleset")
|
|
28
28
|
array_api_schedule = array_api_ruleset.saturate()
|
|
29
29
|
|
|
30
30
|
|
|
@@ -36,10 +36,14 @@ class Boolean(Expr):
|
|
|
36
36
|
@property
|
|
37
37
|
def bool(self) -> Bool: ...
|
|
38
38
|
|
|
39
|
-
def __or__(self, other:
|
|
39
|
+
def __or__(self, other: BooleanLike) -> Boolean: ...
|
|
40
40
|
|
|
41
|
-
def __and__(self, other:
|
|
41
|
+
def __and__(self, other: BooleanLike) -> Boolean: ...
|
|
42
42
|
|
|
43
|
+
def if_int(self, true_value: Int, false_value: Int) -> Int: ...
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
BooleanLike = Boolean | bool
|
|
43
47
|
|
|
44
48
|
TRUE = constant("TRUE", Boolean)
|
|
45
49
|
FALSE = constant("FALSE", Boolean)
|
|
@@ -47,7 +51,7 @@ converter(bool, Boolean, lambda x: TRUE if x else FALSE)
|
|
|
47
51
|
|
|
48
52
|
|
|
49
53
|
@array_api_ruleset.register
|
|
50
|
-
def _bool(x: Boolean):
|
|
54
|
+
def _bool(x: Boolean, i: Int, j: Int):
|
|
51
55
|
return [
|
|
52
56
|
rule(eq(x).to(TRUE)).then(set_(x.bool).to(Bool(True))),
|
|
53
57
|
rule(eq(x).to(FALSE)).then(set_(x.bool).to(Bool(False))),
|
|
@@ -55,82 +59,8 @@ def _bool(x: Boolean):
|
|
|
55
59
|
rewrite(FALSE | x).to(x),
|
|
56
60
|
rewrite(TRUE & x).to(x),
|
|
57
61
|
rewrite(FALSE & x).to(FALSE),
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class DType(Expr):
|
|
62
|
-
float64: ClassVar[DType]
|
|
63
|
-
float32: ClassVar[DType]
|
|
64
|
-
int64: ClassVar[DType]
|
|
65
|
-
int32: ClassVar[DType]
|
|
66
|
-
object: ClassVar[DType]
|
|
67
|
-
bool: ClassVar[DType]
|
|
68
|
-
|
|
69
|
-
def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
|
|
70
|
-
...
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
float64 = DType.float64
|
|
74
|
-
float32 = DType.float32
|
|
75
|
-
int32 = DType.int32
|
|
76
|
-
int64 = DType.int64
|
|
77
|
-
|
|
78
|
-
_DTYPES = [float64, float32, int32, int64, DType.object]
|
|
79
|
-
|
|
80
|
-
converter(type, DType, lambda x: convert(np.dtype(x), DType))
|
|
81
|
-
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
|
|
82
|
-
array_api_ruleset.register(
|
|
83
|
-
*(rewrite(l == r).to(TRUE if l is r else FALSE) for l, r in itertools.product(_DTYPES, repeat=2))
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
class IsDtypeKind(Expr):
|
|
88
|
-
NULL: ClassVar[IsDtypeKind]
|
|
89
|
-
|
|
90
|
-
@classmethod
|
|
91
|
-
def string(cls, s: StringLike) -> IsDtypeKind: ...
|
|
92
|
-
|
|
93
|
-
@classmethod
|
|
94
|
-
def dtype(cls, d: DType) -> IsDtypeKind: ...
|
|
95
|
-
|
|
96
|
-
@method(cost=10)
|
|
97
|
-
def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
# TODO: Make kind more generic to support tuples.
|
|
101
|
-
@function
|
|
102
|
-
def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
|
|
106
|
-
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
|
|
107
|
-
converter(
|
|
108
|
-
tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
@array_api_ruleset.register
|
|
113
|
-
def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
|
|
114
|
-
return [
|
|
115
|
-
rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
|
|
116
|
-
rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
|
|
117
|
-
rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
|
|
118
|
-
rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
|
|
119
|
-
rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
|
|
120
|
-
rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
121
|
-
rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
122
|
-
rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
123
|
-
rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
124
|
-
rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
125
|
-
rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
126
|
-
rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
127
|
-
rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
128
|
-
rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
129
|
-
rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
130
|
-
rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
|
|
131
|
-
rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
|
|
132
|
-
rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
|
|
133
|
-
rewrite(k1 | IsDtypeKind.NULL).to(k1),
|
|
62
|
+
rewrite(TRUE.if_int(i, j)).to(i),
|
|
63
|
+
rewrite(FALSE.if_int(i, j)).to(j),
|
|
134
64
|
]
|
|
135
65
|
|
|
136
66
|
|
|
@@ -264,10 +194,13 @@ converter(int, Int, lambda x: Int(x))
|
|
|
264
194
|
|
|
265
195
|
|
|
266
196
|
class Float(Expr):
|
|
197
|
+
# Differentiate costs of three constructors so extraction is deterministic if all three are present
|
|
198
|
+
@method(cost=3)
|
|
267
199
|
def __init__(self, value: f64Like) -> None: ...
|
|
268
200
|
|
|
269
201
|
def abs(self) -> Float: ...
|
|
270
202
|
|
|
203
|
+
@method(cost=2)
|
|
271
204
|
@classmethod
|
|
272
205
|
def rational(cls, r: Rational) -> Float: ...
|
|
273
206
|
|
|
@@ -366,6 +299,85 @@ converter(type(None), OptionalInt, lambda _: OptionalInt.none)
|
|
|
366
299
|
converter(Int, OptionalInt, OptionalInt.some)
|
|
367
300
|
|
|
368
301
|
|
|
302
|
+
class DType(Expr):
|
|
303
|
+
float64: ClassVar[DType]
|
|
304
|
+
float32: ClassVar[DType]
|
|
305
|
+
int64: ClassVar[DType]
|
|
306
|
+
int32: ClassVar[DType]
|
|
307
|
+
object: ClassVar[DType]
|
|
308
|
+
bool: ClassVar[DType]
|
|
309
|
+
|
|
310
|
+
def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
|
|
311
|
+
...
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
float64 = DType.float64
|
|
315
|
+
float32 = DType.float32
|
|
316
|
+
int32 = DType.int32
|
|
317
|
+
int64 = DType.int64
|
|
318
|
+
|
|
319
|
+
_DTYPES = [float64, float32, int32, int64, DType.object]
|
|
320
|
+
|
|
321
|
+
converter(type, DType, lambda x: convert(np.dtype(x), DType))
|
|
322
|
+
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@array_api_ruleset.register
|
|
326
|
+
def _():
|
|
327
|
+
for l, r in itertools.product(_DTYPES, repeat=2):
|
|
328
|
+
yield rewrite(l == r).to(TRUE if l is r else FALSE)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class IsDtypeKind(Expr):
|
|
332
|
+
NULL: ClassVar[IsDtypeKind]
|
|
333
|
+
|
|
334
|
+
@classmethod
|
|
335
|
+
def string(cls, s: StringLike) -> IsDtypeKind: ...
|
|
336
|
+
|
|
337
|
+
@classmethod
|
|
338
|
+
def dtype(cls, d: DType) -> IsDtypeKind: ...
|
|
339
|
+
|
|
340
|
+
@method(cost=10)
|
|
341
|
+
def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# TODO: Make kind more generic to support tuples.
|
|
345
|
+
@function
|
|
346
|
+
def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
|
|
350
|
+
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
|
|
351
|
+
converter(
|
|
352
|
+
tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@array_api_ruleset.register
|
|
357
|
+
def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
|
|
358
|
+
return [
|
|
359
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE),
|
|
360
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE),
|
|
361
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE),
|
|
362
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE),
|
|
363
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE),
|
|
364
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
365
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE),
|
|
366
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
367
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
368
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE),
|
|
369
|
+
rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
370
|
+
rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
371
|
+
rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
372
|
+
rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
373
|
+
rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE),
|
|
374
|
+
rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE),
|
|
375
|
+
rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE),
|
|
376
|
+
rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)),
|
|
377
|
+
rewrite(k1 | IsDtypeKind.NULL).to(k1),
|
|
378
|
+
]
|
|
379
|
+
|
|
380
|
+
|
|
369
381
|
class Slice(Expr):
|
|
370
382
|
def __init__(
|
|
371
383
|
self,
|
egglog/exp/array_api_numba.py
CHANGED
|
@@ -31,7 +31,12 @@ def _std(y: NDArray, x: NDArray, i: Int):
|
|
|
31
31
|
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
|
|
32
32
|
# https://numpy.org/doc/stable/reference/generated/numpy.std.html
|
|
33
33
|
# "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
|
|
34
|
-
yield rewrite(
|
|
34
|
+
yield rewrite(
|
|
35
|
+
std(x, axis),
|
|
36
|
+
subsume=True,
|
|
37
|
+
).to(
|
|
38
|
+
sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)),
|
|
39
|
+
)
|
|
35
40
|
|
|
36
41
|
|
|
37
42
|
# rewrite unique_counts to count each value one by one, since numba doesn't support np.unique(..., return_counts=True)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import egglog
|
|
4
|
+
|
|
5
|
+
from .array_api import Int
|
|
6
|
+
|
|
7
|
+
# https://github.com/sklam/pyasir/blob/c363ff4f8f91177700ad4108dd5042b9b97d8289/pyasir/tests/test_fib.py
|
|
8
|
+
|
|
9
|
+
# In progress - should be able to re-create this
|
|
10
|
+
# @df.func
|
|
11
|
+
# def fib_ir(n: pyasir.Int64) -> pyasir.Int64:
|
|
12
|
+
# @df.switch(n <= 1)
|
|
13
|
+
# def swt(n):
|
|
14
|
+
# @df.case(1)
|
|
15
|
+
# def case0(n):
|
|
16
|
+
# return 1
|
|
17
|
+
|
|
18
|
+
# @df.case(0)
|
|
19
|
+
# def case1(n):
|
|
20
|
+
# return fib_ir(n - 1) + fib_ir(n - 2)
|
|
21
|
+
|
|
22
|
+
# yield case0
|
|
23
|
+
# yield case1
|
|
24
|
+
|
|
25
|
+
# r = swt(n)
|
|
26
|
+
# return r
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# With something like this:
|
|
30
|
+
@egglog.function
|
|
31
|
+
def fib(n: Int) -> Int:
|
|
32
|
+
return (n <= Int(1)).if_int(
|
|
33
|
+
Int(1),
|
|
34
|
+
fib(n - Int(1)) + fib(n - Int(2)),
|
|
35
|
+
)
|