egglog 7.2.0__cp312-none-win_amd64.whl → 8.0.1__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +107 -53
- egglog/builtins.py +49 -6
- egglog/conversion.py +32 -9
- egglog/declarations.py +82 -4
- egglog/egraph.py +260 -179
- egglog/egraph_state.py +149 -66
- egglog/examples/higher_order_functions.py +4 -9
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +4 -8
- egglog/exp/array_api_loopnest.py +149 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +62 -25
- egglog/exp/program_gen.py +23 -17
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +88 -44
- egglog/runtime.py +53 -40
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35774 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/METADATA +33 -32
- egglog-8.0.1.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.1.dist-info/licenses}/LICENSE +0 -0
egglog/declarations.py
CHANGED
|
@@ -13,10 +13,11 @@ from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_c
|
|
|
13
13
|
from typing_extensions import Self, assert_never
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from collections.abc import Callable, Iterable
|
|
16
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
|
+
"replace_typed_expr",
|
|
20
21
|
"Declarations",
|
|
21
22
|
"DeclerationsLike",
|
|
22
23
|
"DelayedDeclerations",
|
|
@@ -29,6 +30,7 @@ __all__ = [
|
|
|
29
30
|
"MethodRef",
|
|
30
31
|
"ClassMethodRef",
|
|
31
32
|
"FunctionRef",
|
|
33
|
+
"UnnamedFunctionRef",
|
|
32
34
|
"ConstantRef",
|
|
33
35
|
"ClassVariableRef",
|
|
34
36
|
"PropertyRef",
|
|
@@ -73,6 +75,7 @@ __all__ = [
|
|
|
73
75
|
"FunctionSignature",
|
|
74
76
|
"DefaultRewriteDecl",
|
|
75
77
|
"InitRef",
|
|
78
|
+
"HasDeclerations",
|
|
76
79
|
]
|
|
77
80
|
|
|
78
81
|
|
|
@@ -82,12 +85,13 @@ class DelayedDeclerations:
|
|
|
82
85
|
|
|
83
86
|
@property
|
|
84
87
|
def __egg_decls__(self) -> Declarations:
|
|
88
|
+
thunk = self.__egg_decls_thunk__
|
|
85
89
|
try:
|
|
86
|
-
return
|
|
90
|
+
return thunk()
|
|
87
91
|
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
|
|
88
92
|
# instead raise explicitly
|
|
89
93
|
except AttributeError as err:
|
|
90
|
-
msg = "
|
|
94
|
+
msg = f"Cannot resolve declerations for {self}"
|
|
91
95
|
raise RuntimeError(msg) from err
|
|
92
96
|
|
|
93
97
|
|
|
@@ -116,6 +120,7 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D
|
|
|
116
120
|
|
|
117
121
|
@dataclass
|
|
118
122
|
class Declarations:
|
|
123
|
+
_unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
|
|
119
124
|
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
|
|
120
125
|
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
121
126
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
@@ -192,6 +197,8 @@ class Declarations:
|
|
|
192
197
|
init_fn = self._classes[class_name].init
|
|
193
198
|
assert init_fn
|
|
194
199
|
return init_fn
|
|
200
|
+
case UnnamedFunctionRef():
|
|
201
|
+
return ref.to_function_decl()
|
|
195
202
|
assert_never(ref)
|
|
196
203
|
|
|
197
204
|
def set_function_decl(
|
|
@@ -318,6 +325,37 @@ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
|
|
|
318
325
|
##
|
|
319
326
|
|
|
320
327
|
|
|
328
|
+
@dataclass(frozen=True)
|
|
329
|
+
class UnnamedFunctionRef:
|
|
330
|
+
"""
|
|
331
|
+
A reference to a function that doesn't have a name, but does have a body.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
# tuple of var arg names and their types
|
|
335
|
+
args: tuple[TypedExprDecl, ...]
|
|
336
|
+
res: TypedExprDecl
|
|
337
|
+
|
|
338
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
339
|
+
arg_types = []
|
|
340
|
+
arg_names = []
|
|
341
|
+
for a in self.args:
|
|
342
|
+
arg_types.append(a.tp.to_var())
|
|
343
|
+
assert isinstance(a.expr, VarDecl)
|
|
344
|
+
arg_names.append(a.expr.name)
|
|
345
|
+
return FunctionDecl(
|
|
346
|
+
FunctionSignature(
|
|
347
|
+
arg_types=tuple(arg_types),
|
|
348
|
+
arg_names=tuple(arg_names),
|
|
349
|
+
arg_defaults=(None,) * len(self.args),
|
|
350
|
+
return_type=self.res.tp.to_var(),
|
|
351
|
+
),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def egg_name(self) -> None | str:
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
|
|
321
359
|
@dataclass(frozen=True)
|
|
322
360
|
class FunctionRef:
|
|
323
361
|
name: str
|
|
@@ -358,7 +396,14 @@ class PropertyRef:
|
|
|
358
396
|
|
|
359
397
|
|
|
360
398
|
CallableRef: TypeAlias = (
|
|
361
|
-
FunctionRef
|
|
399
|
+
FunctionRef
|
|
400
|
+
| ConstantRef
|
|
401
|
+
| MethodRef
|
|
402
|
+
| ClassMethodRef
|
|
403
|
+
| InitRef
|
|
404
|
+
| ClassVariableRef
|
|
405
|
+
| PropertyRef
|
|
406
|
+
| UnnamedFunctionRef
|
|
362
407
|
)
|
|
363
408
|
|
|
364
409
|
|
|
@@ -455,6 +500,8 @@ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
|
|
|
455
500
|
@dataclass(frozen=True)
|
|
456
501
|
class VarDecl:
|
|
457
502
|
name: str
|
|
503
|
+
# Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
|
|
504
|
+
is_let: bool
|
|
458
505
|
|
|
459
506
|
|
|
460
507
|
@dataclass(frozen=True)
|
|
@@ -561,6 +608,37 @@ class TypedExprDecl:
|
|
|
561
608
|
return l
|
|
562
609
|
|
|
563
610
|
|
|
611
|
+
def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
|
|
612
|
+
"""
|
|
613
|
+
Replace all the typed expressions in the given typed expression with the replacements.
|
|
614
|
+
"""
|
|
615
|
+
# keep track of the traversed expressions for memoization
|
|
616
|
+
traversed: dict[TypedExprDecl, TypedExprDecl] = {}
|
|
617
|
+
|
|
618
|
+
def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
|
|
619
|
+
if typed_expr in traversed:
|
|
620
|
+
return traversed[typed_expr]
|
|
621
|
+
if typed_expr in replacements:
|
|
622
|
+
res = replacements[typed_expr]
|
|
623
|
+
else:
|
|
624
|
+
match typed_expr.expr:
|
|
625
|
+
case CallDecl(callable, args, bound_tp_params) | PartialCallDecl(
|
|
626
|
+
CallDecl(callable, args, bound_tp_params)
|
|
627
|
+
):
|
|
628
|
+
new_args = tuple(_inner(a) for a in args)
|
|
629
|
+
call_decl = CallDecl(callable, new_args, bound_tp_params)
|
|
630
|
+
res = TypedExprDecl(
|
|
631
|
+
typed_expr.tp,
|
|
632
|
+
call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
|
|
633
|
+
)
|
|
634
|
+
case _:
|
|
635
|
+
res = typed_expr
|
|
636
|
+
traversed[typed_expr] = res
|
|
637
|
+
return res
|
|
638
|
+
|
|
639
|
+
return _inner(typed_expr)
|
|
640
|
+
|
|
641
|
+
|
|
564
642
|
##
|
|
565
643
|
# Schedules
|
|
566
644
|
##
|