egglog 7.1.0__cp311-none-win_amd64.whl → 7.2.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/declarations.py +66 -7
- egglog/egraph.py +141 -75
- egglog/egraph_state.py +51 -20
- egglog/examples/higher_order_functions.py +50 -0
- egglog/exp/array_api.py +4 -1
- egglog/pretty.py +15 -5
- egglog/runtime.py +9 -6
- egglog/thunk.py +16 -4
- egglog/type_constraint_solver.py +5 -4
- {egglog-7.1.0.dist-info → egglog-7.2.0.dist-info}/METADATA +19 -19
- {egglog-7.1.0.dist-info → egglog-7.2.0.dist-info}/RECORD +14 -13
- {egglog-7.1.0.dist-info → egglog-7.2.0.dist-info}/WHEEL +0 -0
- {egglog-7.1.0.dist-info → egglog-7.2.0.dist-info}/license_files/LICENSE +0 -0
|
Binary file
|
egglog/declarations.py
CHANGED
|
@@ -71,6 +71,8 @@ __all__ = [
|
|
|
71
71
|
"CommandDecl",
|
|
72
72
|
"SpecialFunctions",
|
|
73
73
|
"FunctionSignature",
|
|
74
|
+
"DefaultRewriteDecl",
|
|
75
|
+
"InitRef",
|
|
74
76
|
]
|
|
75
77
|
|
|
76
78
|
|
|
@@ -80,7 +82,13 @@ class DelayedDeclerations:
|
|
|
80
82
|
|
|
81
83
|
@property
|
|
82
84
|
def __egg_decls__(self) -> Declarations:
|
|
83
|
-
|
|
85
|
+
try:
|
|
86
|
+
return self.__egg_decls_thunk__()
|
|
87
|
+
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
|
|
88
|
+
# instead raise explicitly
|
|
89
|
+
except AttributeError as err:
|
|
90
|
+
msg = "Error resolving declerations"
|
|
91
|
+
raise RuntimeError(msg) from err
|
|
84
92
|
|
|
85
93
|
|
|
86
94
|
@runtime_checkable
|
|
@@ -113,6 +121,12 @@ class Declarations:
|
|
|
113
121
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
114
122
|
_rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
115
123
|
|
|
124
|
+
@property
|
|
125
|
+
def default_ruleset(self) -> RulesetDecl:
|
|
126
|
+
ruleset = self._rulesets[""]
|
|
127
|
+
assert isinstance(ruleset, RulesetDecl)
|
|
128
|
+
return ruleset
|
|
129
|
+
|
|
116
130
|
@classmethod
|
|
117
131
|
def create(cls, *others: DeclerationsLike) -> Declarations:
|
|
118
132
|
others = upcast_declerations(others)
|
|
@@ -127,7 +141,7 @@ class Declarations:
|
|
|
127
141
|
|
|
128
142
|
def copy(self) -> Declarations:
|
|
129
143
|
new = Declarations()
|
|
130
|
-
new
|
|
144
|
+
self.update_other(new)
|
|
131
145
|
return new
|
|
132
146
|
|
|
133
147
|
def update(self, *others: DeclerationsLike) -> None:
|
|
@@ -154,9 +168,13 @@ class Declarations:
|
|
|
154
168
|
other._functions |= self._functions
|
|
155
169
|
other._classes |= self._classes
|
|
156
170
|
other._constants |= self._constants
|
|
171
|
+
# Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
|
|
172
|
+
# is added to functions.
|
|
173
|
+
combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
|
|
157
174
|
other._rulesets |= self._rulesets
|
|
175
|
+
other._rulesets[""] = RulesetDecl(list(combined_default_rules))
|
|
158
176
|
|
|
159
|
-
def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
|
|
177
|
+
def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
|
|
160
178
|
match ref:
|
|
161
179
|
case FunctionRef(name):
|
|
162
180
|
return self._functions[name]
|
|
@@ -170,8 +188,29 @@ class Declarations:
|
|
|
170
188
|
return self._classes[class_name].class_methods[name]
|
|
171
189
|
case PropertyRef(class_name, property_name):
|
|
172
190
|
return self._classes[class_name].properties[property_name]
|
|
191
|
+
case InitRef(class_name):
|
|
192
|
+
init_fn = self._classes[class_name].init
|
|
193
|
+
assert init_fn
|
|
194
|
+
return init_fn
|
|
173
195
|
assert_never(ref)
|
|
174
196
|
|
|
197
|
+
def set_function_decl(
|
|
198
|
+
self, ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, decl: FunctionDecl
|
|
199
|
+
) -> None:
|
|
200
|
+
match ref:
|
|
201
|
+
case FunctionRef(name):
|
|
202
|
+
self._functions[name] = decl
|
|
203
|
+
case MethodRef(class_name, method_name):
|
|
204
|
+
self._classes[class_name].methods[method_name] = decl
|
|
205
|
+
case ClassMethodRef(class_name, name):
|
|
206
|
+
self._classes[class_name].class_methods[name] = decl
|
|
207
|
+
case PropertyRef(class_name, property_name):
|
|
208
|
+
self._classes[class_name].properties[property_name] = decl
|
|
209
|
+
case InitRef(class_name):
|
|
210
|
+
self._classes[class_name].init = decl
|
|
211
|
+
case _:
|
|
212
|
+
assert_never(ref)
|
|
213
|
+
|
|
175
214
|
def has_method(self, class_name: str, method_name: str) -> bool | None:
|
|
176
215
|
"""
|
|
177
216
|
Returns whether the given class has the given method, or None if we cant find the class.
|
|
@@ -183,12 +222,20 @@ class Declarations:
|
|
|
183
222
|
def get_class_decl(self, name: str) -> ClassDecl:
|
|
184
223
|
return self._classes[name]
|
|
185
224
|
|
|
225
|
+
def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
|
|
226
|
+
"""
|
|
227
|
+
Returns a class reference with type parameters, if the class is paramaterized.
|
|
228
|
+
"""
|
|
229
|
+
type_vars = self._classes[name].type_vars
|
|
230
|
+
return TypeRefWithVars(name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
231
|
+
|
|
186
232
|
|
|
187
233
|
@dataclass
|
|
188
234
|
class ClassDecl:
|
|
189
235
|
egg_name: str | None = None
|
|
190
236
|
type_vars: tuple[str, ...] = ()
|
|
191
237
|
builtin: bool = False
|
|
238
|
+
init: FunctionDecl | None = None
|
|
192
239
|
class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
193
240
|
# These have to be seperate from class_methods so that printing them can be done easily
|
|
194
241
|
class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
@@ -293,6 +340,11 @@ class ClassMethodRef:
|
|
|
293
340
|
method_name: str
|
|
294
341
|
|
|
295
342
|
|
|
343
|
+
@dataclass(frozen=True)
|
|
344
|
+
class InitRef:
|
|
345
|
+
class_name: str
|
|
346
|
+
|
|
347
|
+
|
|
296
348
|
@dataclass(frozen=True)
|
|
297
349
|
class ClassVariableRef:
|
|
298
350
|
class_name: str
|
|
@@ -305,7 +357,9 @@ class PropertyRef:
|
|
|
305
357
|
property_name: str
|
|
306
358
|
|
|
307
359
|
|
|
308
|
-
CallableRef: TypeAlias =
|
|
360
|
+
CallableRef: TypeAlias = (
|
|
361
|
+
FunctionRef | ConstantRef | MethodRef | ClassMethodRef | InitRef | ClassVariableRef | PropertyRef
|
|
362
|
+
)
|
|
309
363
|
|
|
310
364
|
|
|
311
365
|
##
|
|
@@ -378,7 +432,6 @@ class FunctionSignature:
|
|
|
378
432
|
@dataclass(frozen=True)
|
|
379
433
|
class FunctionDecl:
|
|
380
434
|
signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
|
|
381
|
-
|
|
382
435
|
# Egg params
|
|
383
436
|
builtin: bool = False
|
|
384
437
|
egg_name: str | None = None
|
|
@@ -458,7 +511,7 @@ class CallDecl:
|
|
|
458
511
|
bound_tp_params: tuple[JustTypeRef, ...] | None = None
|
|
459
512
|
|
|
460
513
|
def __post_init__(self) -> None:
|
|
461
|
-
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
|
|
514
|
+
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
|
|
462
515
|
msg = "Cannot bind type parameters to a non-class method callable."
|
|
463
516
|
raise ValueError(msg)
|
|
464
517
|
|
|
@@ -629,7 +682,13 @@ class RuleDecl:
|
|
|
629
682
|
name: str | None
|
|
630
683
|
|
|
631
684
|
|
|
632
|
-
|
|
685
|
+
@dataclass(frozen=True)
|
|
686
|
+
class DefaultRewriteDecl:
|
|
687
|
+
ref: CallableRef
|
|
688
|
+
expr: ExprDecl
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
|
|
633
692
|
|
|
634
693
|
|
|
635
694
|
@dataclass(frozen=True)
|
egglog/egraph.py
CHANGED
|
@@ -352,7 +352,7 @@ class _BaseModule:
|
|
|
352
352
|
This is the same as defining a nullary function with a high cost.
|
|
353
353
|
# TODO: Rename as declare to match eggglog?
|
|
354
354
|
"""
|
|
355
|
-
return constant(name, tp, egg_name)
|
|
355
|
+
return constant(name, tp, egg_name=egg_name)
|
|
356
356
|
|
|
357
357
|
def register(
|
|
358
358
|
self,
|
|
@@ -452,6 +452,7 @@ class _ExprMetaclass(type):
|
|
|
452
452
|
namespace: dict[str, Any],
|
|
453
453
|
egg_sort: str | None = None,
|
|
454
454
|
builtin: bool = False,
|
|
455
|
+
ruleset: Ruleset | None = None,
|
|
455
456
|
) -> RuntimeClass | type:
|
|
456
457
|
# If this is the Expr subclass, just return the class
|
|
457
458
|
if not bases:
|
|
@@ -467,7 +468,14 @@ class _ExprMetaclass(type):
|
|
|
467
468
|
# Otherwise, f_locals returns a copy
|
|
468
469
|
# https://peps.python.org/pep-0667/
|
|
469
470
|
decls_thunk = Thunk.fn(
|
|
470
|
-
_generate_class_decls,
|
|
471
|
+
_generate_class_decls,
|
|
472
|
+
namespace,
|
|
473
|
+
prev_frame,
|
|
474
|
+
builtin,
|
|
475
|
+
egg_sort,
|
|
476
|
+
name,
|
|
477
|
+
ruleset,
|
|
478
|
+
fallback=Declarations,
|
|
471
479
|
)
|
|
472
480
|
return RuntimeClass(decls_thunk, TypeRefWithVars(name))
|
|
473
481
|
|
|
@@ -475,8 +483,13 @@ class _ExprMetaclass(type):
|
|
|
475
483
|
return isinstance(instance, RuntimeExpr)
|
|
476
484
|
|
|
477
485
|
|
|
478
|
-
def _generate_class_decls(
|
|
479
|
-
namespace: dict[str, Any],
|
|
486
|
+
def _generate_class_decls(
|
|
487
|
+
namespace: dict[str, Any],
|
|
488
|
+
frame: FrameType,
|
|
489
|
+
builtin: bool,
|
|
490
|
+
egg_sort: str | None,
|
|
491
|
+
cls_name: str,
|
|
492
|
+
ruleset: Ruleset | None,
|
|
480
493
|
) -> Declarations:
|
|
481
494
|
"""
|
|
482
495
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
@@ -498,9 +511,9 @@ def _generate_class_decls( # noqa: C901
|
|
|
498
511
|
for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
|
|
499
512
|
if getattr(v, "__origin__", None) == ClassVar:
|
|
500
513
|
(inner_tp,) = v.__args__
|
|
501
|
-
type_ref = resolve_type_annotation(decls, inner_tp)
|
|
502
|
-
cls_decl.class_variables[k] = ConstantDecl(type_ref)
|
|
503
|
-
|
|
514
|
+
type_ref = resolve_type_annotation(decls, inner_tp)
|
|
515
|
+
cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
|
|
516
|
+
_add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
|
|
504
517
|
else:
|
|
505
518
|
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
506
519
|
raise NotImplementedError(msg)
|
|
@@ -510,13 +523,15 @@ def _generate_class_decls( # noqa: C901
|
|
|
510
523
|
##
|
|
511
524
|
|
|
512
525
|
# The type ref of self is paramterized by the type vars
|
|
513
|
-
|
|
526
|
+
TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
514
527
|
|
|
515
528
|
# Get all the methods from the class
|
|
516
529
|
filtered_namespace: list[tuple[str, Any]] = [
|
|
517
530
|
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
518
531
|
]
|
|
519
532
|
|
|
533
|
+
# all methods we should try adding default functions for
|
|
534
|
+
default_function_refs: dict[ClassMethodRef | MethodRef | PropertyRef, Callable] = {}
|
|
520
535
|
# Then register each of its methods
|
|
521
536
|
for method_name, method in filtered_namespace:
|
|
522
537
|
is_init = method_name == "__init__"
|
|
@@ -536,43 +551,26 @@ def _generate_class_decls( # noqa: C901
|
|
|
536
551
|
continue
|
|
537
552
|
locals = frame.f_locals
|
|
538
553
|
|
|
539
|
-
|
|
540
|
-
special_function_name: SpecialFunctions | None = (
|
|
541
|
-
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
|
|
542
|
-
)
|
|
543
|
-
if special_function_name:
|
|
544
|
-
return FunctionDecl(
|
|
545
|
-
special_function_name,
|
|
546
|
-
builtin=True,
|
|
547
|
-
egg_name=egg_fn, # noqa: B023
|
|
548
|
-
)
|
|
549
|
-
|
|
550
|
-
return _fn_decl(
|
|
551
|
-
decls,
|
|
552
|
-
egg_fn, # noqa: B023
|
|
553
|
-
fn,
|
|
554
|
-
locals, # noqa: B023
|
|
555
|
-
default, # noqa: B023
|
|
556
|
-
cost, # noqa: B023
|
|
557
|
-
merge, # noqa: B023
|
|
558
|
-
on_merge, # noqa: B023
|
|
559
|
-
mutates, # noqa: B023
|
|
560
|
-
builtin,
|
|
561
|
-
first,
|
|
562
|
-
is_init, # noqa: B023
|
|
563
|
-
unextractable, # noqa: B023
|
|
564
|
-
)
|
|
565
|
-
|
|
554
|
+
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
566
555
|
match fn:
|
|
567
556
|
case classmethod():
|
|
568
|
-
|
|
557
|
+
ref = ClassMethodRef(cls_name, method_name)
|
|
558
|
+
fn = fn.__func__
|
|
569
559
|
case property():
|
|
570
|
-
|
|
560
|
+
ref = PropertyRef(cls_name, method_name)
|
|
561
|
+
fn = fn.fget
|
|
571
562
|
case _:
|
|
572
|
-
if is_init
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
563
|
+
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
|
|
564
|
+
|
|
565
|
+
_fn_decl(decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, unextractable)
|
|
566
|
+
|
|
567
|
+
if not builtin and not isinstance(ref, InitRef) and not mutates:
|
|
568
|
+
default_function_refs[ref] = fn
|
|
569
|
+
|
|
570
|
+
# Add all rewrite methods at the end so that all methods are registered first and can be accessed
|
|
571
|
+
# in the bodies
|
|
572
|
+
for ref, fn in default_function_refs.items():
|
|
573
|
+
_add_default_rewrite_function(decls, ref, fn, ruleset)
|
|
576
574
|
|
|
577
575
|
return decls
|
|
578
576
|
|
|
@@ -591,6 +589,7 @@ def function(
|
|
|
591
589
|
mutates_first_arg: bool = False,
|
|
592
590
|
unextractable: bool = False,
|
|
593
591
|
builtin: bool = False,
|
|
592
|
+
ruleset: Ruleset | None = None,
|
|
594
593
|
) -> Callable[[CALLABLE], CALLABLE]: ...
|
|
595
594
|
|
|
596
595
|
|
|
@@ -604,6 +603,7 @@ def function(
|
|
|
604
603
|
on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
|
|
605
604
|
mutates_first_arg: bool = False,
|
|
606
605
|
unextractable: bool = False,
|
|
606
|
+
ruleset: Ruleset | None = None,
|
|
607
607
|
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
|
|
608
608
|
|
|
609
609
|
|
|
@@ -634,15 +634,17 @@ class _FunctionConstructor:
|
|
|
634
634
|
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
635
635
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
636
636
|
unextractable: bool = False
|
|
637
|
+
ruleset: Ruleset | None = None
|
|
637
638
|
|
|
638
639
|
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
639
640
|
return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
|
|
640
641
|
|
|
641
642
|
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
642
643
|
decls = Declarations()
|
|
643
|
-
|
|
644
|
+
_fn_decl(
|
|
644
645
|
decls,
|
|
645
646
|
self.egg_fn,
|
|
647
|
+
(ref := FunctionRef(fn.__name__)),
|
|
646
648
|
fn,
|
|
647
649
|
self.hint_locals,
|
|
648
650
|
self.default,
|
|
@@ -653,12 +655,14 @@ class _FunctionConstructor:
|
|
|
653
655
|
self.builtin,
|
|
654
656
|
unextractable=self.unextractable,
|
|
655
657
|
)
|
|
658
|
+
_add_default_rewrite_function(decls, ref, fn, self.ruleset)
|
|
656
659
|
return decls
|
|
657
660
|
|
|
658
661
|
|
|
659
662
|
def _fn_decl(
|
|
660
663
|
decls: Declarations,
|
|
661
664
|
egg_name: str | None,
|
|
665
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
|
|
662
666
|
fn: object,
|
|
663
667
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
664
668
|
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
@@ -669,14 +673,23 @@ def _fn_decl(
|
|
|
669
673
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
|
|
670
674
|
mutates_first_arg: bool,
|
|
671
675
|
is_builtin: bool,
|
|
672
|
-
# The first arg is either cls, for a classmethod, a self type, or none for a function
|
|
673
|
-
first_arg: Literal["cls"] | TypeOrVarRef | None = None,
|
|
674
|
-
is_init: bool = False,
|
|
675
676
|
unextractable: bool = False,
|
|
676
|
-
) ->
|
|
677
|
+
) -> None:
|
|
678
|
+
"""
|
|
679
|
+
Sets the function decl for the function object.
|
|
680
|
+
"""
|
|
677
681
|
if not isinstance(fn, FunctionType):
|
|
678
682
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
679
683
|
|
|
684
|
+
# partial function creation and calling are handled with a special case in the type checker, so don't
|
|
685
|
+
# use the normal logic
|
|
686
|
+
special_function_name: SpecialFunctions | None = (
|
|
687
|
+
"fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
|
|
688
|
+
)
|
|
689
|
+
if special_function_name:
|
|
690
|
+
decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
|
|
691
|
+
return
|
|
692
|
+
|
|
680
693
|
hint_globals = fn.__globals__.copy()
|
|
681
694
|
# Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
|
|
682
695
|
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
@@ -687,8 +700,8 @@ def _fn_decl(
|
|
|
687
700
|
|
|
688
701
|
params = list(signature(fn).parameters.values())
|
|
689
702
|
|
|
690
|
-
# If this is an init function, or a classmethod,
|
|
691
|
-
if
|
|
703
|
+
# If this is an init function, or a classmethod, the first arg is not used
|
|
704
|
+
if isinstance(ref, ClassMethodRef | InitRef):
|
|
692
705
|
params = params[1:]
|
|
693
706
|
|
|
694
707
|
if _last_param_variable(params):
|
|
@@ -698,9 +711,8 @@ def _fn_decl(
|
|
|
698
711
|
else:
|
|
699
712
|
var_arg_type = None
|
|
700
713
|
arg_types = tuple(
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
if i == 0 and isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init
|
|
714
|
+
decls.get_paramaterized_class(ref.class_name)
|
|
715
|
+
if i == 0 and isinstance(ref, MethodRef | PropertyRef)
|
|
704
716
|
else resolve_type_annotation(decls, hints[t.name])
|
|
705
717
|
for i, t in enumerate(params)
|
|
706
718
|
)
|
|
@@ -713,17 +725,15 @@ def _fn_decl(
|
|
|
713
725
|
|
|
714
726
|
decls.update(*arg_defaults)
|
|
715
727
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
else:
|
|
726
|
-
return_type = resolve_type_annotation(decls, hints["return"])
|
|
728
|
+
return_type = (
|
|
729
|
+
decls.get_paramaterized_class(ref.class_name)
|
|
730
|
+
if isinstance(ref, InitRef)
|
|
731
|
+
else arg_types[0]
|
|
732
|
+
if mutates_first_arg
|
|
733
|
+
else resolve_type_annotation(decls, hints["return"])
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
arg_names = tuple(t.name for t in params)
|
|
727
737
|
|
|
728
738
|
decls |= default
|
|
729
739
|
merged = (
|
|
@@ -747,12 +757,12 @@ def _fn_decl(
|
|
|
747
757
|
)
|
|
748
758
|
)
|
|
749
759
|
decls.update(*merge_action)
|
|
750
|
-
|
|
760
|
+
decl = FunctionDecl(
|
|
751
761
|
FunctionSignature(
|
|
752
762
|
return_type=None if mutates_first_arg else return_type,
|
|
753
763
|
var_arg_type=var_arg_type,
|
|
754
764
|
arg_types=arg_types,
|
|
755
|
-
arg_names=
|
|
765
|
+
arg_names=arg_names,
|
|
756
766
|
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
757
767
|
),
|
|
758
768
|
cost=cost,
|
|
@@ -763,6 +773,7 @@ def _fn_decl(
|
|
|
763
773
|
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
764
774
|
on_merge=tuple(a.action for a in merge_action),
|
|
765
775
|
)
|
|
776
|
+
decls.set_function_decl(ref, decl)
|
|
766
777
|
|
|
767
778
|
|
|
768
779
|
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
@@ -804,19 +815,73 @@ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Dec
|
|
|
804
815
|
return decls
|
|
805
816
|
|
|
806
817
|
|
|
807
|
-
def constant(
|
|
818
|
+
def constant(
|
|
819
|
+
name: str,
|
|
820
|
+
tp: type[EXPR],
|
|
821
|
+
default_replacement: EXPR | None = None,
|
|
822
|
+
/,
|
|
823
|
+
*,
|
|
824
|
+
egg_name: str | None = None,
|
|
825
|
+
ruleset: Ruleset | None = None,
|
|
826
|
+
) -> EXPR:
|
|
808
827
|
"""
|
|
809
828
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
810
829
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
811
830
|
"""
|
|
812
|
-
return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
|
|
831
|
+
return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
|
|
813
832
|
|
|
814
833
|
|
|
815
|
-
def _constant_thunk(
|
|
834
|
+
def _constant_thunk(
|
|
835
|
+
name: str, tp: type, egg_name: str | None, default_replacement: object, ruleset: Ruleset | None
|
|
836
|
+
) -> tuple[Declarations, TypedExprDecl]:
|
|
816
837
|
decls = Declarations()
|
|
817
|
-
type_ref = resolve_type_annotation(decls, tp)
|
|
818
|
-
|
|
819
|
-
|
|
838
|
+
type_ref = resolve_type_annotation(decls, tp)
|
|
839
|
+
callable_ref = ConstantRef(name)
|
|
840
|
+
decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
|
|
841
|
+
_add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
|
|
842
|
+
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def _add_default_rewrite_function(decls: Declarations, ref: CallableRef, fn: Callable, ruleset: Ruleset | None) -> None:
|
|
846
|
+
"""
|
|
847
|
+
Adds a default rewrite for a function, by calling the functions with vars and adding it if it is not None.
|
|
848
|
+
"""
|
|
849
|
+
callable_decl = decls.get_callable_decl(ref)
|
|
850
|
+
assert isinstance(callable_decl, FunctionDecl)
|
|
851
|
+
signature = callable_decl.signature
|
|
852
|
+
assert isinstance(signature, FunctionSignature)
|
|
853
|
+
|
|
854
|
+
var_args: list[object] = [
|
|
855
|
+
RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
|
|
856
|
+
for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
|
|
857
|
+
]
|
|
858
|
+
# If this is a classmethod, add the class as the first arg
|
|
859
|
+
if isinstance(ref, ClassMethodRef):
|
|
860
|
+
tp = decls.get_paramaterized_class(ref.class_name)
|
|
861
|
+
var_args.insert(0, RuntimeClass(Thunk.value(decls), tp))
|
|
862
|
+
_add_default_rewrite(decls, ref, signature.semantic_return_type, fn(*var_args), ruleset)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def _add_default_rewrite(
|
|
866
|
+
decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
|
|
867
|
+
) -> None:
|
|
868
|
+
"""
|
|
869
|
+
Adds a default rewrite for the callable, if the default rewrite is not None
|
|
870
|
+
|
|
871
|
+
Will add it to the ruleset if it is passed in, or add it to the default ruleset on the passed in decls if not.
|
|
872
|
+
"""
|
|
873
|
+
if default_rewrite is None:
|
|
874
|
+
return
|
|
875
|
+
resolved_value = resolve_literal(type_ref, default_rewrite)
|
|
876
|
+
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
|
|
877
|
+
if ruleset:
|
|
878
|
+
ruleset_decls = ruleset._current_egg_decls
|
|
879
|
+
ruleset_decl = ruleset.__egg_ruleset__
|
|
880
|
+
else:
|
|
881
|
+
ruleset_decls = decls
|
|
882
|
+
ruleset_decl = decls.default_ruleset
|
|
883
|
+
ruleset_decl.rules.append(rewrite_decl)
|
|
884
|
+
ruleset_decls |= resolved_value
|
|
820
885
|
|
|
821
886
|
|
|
822
887
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -991,6 +1056,7 @@ class EGraph(_BaseModule):
|
|
|
991
1056
|
action = let(name, expr)
|
|
992
1057
|
self.register(action)
|
|
993
1058
|
runtime_expr = to_runtime_expr(expr)
|
|
1059
|
+
self._add_decls(runtime_expr)
|
|
994
1060
|
return cast(
|
|
995
1061
|
EXPR,
|
|
996
1062
|
RuntimeExpr.__from_value__(
|
|
@@ -1016,7 +1082,8 @@ class EGraph(_BaseModule):
|
|
|
1016
1082
|
self._add_decls(runtime_expr, schedule)
|
|
1017
1083
|
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
1018
1084
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1019
|
-
|
|
1085
|
+
# Must also register type
|
|
1086
|
+
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1020
1087
|
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
1021
1088
|
extract_report = self._egraph.extract_report()
|
|
1022
1089
|
if not isinstance(extract_report, bindings.Best):
|
|
@@ -1121,8 +1188,7 @@ class EGraph(_BaseModule):
|
|
|
1121
1188
|
return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1122
1189
|
|
|
1123
1190
|
def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
|
|
1124
|
-
self._state.
|
|
1125
|
-
expr = self._state.expr_to_egg(typed_expr.expr)
|
|
1191
|
+
expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1126
1192
|
self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
|
|
1127
1193
|
extract_report = self._egraph.extract_report()
|
|
1128
1194
|
if not extract_report:
|
|
@@ -1181,7 +1247,7 @@ class EGraph(_BaseModule):
|
|
|
1181
1247
|
runtime_expr = to_runtime_expr(expr)
|
|
1182
1248
|
self._add_decls(runtime_expr)
|
|
1183
1249
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1184
|
-
egg_expr = self._state.
|
|
1250
|
+
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1185
1251
|
match typed_expr.tp:
|
|
1186
1252
|
case JustTypeRef("i64"):
|
|
1187
1253
|
return self._egraph.eval_i64(egg_expr)
|
|
@@ -1878,7 +1944,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
|
|
|
1878
1944
|
if "Callable" not in globals:
|
|
1879
1945
|
globals["Callable"] = Callable
|
|
1880
1946
|
hints = get_type_hints(gen, globals, frame.f_locals)
|
|
1881
|
-
args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1947
|
+
args = [_var(_rule_var_name(p.name), hints[p.name]) for p in signature(gen).parameters.values()]
|
|
1882
1948
|
return list(gen(*args)) # type: ignore[misc]
|
|
1883
1949
|
|
|
1884
1950
|
|
egglog/egraph_state.py
CHANGED
|
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from collections.abc import Iterable
|
|
21
21
|
|
|
22
|
-
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
|
|
22
|
+
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "_rule_var_name"]
|
|
23
23
|
|
|
24
24
|
# Create a global sort for python objects, so we can store them without an e-graph instance
|
|
25
25
|
# Needed when serializing commands to egg commands when creating modules
|
|
@@ -98,7 +98,8 @@ class EGraphState:
|
|
|
98
98
|
for rule in rules:
|
|
99
99
|
if rule in added_rules:
|
|
100
100
|
continue
|
|
101
|
-
self.
|
|
101
|
+
cmd = self.command_to_egg(rule, name)
|
|
102
|
+
self.egraph.run_program(cmd)
|
|
102
103
|
added_rules.add(rule)
|
|
103
104
|
case CombinedRulesetDecl(rulesets):
|
|
104
105
|
if name in self.rulesets:
|
|
@@ -115,8 +116,8 @@ class EGraphState:
|
|
|
115
116
|
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
|
|
116
117
|
self.type_ref_to_egg(tp)
|
|
117
118
|
rewrite = bindings.Rewrite(
|
|
118
|
-
self.
|
|
119
|
-
self.
|
|
119
|
+
self._expr_to_egg(lhs),
|
|
120
|
+
self._expr_to_egg(rhs),
|
|
120
121
|
[self.fact_to_egg(c) for c in conditions],
|
|
121
122
|
)
|
|
122
123
|
return (
|
|
@@ -130,6 +131,16 @@ class EGraphState:
|
|
|
130
131
|
[self.fact_to_egg(f) for f in body],
|
|
131
132
|
)
|
|
132
133
|
return bindings.RuleCommand(name or "", ruleset, rule)
|
|
134
|
+
case DefaultRewriteDecl(ref, expr):
|
|
135
|
+
decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
|
|
136
|
+
sig = decl.signature
|
|
137
|
+
assert isinstance(sig, FunctionSignature)
|
|
138
|
+
args = tuple(
|
|
139
|
+
TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name)))
|
|
140
|
+
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
141
|
+
)
|
|
142
|
+
rewrite_decl = RewriteDecl(sig.semantic_return_type.to_just(), CallDecl(ref, args), expr, (), False)
|
|
143
|
+
return self.command_to_egg(rewrite_decl, ruleset)
|
|
133
144
|
case _:
|
|
134
145
|
assert_never(cmd)
|
|
135
146
|
|
|
@@ -139,13 +150,13 @@ class EGraphState:
|
|
|
139
150
|
return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
|
|
140
151
|
case SetDecl(tp, call, rhs):
|
|
141
152
|
self.type_ref_to_egg(tp)
|
|
142
|
-
call_ = self.
|
|
143
|
-
return bindings.Set(call_.name, call_.args, self.
|
|
153
|
+
call_ = self._expr_to_egg(call)
|
|
154
|
+
return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
|
|
144
155
|
case ExprActionDecl(typed_expr):
|
|
145
156
|
return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
|
|
146
157
|
case ChangeDecl(tp, call, change):
|
|
147
158
|
self.type_ref_to_egg(tp)
|
|
148
|
-
call_ = self.
|
|
159
|
+
call_ = self._expr_to_egg(call)
|
|
149
160
|
egg_change: bindings._Change
|
|
150
161
|
match change:
|
|
151
162
|
case "delete":
|
|
@@ -157,7 +168,7 @@ class EGraphState:
|
|
|
157
168
|
return bindings.Change(egg_change, call_.name, call_.args)
|
|
158
169
|
case UnionDecl(tp, lhs, rhs):
|
|
159
170
|
self.type_ref_to_egg(tp)
|
|
160
|
-
return bindings.Union(self.
|
|
171
|
+
return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
161
172
|
case PanicDecl(name):
|
|
162
173
|
return bindings.Panic(name)
|
|
163
174
|
case _:
|
|
@@ -167,7 +178,7 @@ class EGraphState:
|
|
|
167
178
|
match fact:
|
|
168
179
|
case EqDecl(tp, exprs):
|
|
169
180
|
self.type_ref_to_egg(tp)
|
|
170
|
-
return bindings.Eq([self.
|
|
181
|
+
return bindings.Eq([self._expr_to_egg(e) for e in exprs])
|
|
171
182
|
case ExprFactDecl(typed_expr):
|
|
172
183
|
return bindings.Fact(self.typed_expr_to_egg(typed_expr))
|
|
173
184
|
case _:
|
|
@@ -201,8 +212,8 @@ class EGraphState:
|
|
|
201
212
|
[self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
|
|
202
213
|
self.type_ref_to_egg(signature.semantic_return_type.to_just()),
|
|
203
214
|
),
|
|
204
|
-
self.
|
|
205
|
-
self.
|
|
215
|
+
self._expr_to_egg(decl.default) if decl.default else None,
|
|
216
|
+
self._expr_to_egg(decl.merge) if decl.merge else None,
|
|
206
217
|
[self.action_to_egg(a) for a in decl.on_merge],
|
|
207
218
|
decl.cost,
|
|
208
219
|
decl.unextractable,
|
|
@@ -245,6 +256,8 @@ class EGraphState:
|
|
|
245
256
|
if decl.builtin:
|
|
246
257
|
for method in decl.class_methods:
|
|
247
258
|
self.callable_ref_to_egg(ClassMethodRef(ref.name, method))
|
|
259
|
+
if decl.init:
|
|
260
|
+
self.callable_ref_to_egg(InitRef(ref.name))
|
|
248
261
|
|
|
249
262
|
return egg_name
|
|
250
263
|
|
|
@@ -261,15 +274,15 @@ class EGraphState:
|
|
|
261
274
|
|
|
262
275
|
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
|
|
263
276
|
self.type_ref_to_egg(typed_expr_decl.tp)
|
|
264
|
-
return self.
|
|
277
|
+
return self._expr_to_egg(typed_expr_decl.expr)
|
|
265
278
|
|
|
266
279
|
@overload
|
|
267
|
-
def
|
|
280
|
+
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
268
281
|
|
|
269
282
|
@overload
|
|
270
|
-
def
|
|
283
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
271
284
|
|
|
272
|
-
def
|
|
285
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
|
|
273
286
|
"""
|
|
274
287
|
Convert an ExprDecl to an egg expression.
|
|
275
288
|
|
|
@@ -307,7 +320,7 @@ class EGraphState:
|
|
|
307
320
|
case PyObjectDecl(value):
|
|
308
321
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
309
322
|
case PartialCallDecl(call_decl):
|
|
310
|
-
egg_fn_call = self.
|
|
323
|
+
egg_fn_call = self._expr_to_egg(call_decl)
|
|
311
324
|
res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
|
|
312
325
|
case _:
|
|
313
326
|
assert_never(expr_decl.expr)
|
|
@@ -355,6 +368,8 @@ def _generate_callable_egg_name(ref: CallableRef) -> str:
|
|
|
355
368
|
| PropertyRef(cls_name, name)
|
|
356
369
|
):
|
|
357
370
|
return f"{cls_name}_{name}"
|
|
371
|
+
case InitRef(cls_name):
|
|
372
|
+
return f"{cls_name}___init__"
|
|
358
373
|
case _:
|
|
359
374
|
assert_never(ref)
|
|
360
375
|
|
|
@@ -427,8 +442,11 @@ class FromEggState:
|
|
|
427
442
|
possible_types: Iterable[JustTypeRef | None]
|
|
428
443
|
signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
|
|
429
444
|
assert isinstance(signature, FunctionSignature)
|
|
430
|
-
if isinstance(callable_ref, ClassMethodRef):
|
|
431
|
-
|
|
445
|
+
if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
|
|
446
|
+
# Need OR in case we have class method whose class whas never added as a sort, which would happen
|
|
447
|
+
# if the class method didn't return that type and no other function did. In this case, we don't need
|
|
448
|
+
# to care about the type vars and we we don't need to bind any possible type.
|
|
449
|
+
possible_types = self.state._get_possible_types(callable_ref.class_name) or [None]
|
|
432
450
|
cls_name = callable_ref.class_name
|
|
433
451
|
else:
|
|
434
452
|
possible_types = [None]
|
|
@@ -437,7 +455,6 @@ class FromEggState:
|
|
|
437
455
|
tcs = TypeConstraintSolver(self.decls)
|
|
438
456
|
if possible_type and possible_type.args:
|
|
439
457
|
tcs.bind_class(possible_type)
|
|
440
|
-
|
|
441
458
|
try:
|
|
442
459
|
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
443
460
|
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
|
|
@@ -445,7 +462,14 @@ class FromEggState:
|
|
|
445
462
|
except TypeConstraintError:
|
|
446
463
|
continue
|
|
447
464
|
args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
|
|
448
|
-
|
|
465
|
+
|
|
466
|
+
return CallDecl(
|
|
467
|
+
callable_ref,
|
|
468
|
+
args,
|
|
469
|
+
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
470
|
+
# but dont need to store them
|
|
471
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
|
|
472
|
+
)
|
|
449
473
|
raise ValueError(f"Could not find callable ref for call {term}")
|
|
450
474
|
|
|
451
475
|
def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
|
|
@@ -454,3 +478,10 @@ class FromEggState:
|
|
|
454
478
|
except KeyError:
|
|
455
479
|
res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
|
|
456
480
|
return res
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _rule_var_name(s: str) -> str:
|
|
484
|
+
"""
|
|
485
|
+
Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
|
|
486
|
+
"""
|
|
487
|
+
return f"__var__{s}"
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Higher Order Functions
|
|
4
|
+
======================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from egglog import *
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Math(Expr):
|
|
18
|
+
def __init__(self, i: i64Like) -> None: ...
|
|
19
|
+
|
|
20
|
+
def __add__(self, other: Math) -> Math: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MathList(Expr):
|
|
24
|
+
def __init__(self) -> None: ...
|
|
25
|
+
|
|
26
|
+
def append(self, i: Math) -> MathList: ...
|
|
27
|
+
|
|
28
|
+
def map(self, f: Callable[[Math], Math]) -> MathList: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@ruleset
|
|
32
|
+
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
|
|
33
|
+
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
|
|
34
|
+
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
|
|
35
|
+
yield rewrite(MathList().map(f)).to(MathList())
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@function(ruleset=math_ruleset)
|
|
39
|
+
def increment_by_one(x: Math) -> Math:
|
|
40
|
+
return x + Math(1)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
egraph = EGraph()
|
|
44
|
+
x = egraph.let("x", MathList().append(Math(1)).append(Math(2)))
|
|
45
|
+
y = egraph.let("y", x.map(increment_by_one))
|
|
46
|
+
egraph.run(math_ruleset.saturate())
|
|
47
|
+
|
|
48
|
+
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
|
|
49
|
+
|
|
50
|
+
egraph
|
egglog/exp/array_api.py
CHANGED
|
@@ -882,7 +882,10 @@ converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
|
|
|
882
882
|
|
|
883
883
|
@function
|
|
884
884
|
def asarray(
|
|
885
|
-
a: NDArray,
|
|
885
|
+
a: NDArray,
|
|
886
|
+
dtype: OptionalDType = OptionalDType.none,
|
|
887
|
+
copy: OptionalBool = OptionalBool.none,
|
|
888
|
+
device: OptionalDevice = OptionalDevice.none,
|
|
886
889
|
) -> NDArray: ...
|
|
887
890
|
|
|
888
891
|
|
egglog/pretty.py
CHANGED
|
@@ -166,6 +166,8 @@ class TraverseContext:
|
|
|
166
166
|
pass
|
|
167
167
|
case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls):
|
|
168
168
|
for de in decls:
|
|
169
|
+
if isinstance(de, DefaultRewriteDecl):
|
|
170
|
+
continue
|
|
169
171
|
self(de)
|
|
170
172
|
case CallDecl(_, exprs, _):
|
|
171
173
|
for e in exprs:
|
|
@@ -178,6 +180,8 @@ class TraverseContext:
|
|
|
178
180
|
self(c)
|
|
179
181
|
case CombinedRulesetDecl(_):
|
|
180
182
|
pass
|
|
183
|
+
case DefaultRewriteDecl():
|
|
184
|
+
pass
|
|
181
185
|
case _:
|
|
182
186
|
assert_never(decl)
|
|
183
187
|
|
|
@@ -276,7 +280,7 @@ class PrettyContext:
|
|
|
276
280
|
case RulesetDecl(rules):
|
|
277
281
|
if ruleset_name:
|
|
278
282
|
return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
|
|
279
|
-
args = ", ".join(
|
|
283
|
+
args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
|
|
280
284
|
return f"ruleset({args})", "ruleset"
|
|
281
285
|
case CombinedRulesetDecl(rulesets):
|
|
282
286
|
if ruleset_name:
|
|
@@ -298,6 +302,9 @@ class PrettyContext:
|
|
|
298
302
|
return ruleset_str, "schedule"
|
|
299
303
|
args = ", ".join(map(self, until))
|
|
300
304
|
return f"run({ruleset_str}, {args})", "schedule"
|
|
305
|
+
case DefaultRewriteDecl():
|
|
306
|
+
msg = "default rewrites should not be pretty printed"
|
|
307
|
+
raise TypeError(msg)
|
|
301
308
|
assert_never(decl)
|
|
302
309
|
|
|
303
310
|
def _call(
|
|
@@ -370,10 +377,8 @@ class PrettyContext:
|
|
|
370
377
|
case FunctionRef(name):
|
|
371
378
|
return name, args
|
|
372
379
|
case ClassMethodRef(class_name, method_name):
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
fn_str += f".{method_name}"
|
|
376
|
-
return fn_str, args
|
|
380
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
381
|
+
return f"{tp_ref}.{method_name}", args
|
|
377
382
|
case MethodRef(_class_name, method_name):
|
|
378
383
|
slf, *args = args
|
|
379
384
|
slf = self(slf, parens=True)
|
|
@@ -400,6 +405,9 @@ class PrettyContext:
|
|
|
400
405
|
return f"{class_name}.{variable_name}"
|
|
401
406
|
case PropertyRef(_class_name, property_name):
|
|
402
407
|
return f"{self(args[0], parens=True)}.{property_name}"
|
|
408
|
+
case InitRef(class_name):
|
|
409
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
410
|
+
return str(tp_ref), args
|
|
403
411
|
assert_never(ref)
|
|
404
412
|
|
|
405
413
|
def _generate_name(self, typ: str) -> str:
|
|
@@ -434,6 +442,8 @@ def _pretty_callable(ref: CallableRef) -> str:
|
|
|
434
442
|
| PropertyRef(class_name, method_name)
|
|
435
443
|
):
|
|
436
444
|
return f"{class_name}.{method_name}"
|
|
445
|
+
case InitRef(class_name):
|
|
446
|
+
return class_name
|
|
437
447
|
case ConstantRef(_):
|
|
438
448
|
msg = "Constants should not be callable"
|
|
439
449
|
raise NotImplementedError(msg)
|
egglog/runtime.py
CHANGED
|
@@ -163,10 +163,8 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
163
163
|
return RuntimeExpr.__from_value__(
|
|
164
164
|
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
|
|
165
165
|
)
|
|
166
|
-
|
|
167
|
-
return
|
|
168
|
-
Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
|
|
169
|
-
)(*args, **kwargs) # type: ignore[arg-type]
|
|
166
|
+
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), InitRef(name), self.__egg_tp__.to_just())
|
|
167
|
+
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
170
168
|
|
|
171
169
|
def __dir__(self) -> list[str]:
|
|
172
170
|
cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
|
|
@@ -277,7 +275,10 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
277
275
|
|
|
278
276
|
# Turn all keyword args into positional args
|
|
279
277
|
py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
|
|
280
|
-
|
|
278
|
+
try:
|
|
279
|
+
bound = py_signature.bind(*args, **kwargs)
|
|
280
|
+
except TypeError as err:
|
|
281
|
+
raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err
|
|
281
282
|
del kwargs
|
|
282
283
|
bound.apply_defaults()
|
|
283
284
|
assert not bound.kwargs
|
|
@@ -310,7 +311,9 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
310
311
|
return_tp = tcs.infer_return_type(
|
|
311
312
|
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
|
|
312
313
|
)
|
|
313
|
-
bound_params =
|
|
314
|
+
bound_params = (
|
|
315
|
+
cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
|
|
316
|
+
)
|
|
314
317
|
# If we were using unstable-app to call a funciton, add that function back as the first arg.
|
|
315
318
|
if function_value:
|
|
316
319
|
arg_exprs = (function_value, *arg_exprs)
|
egglog/thunk.py
CHANGED
|
@@ -21,7 +21,7 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
21
21
|
Cached delayed function call.
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
-
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T]
|
|
24
|
+
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T] | Error
|
|
25
25
|
|
|
26
26
|
@classmethod
|
|
27
27
|
def fn(
|
|
@@ -44,14 +44,21 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
44
44
|
return value
|
|
45
45
|
case Unresolved(fn, args, fallback):
|
|
46
46
|
self.state = Resolving(fallback)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
47
|
+
try:
|
|
48
|
+
res = fn(*args)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
self.state = Error(e)
|
|
51
|
+
raise
|
|
52
|
+
else:
|
|
53
|
+
self.state = Resolved(res)
|
|
54
|
+
return res
|
|
50
55
|
case Resolving(fallback):
|
|
51
56
|
if fallback is None:
|
|
52
57
|
msg = "Recursively resolving thunk without fallback"
|
|
53
58
|
raise ValueError(msg)
|
|
54
59
|
return fallback()
|
|
60
|
+
case Error(e):
|
|
61
|
+
raise e
|
|
55
62
|
|
|
56
63
|
|
|
57
64
|
@dataclass
|
|
@@ -69,3 +76,8 @@ class Unresolved(Generic[T, Unpack[TS]]):
|
|
|
69
76
|
@dataclass
|
|
70
77
|
class Resolving(Generic[T]):
|
|
71
78
|
fallback: Callable[[], T] | None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class Error:
|
|
83
|
+
e: Exception
|
egglog/type_constraint_solver.py
CHANGED
|
@@ -79,9 +79,10 @@ class TypeConstraintSolver:
|
|
|
79
79
|
Also returns the bound type params if the class name is passed in.
|
|
80
80
|
"""
|
|
81
81
|
self._infer_typevars(fn_return, return_, cls_name)
|
|
82
|
-
arg_types = (
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
arg_types: Iterable[JustTypeRef] = [self._subtitute_typevars(a, cls_name) for a in fn_args]
|
|
83
|
+
if fn_var_args:
|
|
84
|
+
# Need to be generator so it can be infinite for variable args
|
|
85
|
+
arg_types = chain(arg_types, repeat(self._subtitute_typevars(fn_var_args, cls_name)))
|
|
85
86
|
bound_typevars = (
|
|
86
87
|
tuple(
|
|
87
88
|
v
|
|
@@ -132,8 +133,8 @@ class TypeConstraintSolver:
|
|
|
132
133
|
def _subtitute_typevars(self, tp: TypeOrVarRef, cls_name: str | None) -> JustTypeRef:
|
|
133
134
|
match tp:
|
|
134
135
|
case ClassTypeVarRef(name):
|
|
136
|
+
assert cls_name is not None
|
|
135
137
|
try:
|
|
136
|
-
assert cls_name is not None
|
|
137
138
|
return self._cls_typevar_index_to_type[cls_name][name]
|
|
138
139
|
except KeyError as e:
|
|
139
140
|
raise TypeConstraintError(f"Not enough bound typevars for {tp}") from e
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: egglog
|
|
3
|
-
Version: 7.
|
|
3
|
+
Version: 7.2.0
|
|
4
4
|
Classifier: Environment :: MacOS X
|
|
5
5
|
Classifier: Environment :: Win32 (MS Windows)
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -19,22 +19,6 @@ Classifier: Typing :: Typed
|
|
|
19
19
|
Requires-Dist: typing-extensions
|
|
20
20
|
Requires-Dist: black
|
|
21
21
|
Requires-Dist: graphviz
|
|
22
|
-
Requires-Dist: scikit-learn; extra == 'array'
|
|
23
|
-
Requires-Dist: array_api_compat; extra == 'array'
|
|
24
|
-
Requires-Dist: numba==0.59.1; extra == 'array'
|
|
25
|
-
Requires-Dist: llvmlite==0.42.0; extra == 'array'
|
|
26
|
-
Requires-Dist: pre-commit; extra == 'dev'
|
|
27
|
-
Requires-Dist: ruff; extra == 'dev'
|
|
28
|
-
Requires-Dist: mypy; extra == 'dev'
|
|
29
|
-
Requires-Dist: anywidget[dev]; extra == 'dev'
|
|
30
|
-
Requires-Dist: egglog[docs,test]; extra == 'dev'
|
|
31
|
-
Requires-Dist: pytest; extra == 'test'
|
|
32
|
-
Requires-Dist: mypy; extra == 'test'
|
|
33
|
-
Requires-Dist: syrupy; extra == 'test'
|
|
34
|
-
Requires-Dist: egglog[array]; extra == 'test'
|
|
35
|
-
Requires-Dist: pytest-codspeed; extra == 'test'
|
|
36
|
-
Requires-Dist: pytest-benchmark; extra == 'test'
|
|
37
|
-
Requires-Dist: pytest-xdist; extra == 'test'
|
|
38
22
|
Requires-Dist: pydata-sphinx-theme; extra == 'docs'
|
|
39
23
|
Requires-Dist: myst-nb; extra == 'docs'
|
|
40
24
|
Requires-Dist: sphinx-autodoc-typehints; extra == 'docs'
|
|
@@ -47,10 +31,26 @@ Requires-Dist: egglog[array]; extra == 'docs'
|
|
|
47
31
|
Requires-Dist: line-profiler; extra == 'docs'
|
|
48
32
|
Requires-Dist: sphinxcontrib-mermaid; extra == 'docs'
|
|
49
33
|
Requires-Dist: ablog; extra == 'docs'
|
|
34
|
+
Requires-Dist: pytest; extra == 'test'
|
|
35
|
+
Requires-Dist: mypy; extra == 'test'
|
|
36
|
+
Requires-Dist: syrupy; extra == 'test'
|
|
37
|
+
Requires-Dist: egglog[array]; extra == 'test'
|
|
38
|
+
Requires-Dist: pytest-codspeed; extra == 'test'
|
|
39
|
+
Requires-Dist: pytest-benchmark; extra == 'test'
|
|
40
|
+
Requires-Dist: pytest-xdist; extra == 'test'
|
|
41
|
+
Requires-Dist: scikit-learn; extra == 'array'
|
|
42
|
+
Requires-Dist: array_api_compat; extra == 'array'
|
|
43
|
+
Requires-Dist: numba==0.59.1; extra == 'array'
|
|
44
|
+
Requires-Dist: llvmlite==0.42.0; extra == 'array'
|
|
45
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
46
|
+
Requires-Dist: ruff; extra == 'dev'
|
|
47
|
+
Requires-Dist: mypy; extra == 'dev'
|
|
48
|
+
Requires-Dist: anywidget[dev]; extra == 'dev'
|
|
49
|
+
Requires-Dist: egglog[docs,test]; extra == 'dev'
|
|
50
|
+
Provides-Extra: docs
|
|
51
|
+
Provides-Extra: test
|
|
50
52
|
Provides-Extra: array
|
|
51
53
|
Provides-Extra: dev
|
|
52
|
-
Provides-Extra: test
|
|
53
|
-
Provides-Extra: docs
|
|
54
54
|
License-File: LICENSE
|
|
55
55
|
Summary: e-graphs in Python built around the the egglog rust library
|
|
56
56
|
License: MIT
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
egglog-7.
|
|
2
|
-
egglog-7.
|
|
3
|
-
egglog-7.
|
|
1
|
+
egglog-7.2.0.dist-info/METADATA,sha256=O3Uq5oIIHWZ-bG2XVILykl-0dYyuNuoR9Dxj38v6bXs,3829
|
|
2
|
+
egglog-7.2.0.dist-info/WHEEL,sha256=PI_yinHuPssCo943lUdZTaSBEwIUzDPKgEICR1imaRE,96
|
|
3
|
+
egglog-7.2.0.dist-info/license_files/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
|
|
4
4
|
egglog/bindings.pyi,sha256=iFdtYHqPjuVt46qh1IE81cOI-lbGgPISKQB-3ERNzhI,11770
|
|
5
5
|
egglog/builtins.py,sha256=p5oZLDleCgQC8G2Zp2EvmqsfVnpgmARqUHqosKSSKnQ,13120
|
|
6
6
|
egglog/config.py,sha256=mALVaxh7zmGrbuyzaVKVmYKcu1lF703QsKJw8AF7gSM,176
|
|
7
7
|
egglog/conversion.py,sha256=4JhocGd1_nwmFMVNCzDDwj6aWfhBFTX2hm_7Xc_DiUM,6328
|
|
8
|
-
egglog/declarations.py,sha256=
|
|
9
|
-
egglog/egraph.py,sha256=
|
|
10
|
-
egglog/egraph_state.py,sha256=
|
|
8
|
+
egglog/declarations.py,sha256=WpZ1yyw4vlWm3Xi_otPP2IU51IqNcDmJA5y20RcBFBQ,18583
|
|
9
|
+
egglog/egraph.py,sha256=Q3ZJhtzzWcsukuarT0vT48-DkfHe17adLLUHn-pILPE,68657
|
|
10
|
+
egglog/egraph_state.py,sha256=kXCNt4zenG1meRv5T6B8xI5oH_NreA06JMhmqv_F-og,21802
|
|
11
11
|
egglog/examples/bool.py,sha256=pWZTjfXR1cFy3KcihLBU5AF5rn83ImORlhUUJ1YiAXc,733
|
|
12
12
|
egglog/examples/eqsat_basic.py,sha256=ORXFYYEDsEZK2IPhHtoFsd-LdjMiQi1nn7kix4Nam0s,1011
|
|
13
13
|
egglog/examples/fib.py,sha256=wAn-PjazxgHDkXAU4o2xTk_GtM_iGL0biV66vWM1st4,520
|
|
14
|
+
egglog/examples/higher_order_functions.py,sha256=CBZqnqVdYGgztHw5QWtcnUhzs1nnBcw4O-gjSzAxjRc,1203
|
|
14
15
|
egglog/examples/lambda_.py,sha256=hQBOaSw_yorNcbkQVu2EhgSc0IZNWIny7asaOlcUk9s,8496
|
|
15
16
|
egglog/examples/matrix.py,sha256=_zmjgfFr2O_LjTcsTD-45_38Y_M1sP3AV39K6oFxAdw,5136
|
|
16
17
|
egglog/examples/ndarrays.py,sha256=T-wwef-n-3LDSjaO35zA8AZH5DXFFqq0XBSCQKEXV6E,4186
|
|
@@ -18,7 +19,7 @@ egglog/examples/README.rst,sha256=QrbfmivODBvUvmY3-dHarcbC6bEvwoqAfTDhiI-aJxU,23
|
|
|
18
19
|
egglog/examples/resolution.py,sha256=sKkbRI_v9XkQM0DriacKLINqKKDqYGFhvMCAS9tZbTA,2203
|
|
19
20
|
egglog/examples/schedule_demo.py,sha256=iJtIbcLaZ7zK8UalY0z7KAKMqYjQx0MKTsNF24lKtik,652
|
|
20
21
|
egglog/examples/__init__.py,sha256=KuhaJFOyz_rpUvEqZubsgLnv6rhQNE_AVFXA6bUnpdY,34
|
|
21
|
-
egglog/exp/array_api.py,sha256=
|
|
22
|
+
egglog/exp/array_api.py,sha256=b7MoUuH2sIW3KnUX7fGPaflwlCCskgrt1GsidVT-KG4,41474
|
|
22
23
|
egglog/exp/array_api_jit.py,sha256=HIZzd0G17u-u_F4vfRdhoYvRo-ETx5HFO3RBcOfLcco,1287
|
|
23
24
|
egglog/exp/array_api_numba.py,sha256=OjONBLdFRukx4vKiNK_rBxjMzsxbWpMEdD7JcGHJjmY,2924
|
|
24
25
|
egglog/exp/array_api_program_gen.py,sha256=crgUYXXNhQdfTq31FSIpWLIfzNsgQD8ngg3OosCtIgg,19680
|
|
@@ -27,13 +28,13 @@ egglog/exp/siu_examples.py,sha256=KZmpkSCgbL4uqHhx2Jh9Adz1_cDbkIlyXPa1Kh_0O_8,78
|
|
|
27
28
|
egglog/exp/__init__.py,sha256=G9zeKUcPBgIhgUg1meC86OfZVFETYIugyHWseTcCe_0,52
|
|
28
29
|
egglog/graphviz_widget.py,sha256=YtI7LCFWHihDQ1qLvdj2SVYEcuZLSooFUYheunOTxdc,1339
|
|
29
30
|
egglog/ipython_magic.py,sha256=VA19xAb6Sz7-IxlJBbnZW_gVFDqaYNnvdMB9QitndjE,1254
|
|
30
|
-
egglog/pretty.py,sha256=
|
|
31
|
+
egglog/pretty.py,sha256=s5XBIB7kMrxV25aO-F4NW5Jli2DOoc72wwswFYxJPjM,19561
|
|
31
32
|
egglog/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
|
-
egglog/runtime.py,sha256=
|
|
33
|
-
egglog/thunk.py,sha256=
|
|
34
|
-
egglog/type_constraint_solver.py,sha256=
|
|
33
|
+
egglog/runtime.py,sha256=s32xdoD_mIdDRi_PJQ5E_kSNyhUkT310osH38kZfCSk,23174
|
|
34
|
+
egglog/thunk.py,sha256=bm1zZUzwey0DijtMZkSulyfiWjCAg_1UmXaSG8qvvjs,2170
|
|
35
|
+
egglog/type_constraint_solver.py,sha256=MW22gwryNOqsYaB__dCCjVvdnB1Km5fovoW2IgUXaVo,5836
|
|
35
36
|
egglog/widget.css,sha256=WJS2M1wQdujhSTCakMa_ZXuoTPre1Uy1lPcvBE1LZQU,102
|
|
36
37
|
egglog/widget.js,sha256=UNOER3sYZ-bS7Qhw9S6qtpR81FuHa5DzXQaWWtQq334,2021
|
|
37
38
|
egglog/__init__.py,sha256=iUrVe5fb0XFyMCS3CwTjwhKEtU8KpIkdTpJpnUhm8o0,247
|
|
38
|
-
egglog/bindings.cp311-win_amd64.pyd,sha256=
|
|
39
|
-
egglog-7.
|
|
39
|
+
egglog/bindings.cp311-win_amd64.pyd,sha256=7EuZkPO4dxLNKrmTWuEZdHrX0eYDKEBwfdYSpgkD-94,4591104
|
|
40
|
+
egglog-7.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|