egglog 7.1.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +146 -8
- egglog/egraph.py +337 -203
- egglog/egraph_state.py +171 -64
- egglog/examples/higher_order_functions.py +45 -0
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +97 -43
- egglog/runtime.py +60 -44
- egglog/thunk.py +44 -20
- egglog/type_constraint_solver.py +5 -4
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/METADATA +31 -30
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.1.0.dist-info/RECORD +0 -39
- {egglog-7.1.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
egglog/declarations.py
CHANGED
|
@@ -13,10 +13,11 @@ from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_c
|
|
|
13
13
|
from typing_extensions import Self, assert_never
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from collections.abc import Callable, Iterable
|
|
16
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
|
+
"replace_typed_expr",
|
|
20
21
|
"Declarations",
|
|
21
22
|
"DeclerationsLike",
|
|
22
23
|
"DelayedDeclerations",
|
|
@@ -29,6 +30,7 @@ __all__ = [
|
|
|
29
30
|
"MethodRef",
|
|
30
31
|
"ClassMethodRef",
|
|
31
32
|
"FunctionRef",
|
|
33
|
+
"UnnamedFunctionRef",
|
|
32
34
|
"ConstantRef",
|
|
33
35
|
"ClassVariableRef",
|
|
34
36
|
"PropertyRef",
|
|
@@ -71,6 +73,9 @@ __all__ = [
|
|
|
71
73
|
"CommandDecl",
|
|
72
74
|
"SpecialFunctions",
|
|
73
75
|
"FunctionSignature",
|
|
76
|
+
"DefaultRewriteDecl",
|
|
77
|
+
"InitRef",
|
|
78
|
+
"HasDeclerations",
|
|
74
79
|
]
|
|
75
80
|
|
|
76
81
|
|
|
@@ -80,7 +85,14 @@ class DelayedDeclerations:
|
|
|
80
85
|
|
|
81
86
|
@property
|
|
82
87
|
def __egg_decls__(self) -> Declarations:
|
|
83
|
-
|
|
88
|
+
thunk = self.__egg_decls_thunk__
|
|
89
|
+
try:
|
|
90
|
+
return thunk()
|
|
91
|
+
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
|
|
92
|
+
# instead raise explicitly
|
|
93
|
+
except AttributeError as err:
|
|
94
|
+
msg = f"Cannot resolve declerations for {self}"
|
|
95
|
+
raise RuntimeError(msg) from err
|
|
84
96
|
|
|
85
97
|
|
|
86
98
|
@runtime_checkable
|
|
@@ -108,11 +120,18 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D
|
|
|
108
120
|
|
|
109
121
|
@dataclass
|
|
110
122
|
class Declarations:
|
|
123
|
+
_unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
|
|
111
124
|
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
|
|
112
125
|
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
113
126
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
114
127
|
_rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
115
128
|
|
|
129
|
+
@property
|
|
130
|
+
def default_ruleset(self) -> RulesetDecl:
|
|
131
|
+
ruleset = self._rulesets[""]
|
|
132
|
+
assert isinstance(ruleset, RulesetDecl)
|
|
133
|
+
return ruleset
|
|
134
|
+
|
|
116
135
|
@classmethod
|
|
117
136
|
def create(cls, *others: DeclerationsLike) -> Declarations:
|
|
118
137
|
others = upcast_declerations(others)
|
|
@@ -127,7 +146,7 @@ class Declarations:
|
|
|
127
146
|
|
|
128
147
|
def copy(self) -> Declarations:
|
|
129
148
|
new = Declarations()
|
|
130
|
-
new
|
|
149
|
+
self.update_other(new)
|
|
131
150
|
return new
|
|
132
151
|
|
|
133
152
|
def update(self, *others: DeclerationsLike) -> None:
|
|
@@ -154,9 +173,13 @@ class Declarations:
|
|
|
154
173
|
other._functions |= self._functions
|
|
155
174
|
other._classes |= self._classes
|
|
156
175
|
other._constants |= self._constants
|
|
176
|
+
# Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
|
|
177
|
+
# is added to functions.
|
|
178
|
+
combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
|
|
157
179
|
other._rulesets |= self._rulesets
|
|
180
|
+
other._rulesets[""] = RulesetDecl(list(combined_default_rules))
|
|
158
181
|
|
|
159
|
-
def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
|
|
182
|
+
def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
|
|
160
183
|
match ref:
|
|
161
184
|
case FunctionRef(name):
|
|
162
185
|
return self._functions[name]
|
|
@@ -170,8 +193,31 @@ class Declarations:
|
|
|
170
193
|
return self._classes[class_name].class_methods[name]
|
|
171
194
|
case PropertyRef(class_name, property_name):
|
|
172
195
|
return self._classes[class_name].properties[property_name]
|
|
196
|
+
case InitRef(class_name):
|
|
197
|
+
init_fn = self._classes[class_name].init
|
|
198
|
+
assert init_fn
|
|
199
|
+
return init_fn
|
|
200
|
+
case UnnamedFunctionRef():
|
|
201
|
+
return ref.to_function_decl()
|
|
173
202
|
assert_never(ref)
|
|
174
203
|
|
|
204
|
+
def set_function_decl(
|
|
205
|
+
self, ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, decl: FunctionDecl
|
|
206
|
+
) -> None:
|
|
207
|
+
match ref:
|
|
208
|
+
case FunctionRef(name):
|
|
209
|
+
self._functions[name] = decl
|
|
210
|
+
case MethodRef(class_name, method_name):
|
|
211
|
+
self._classes[class_name].methods[method_name] = decl
|
|
212
|
+
case ClassMethodRef(class_name, name):
|
|
213
|
+
self._classes[class_name].class_methods[name] = decl
|
|
214
|
+
case PropertyRef(class_name, property_name):
|
|
215
|
+
self._classes[class_name].properties[property_name] = decl
|
|
216
|
+
case InitRef(class_name):
|
|
217
|
+
self._classes[class_name].init = decl
|
|
218
|
+
case _:
|
|
219
|
+
assert_never(ref)
|
|
220
|
+
|
|
175
221
|
def has_method(self, class_name: str, method_name: str) -> bool | None:
|
|
176
222
|
"""
|
|
177
223
|
Returns whether the given class has the given method, or None if we cant find the class.
|
|
@@ -183,12 +229,20 @@ class Declarations:
|
|
|
183
229
|
def get_class_decl(self, name: str) -> ClassDecl:
|
|
184
230
|
return self._classes[name]
|
|
185
231
|
|
|
232
|
+
def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
|
|
233
|
+
"""
|
|
234
|
+
Returns a class reference with type parameters, if the class is paramaterized.
|
|
235
|
+
"""
|
|
236
|
+
type_vars = self._classes[name].type_vars
|
|
237
|
+
return TypeRefWithVars(name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
238
|
+
|
|
186
239
|
|
|
187
240
|
@dataclass
|
|
188
241
|
class ClassDecl:
|
|
189
242
|
egg_name: str | None = None
|
|
190
243
|
type_vars: tuple[str, ...] = ()
|
|
191
244
|
builtin: bool = False
|
|
245
|
+
init: FunctionDecl | None = None
|
|
192
246
|
class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
193
247
|
# These have to be seperate from class_methods so that printing them can be done easily
|
|
194
248
|
class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
@@ -271,6 +325,37 @@ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
|
|
|
271
325
|
##
|
|
272
326
|
|
|
273
327
|
|
|
328
|
+
@dataclass(frozen=True)
|
|
329
|
+
class UnnamedFunctionRef:
|
|
330
|
+
"""
|
|
331
|
+
A reference to a function that doesn't have a name, but does have a body.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
# tuple of var arg names and their types
|
|
335
|
+
args: tuple[TypedExprDecl, ...]
|
|
336
|
+
res: TypedExprDecl
|
|
337
|
+
|
|
338
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
339
|
+
arg_types = []
|
|
340
|
+
arg_names = []
|
|
341
|
+
for a in self.args:
|
|
342
|
+
arg_types.append(a.tp.to_var())
|
|
343
|
+
assert isinstance(a.expr, VarDecl)
|
|
344
|
+
arg_names.append(a.expr.name)
|
|
345
|
+
return FunctionDecl(
|
|
346
|
+
FunctionSignature(
|
|
347
|
+
arg_types=tuple(arg_types),
|
|
348
|
+
arg_names=tuple(arg_names),
|
|
349
|
+
arg_defaults=(None,) * len(self.args),
|
|
350
|
+
return_type=self.res.tp.to_var(),
|
|
351
|
+
),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def egg_name(self) -> None | str:
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
|
|
274
359
|
@dataclass(frozen=True)
|
|
275
360
|
class FunctionRef:
|
|
276
361
|
name: str
|
|
@@ -293,6 +378,11 @@ class ClassMethodRef:
|
|
|
293
378
|
method_name: str
|
|
294
379
|
|
|
295
380
|
|
|
381
|
+
@dataclass(frozen=True)
|
|
382
|
+
class InitRef:
|
|
383
|
+
class_name: str
|
|
384
|
+
|
|
385
|
+
|
|
296
386
|
@dataclass(frozen=True)
|
|
297
387
|
class ClassVariableRef:
|
|
298
388
|
class_name: str
|
|
@@ -305,7 +395,16 @@ class PropertyRef:
|
|
|
305
395
|
property_name: str
|
|
306
396
|
|
|
307
397
|
|
|
308
|
-
CallableRef: TypeAlias =
|
|
398
|
+
CallableRef: TypeAlias = (
|
|
399
|
+
FunctionRef
|
|
400
|
+
| ConstantRef
|
|
401
|
+
| MethodRef
|
|
402
|
+
| ClassMethodRef
|
|
403
|
+
| InitRef
|
|
404
|
+
| ClassVariableRef
|
|
405
|
+
| PropertyRef
|
|
406
|
+
| UnnamedFunctionRef
|
|
407
|
+
)
|
|
309
408
|
|
|
310
409
|
|
|
311
410
|
##
|
|
@@ -378,7 +477,6 @@ class FunctionSignature:
|
|
|
378
477
|
@dataclass(frozen=True)
|
|
379
478
|
class FunctionDecl:
|
|
380
479
|
signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
|
|
381
|
-
|
|
382
480
|
# Egg params
|
|
383
481
|
builtin: bool = False
|
|
384
482
|
egg_name: str | None = None
|
|
@@ -402,6 +500,8 @@ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
|
|
|
402
500
|
@dataclass(frozen=True)
|
|
403
501
|
class VarDecl:
|
|
404
502
|
name: str
|
|
503
|
+
# Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
|
|
504
|
+
is_let: bool
|
|
405
505
|
|
|
406
506
|
|
|
407
507
|
@dataclass(frozen=True)
|
|
@@ -458,7 +558,7 @@ class CallDecl:
|
|
|
458
558
|
bound_tp_params: tuple[JustTypeRef, ...] | None = None
|
|
459
559
|
|
|
460
560
|
def __post_init__(self) -> None:
|
|
461
|
-
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
|
|
561
|
+
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
|
|
462
562
|
msg = "Cannot bind type parameters to a non-class method callable."
|
|
463
563
|
raise ValueError(msg)
|
|
464
564
|
|
|
@@ -508,6 +608,38 @@ class TypedExprDecl:
|
|
|
508
608
|
return l
|
|
509
609
|
|
|
510
610
|
|
|
611
|
+
def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
|
|
612
|
+
"""
|
|
613
|
+
Replace all the typed expressions in the given typed expression with the replacements.
|
|
614
|
+
"""
|
|
615
|
+
# keep track of the traversed expressions for memoization
|
|
616
|
+
traversed: dict[TypedExprDecl, TypedExprDecl] = {}
|
|
617
|
+
|
|
618
|
+
def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
|
|
619
|
+
if typed_expr in traversed:
|
|
620
|
+
return traversed[typed_expr]
|
|
621
|
+
if typed_expr in replacements:
|
|
622
|
+
res = replacements[typed_expr]
|
|
623
|
+
else:
|
|
624
|
+
match typed_expr.expr:
|
|
625
|
+
case (
|
|
626
|
+
CallDecl(callable, args, bound_tp_params)
|
|
627
|
+
| PartialCallDecl(CallDecl(callable, args, bound_tp_params))
|
|
628
|
+
):
|
|
629
|
+
new_args = tuple(_inner(a) for a in args)
|
|
630
|
+
call_decl = CallDecl(callable, new_args, bound_tp_params)
|
|
631
|
+
res = TypedExprDecl(
|
|
632
|
+
typed_expr.tp,
|
|
633
|
+
call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
|
|
634
|
+
)
|
|
635
|
+
case _:
|
|
636
|
+
res = typed_expr
|
|
637
|
+
traversed[typed_expr] = res
|
|
638
|
+
return res
|
|
639
|
+
|
|
640
|
+
return _inner(typed_expr)
|
|
641
|
+
|
|
642
|
+
|
|
511
643
|
##
|
|
512
644
|
# Schedules
|
|
513
645
|
##
|
|
@@ -629,7 +761,13 @@ class RuleDecl:
|
|
|
629
761
|
name: str | None
|
|
630
762
|
|
|
631
763
|
|
|
632
|
-
|
|
764
|
+
@dataclass(frozen=True)
|
|
765
|
+
class DefaultRewriteDecl:
|
|
766
|
+
ref: CallableRef
|
|
767
|
+
expr: ExprDecl
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
|
|
633
771
|
|
|
634
772
|
|
|
635
773
|
@dataclass(frozen=True)
|