egglog 7.1.0__cp311-none-win_amd64.whl → 8.0.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +146 -8
- egglog/egraph.py +337 -203
- egglog/egraph_state.py +171 -64
- egglog/examples/higher_order_functions.py +45 -0
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +97 -43
- egglog/runtime.py +60 -44
- egglog/thunk.py +44 -20
- egglog/type_constraint_solver.py +5 -4
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/METADATA +31 -30
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.1.0.dist-info/RECORD +0 -39
- {egglog-7.1.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
egglog/egraph.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import inspect
|
|
4
5
|
import pathlib
|
|
5
6
|
import tempfile
|
|
6
7
|
from abc import abstractmethod
|
|
7
|
-
from collections.abc import Callable, Iterable
|
|
8
|
+
from collections.abc import Callable, Generator, Iterable
|
|
8
9
|
from contextvars import ContextVar, Token
|
|
9
10
|
from dataclasses import InitVar, dataclass, field
|
|
10
11
|
from inspect import Parameter, currentframe, signature
|
|
@@ -37,8 +38,6 @@ from .runtime import *
|
|
|
37
38
|
from .thunk import *
|
|
38
39
|
|
|
39
40
|
if TYPE_CHECKING:
|
|
40
|
-
import ipywidgets
|
|
41
|
-
|
|
42
41
|
from .builtins import Bool, PyObject, String, f64, i64
|
|
43
42
|
|
|
44
43
|
|
|
@@ -352,7 +351,7 @@ class _BaseModule:
|
|
|
352
351
|
This is the same as defining a nullary function with a high cost.
|
|
353
352
|
# TODO: Rename as declare to match eggglog?
|
|
354
353
|
"""
|
|
355
|
-
return constant(name, tp, egg_name)
|
|
354
|
+
return constant(name, tp, egg_name=egg_name)
|
|
356
355
|
|
|
357
356
|
def register(
|
|
358
357
|
self,
|
|
@@ -452,6 +451,7 @@ class _ExprMetaclass(type):
|
|
|
452
451
|
namespace: dict[str, Any],
|
|
453
452
|
egg_sort: str | None = None,
|
|
454
453
|
builtin: bool = False,
|
|
454
|
+
ruleset: Ruleset | None = None,
|
|
455
455
|
) -> RuntimeClass | type:
|
|
456
456
|
# If this is the Expr subclass, just return the class
|
|
457
457
|
if not bases:
|
|
@@ -463,20 +463,37 @@ class _ExprMetaclass(type):
|
|
|
463
463
|
prev_frame = frame.f_back
|
|
464
464
|
assert prev_frame
|
|
465
465
|
|
|
466
|
+
# Pass in an instance of the class so that when we are generating the decls
|
|
467
|
+
# we can update them eagerly so that we can access the methods in the class body
|
|
468
|
+
runtime_cls = RuntimeClass(None, TypeRefWithVars(name)) # type: ignore[arg-type]
|
|
469
|
+
|
|
466
470
|
# Store frame so that we can get live access to updated locals/globals
|
|
467
471
|
# Otherwise, f_locals returns a copy
|
|
468
472
|
# https://peps.python.org/pep-0667/
|
|
469
|
-
|
|
470
|
-
_generate_class_decls,
|
|
473
|
+
runtime_cls.__egg_decls_thunk__ = Thunk.fn(
|
|
474
|
+
_generate_class_decls,
|
|
475
|
+
namespace,
|
|
476
|
+
prev_frame,
|
|
477
|
+
builtin,
|
|
478
|
+
egg_sort,
|
|
479
|
+
name,
|
|
480
|
+
ruleset,
|
|
481
|
+
runtime_cls,
|
|
471
482
|
)
|
|
472
|
-
return
|
|
483
|
+
return runtime_cls
|
|
473
484
|
|
|
474
485
|
def __instancecheck__(cls, instance: object) -> bool:
|
|
475
486
|
return isinstance(instance, RuntimeExpr)
|
|
476
487
|
|
|
477
488
|
|
|
478
489
|
def _generate_class_decls( # noqa: C901
|
|
479
|
-
namespace: dict[str, Any],
|
|
490
|
+
namespace: dict[str, Any],
|
|
491
|
+
frame: FrameType,
|
|
492
|
+
builtin: bool,
|
|
493
|
+
egg_sort: str | None,
|
|
494
|
+
cls_name: str,
|
|
495
|
+
ruleset: Ruleset | None,
|
|
496
|
+
runtime_cls: RuntimeClass,
|
|
480
497
|
) -> Declarations:
|
|
481
498
|
"""
|
|
482
499
|
Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
|
|
@@ -489,6 +506,8 @@ def _generate_class_decls( # noqa: C901
|
|
|
489
506
|
del parameters
|
|
490
507
|
cls_decl = ClassDecl(egg_sort, type_vars, builtin)
|
|
491
508
|
decls = Declarations(_classes={cls_name: cls_decl})
|
|
509
|
+
# Update class think eagerly when resolving so that lookups work in methods
|
|
510
|
+
runtime_cls.__egg_decls_thunk__ = Thunk.value(decls)
|
|
492
511
|
|
|
493
512
|
##
|
|
494
513
|
# Register class variables
|
|
@@ -498,9 +517,9 @@ def _generate_class_decls( # noqa: C901
|
|
|
498
517
|
for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
|
|
499
518
|
if getattr(v, "__origin__", None) == ClassVar:
|
|
500
519
|
(inner_tp,) = v.__args__
|
|
501
|
-
type_ref = resolve_type_annotation(decls, inner_tp)
|
|
502
|
-
cls_decl.class_variables[k] = ConstantDecl(type_ref)
|
|
503
|
-
|
|
520
|
+
type_ref = resolve_type_annotation(decls, inner_tp)
|
|
521
|
+
cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
|
|
522
|
+
_add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
|
|
504
523
|
else:
|
|
505
524
|
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
|
|
506
525
|
raise NotImplementedError(msg)
|
|
@@ -509,14 +528,13 @@ def _generate_class_decls( # noqa: C901
|
|
|
509
528
|
# Register methods, classmethods, preserved methods, and properties
|
|
510
529
|
##
|
|
511
530
|
|
|
512
|
-
# The type ref of self is paramterized by the type vars
|
|
513
|
-
slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
514
|
-
|
|
515
531
|
# Get all the methods from the class
|
|
516
532
|
filtered_namespace: list[tuple[str, Any]] = [
|
|
517
533
|
(k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
|
|
518
534
|
]
|
|
519
535
|
|
|
536
|
+
# all methods we should try adding default functions for
|
|
537
|
+
add_default_funcs: list[Callable[[], None]] = []
|
|
520
538
|
# Then register each of its methods
|
|
521
539
|
for method_name, method in filtered_namespace:
|
|
522
540
|
is_init = method_name == "__init__"
|
|
@@ -535,44 +553,35 @@ def _generate_class_decls( # noqa: C901
|
|
|
535
553
|
cls_decl.preserved_methods[method_name] = fn
|
|
536
554
|
continue
|
|
537
555
|
locals = frame.f_locals
|
|
538
|
-
|
|
539
|
-
def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
|
|
540
|
-
special_function_name: SpecialFunctions | None = (
|
|
541
|
-
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
|
|
542
|
-
)
|
|
543
|
-
if special_function_name:
|
|
544
|
-
return FunctionDecl(
|
|
545
|
-
special_function_name,
|
|
546
|
-
builtin=True,
|
|
547
|
-
egg_name=egg_fn, # noqa: B023
|
|
548
|
-
)
|
|
549
|
-
|
|
550
|
-
return _fn_decl(
|
|
551
|
-
decls,
|
|
552
|
-
egg_fn, # noqa: B023
|
|
553
|
-
fn,
|
|
554
|
-
locals, # noqa: B023
|
|
555
|
-
default, # noqa: B023
|
|
556
|
-
cost, # noqa: B023
|
|
557
|
-
merge, # noqa: B023
|
|
558
|
-
on_merge, # noqa: B023
|
|
559
|
-
mutates, # noqa: B023
|
|
560
|
-
builtin,
|
|
561
|
-
first,
|
|
562
|
-
is_init, # noqa: B023
|
|
563
|
-
unextractable, # noqa: B023
|
|
564
|
-
)
|
|
565
|
-
|
|
556
|
+
ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
|
|
566
557
|
match fn:
|
|
567
558
|
case classmethod():
|
|
568
|
-
|
|
559
|
+
ref = ClassMethodRef(cls_name, method_name)
|
|
560
|
+
fn = fn.__func__
|
|
569
561
|
case property():
|
|
570
|
-
|
|
562
|
+
ref = PropertyRef(cls_name, method_name)
|
|
563
|
+
fn = fn.fget
|
|
571
564
|
case _:
|
|
572
|
-
if is_init
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
565
|
+
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
|
|
566
|
+
special_function_name: SpecialFunctions | None = (
|
|
567
|
+
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
|
|
568
|
+
)
|
|
569
|
+
if special_function_name:
|
|
570
|
+
decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn)
|
|
571
|
+
decls.set_function_decl(ref, decl)
|
|
572
|
+
continue
|
|
573
|
+
|
|
574
|
+
_, add_rewrite = _fn_decl(
|
|
575
|
+
decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
if not builtin and not isinstance(ref, InitRef) and not mutates:
|
|
579
|
+
add_default_funcs.append(add_rewrite)
|
|
580
|
+
|
|
581
|
+
# Add all rewrite methods at the end so that all methods are registered first and can be accessed
|
|
582
|
+
# in the bodies
|
|
583
|
+
for add_rewrite in add_default_funcs:
|
|
584
|
+
add_rewrite()
|
|
576
585
|
|
|
577
586
|
return decls
|
|
578
587
|
|
|
@@ -591,6 +600,8 @@ def function(
|
|
|
591
600
|
mutates_first_arg: bool = False,
|
|
592
601
|
unextractable: bool = False,
|
|
593
602
|
builtin: bool = False,
|
|
603
|
+
ruleset: Ruleset | None = None,
|
|
604
|
+
use_body_as_name: bool = False,
|
|
594
605
|
) -> Callable[[CALLABLE], CALLABLE]: ...
|
|
595
606
|
|
|
596
607
|
|
|
@@ -604,6 +615,8 @@ def function(
|
|
|
604
615
|
on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
|
|
605
616
|
mutates_first_arg: bool = False,
|
|
606
617
|
unextractable: bool = False,
|
|
618
|
+
ruleset: Ruleset | None = None,
|
|
619
|
+
use_body_as_name: bool = False,
|
|
607
620
|
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
|
|
608
621
|
|
|
609
622
|
|
|
@@ -634,15 +647,19 @@ class _FunctionConstructor:
|
|
|
634
647
|
merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
|
|
635
648
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
|
|
636
649
|
unextractable: bool = False
|
|
650
|
+
ruleset: Ruleset | None = None
|
|
651
|
+
use_body_as_name: bool = False
|
|
637
652
|
|
|
638
653
|
def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
|
|
639
|
-
return RuntimeFunction(Thunk.fn(self.create_decls, fn)
|
|
654
|
+
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
|
|
640
655
|
|
|
641
|
-
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
|
|
656
|
+
def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, CallableRef]:
|
|
642
657
|
decls = Declarations()
|
|
643
|
-
|
|
658
|
+
ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
|
|
659
|
+
ref, add_rewrite = _fn_decl(
|
|
644
660
|
decls,
|
|
645
661
|
self.egg_fn,
|
|
662
|
+
ref,
|
|
646
663
|
fn,
|
|
647
664
|
self.hint_locals,
|
|
648
665
|
self.default,
|
|
@@ -651,14 +668,18 @@ class _FunctionConstructor:
|
|
|
651
668
|
self.on_merge,
|
|
652
669
|
self.mutates_first_arg,
|
|
653
670
|
self.builtin,
|
|
671
|
+
self.ruleset,
|
|
654
672
|
unextractable=self.unextractable,
|
|
655
673
|
)
|
|
656
|
-
|
|
674
|
+
add_rewrite()
|
|
675
|
+
return decls, ref
|
|
657
676
|
|
|
658
677
|
|
|
659
678
|
def _fn_decl(
|
|
660
679
|
decls: Declarations,
|
|
661
680
|
egg_name: str | None,
|
|
681
|
+
# If ref is Callable, then generate the ref from the function name
|
|
682
|
+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
|
|
662
683
|
fn: object,
|
|
663
684
|
# Pass in the locals, retrieved from the frame when wrapping,
|
|
664
685
|
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
|
|
@@ -669,11 +690,12 @@ def _fn_decl(
|
|
|
669
690
|
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
|
|
670
691
|
mutates_first_arg: bool,
|
|
671
692
|
is_builtin: bool,
|
|
672
|
-
|
|
673
|
-
first_arg: Literal["cls"] | TypeOrVarRef | None = None,
|
|
674
|
-
is_init: bool = False,
|
|
693
|
+
ruleset: Ruleset | None = None,
|
|
675
694
|
unextractable: bool = False,
|
|
676
|
-
) ->
|
|
695
|
+
) -> tuple[CallableRef, Callable[[], None]]:
|
|
696
|
+
"""
|
|
697
|
+
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
|
|
698
|
+
"""
|
|
677
699
|
if not isinstance(fn, FunctionType):
|
|
678
700
|
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
|
|
679
701
|
|
|
@@ -687,8 +709,8 @@ def _fn_decl(
|
|
|
687
709
|
|
|
688
710
|
params = list(signature(fn).parameters.values())
|
|
689
711
|
|
|
690
|
-
# If this is an init function, or a classmethod,
|
|
691
|
-
if
|
|
712
|
+
# If this is an init function, or a classmethod, the first arg is not used
|
|
713
|
+
if isinstance(ref, ClassMethodRef | InitRef):
|
|
692
714
|
params = params[1:]
|
|
693
715
|
|
|
694
716
|
if _last_param_variable(params):
|
|
@@ -698,40 +720,37 @@ def _fn_decl(
|
|
|
698
720
|
else:
|
|
699
721
|
var_arg_type = None
|
|
700
722
|
arg_types = tuple(
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
if i == 0 and isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init
|
|
723
|
+
decls.get_paramaterized_class(ref.class_name)
|
|
724
|
+
if i == 0 and isinstance(ref, MethodRef | PropertyRef)
|
|
704
725
|
else resolve_type_annotation(decls, hints[t.name])
|
|
705
726
|
for i, t in enumerate(params)
|
|
706
727
|
)
|
|
707
728
|
|
|
708
729
|
# Resolve all default values as arg types
|
|
709
730
|
arg_defaults = [
|
|
710
|
-
resolve_literal(t, p.default) if p.default is not Parameter.empty else None
|
|
731
|
+
resolve_literal(t, p.default, Thunk.value(decls)) if p.default is not Parameter.empty else None
|
|
711
732
|
for (t, p) in zip(arg_types, params, strict=True)
|
|
712
733
|
]
|
|
713
734
|
|
|
714
735
|
decls.update(*arg_defaults)
|
|
715
736
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
else:
|
|
726
|
-
return_type = resolve_type_annotation(decls, hints["return"])
|
|
737
|
+
return_type = (
|
|
738
|
+
decls.get_paramaterized_class(ref.class_name)
|
|
739
|
+
if isinstance(ref, InitRef)
|
|
740
|
+
else arg_types[0]
|
|
741
|
+
if mutates_first_arg
|
|
742
|
+
else resolve_type_annotation(decls, hints["return"])
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
arg_names = tuple(t.name for t in params)
|
|
727
746
|
|
|
728
747
|
decls |= default
|
|
729
748
|
merged = (
|
|
730
749
|
None
|
|
731
750
|
if merge is None
|
|
732
751
|
else merge(
|
|
733
|
-
RuntimeExpr.
|
|
734
|
-
RuntimeExpr.
|
|
752
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
|
|
753
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
|
|
735
754
|
)
|
|
736
755
|
)
|
|
737
756
|
decls |= merged
|
|
@@ -741,28 +760,47 @@ def _fn_decl(
|
|
|
741
760
|
if on_merge is None
|
|
742
761
|
else _action_likes(
|
|
743
762
|
on_merge(
|
|
744
|
-
RuntimeExpr.
|
|
745
|
-
RuntimeExpr.
|
|
763
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))),
|
|
764
|
+
RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))),
|
|
746
765
|
)
|
|
747
766
|
)
|
|
748
767
|
)
|
|
749
768
|
decls.update(*merge_action)
|
|
750
|
-
|
|
751
|
-
|
|
769
|
+
# defer this in generator so it doesnt resolve for builtins eagerly
|
|
770
|
+
args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
|
|
771
|
+
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
|
|
772
|
+
res_thunk: Callable[[], object]
|
|
773
|
+
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
|
|
774
|
+
if not ref:
|
|
775
|
+
tuple_args = tuple(args)
|
|
776
|
+
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
|
|
777
|
+
assert isinstance(res, RuntimeExpr)
|
|
778
|
+
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
|
|
779
|
+
decls._unnamed_functions.add(res_ref)
|
|
780
|
+
res_thunk = Thunk.value(res)
|
|
781
|
+
|
|
782
|
+
else:
|
|
783
|
+
signature_ = FunctionSignature(
|
|
752
784
|
return_type=None if mutates_first_arg else return_type,
|
|
753
785
|
var_arg_type=var_arg_type,
|
|
754
786
|
arg_types=arg_types,
|
|
755
|
-
arg_names=
|
|
787
|
+
arg_names=arg_names,
|
|
756
788
|
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
789
|
+
)
|
|
790
|
+
decl = FunctionDecl(
|
|
791
|
+
signature=signature_,
|
|
792
|
+
cost=cost,
|
|
793
|
+
egg_name=egg_name,
|
|
794
|
+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
|
|
795
|
+
unextractable=unextractable,
|
|
796
|
+
builtin=is_builtin,
|
|
797
|
+
default=None if default is None else default.__egg_typed_expr__.expr,
|
|
798
|
+
on_merge=tuple(a.action for a in merge_action),
|
|
799
|
+
)
|
|
800
|
+
res_ref = ref
|
|
801
|
+
decls.set_function_decl(ref, decl)
|
|
802
|
+
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
|
|
803
|
+
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
|
|
766
804
|
|
|
767
805
|
|
|
768
806
|
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
|
|
@@ -793,7 +831,7 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
|
|
|
793
831
|
Creates a function whose return type is `Unit` and has a default value.
|
|
794
832
|
"""
|
|
795
833
|
decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
|
|
796
|
-
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
|
|
834
|
+
return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name))))
|
|
797
835
|
|
|
798
836
|
|
|
799
837
|
def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
|
|
@@ -804,19 +842,85 @@ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Dec
|
|
|
804
842
|
return decls
|
|
805
843
|
|
|
806
844
|
|
|
807
|
-
def constant(
|
|
845
|
+
def constant(
|
|
846
|
+
name: str,
|
|
847
|
+
tp: type[EXPR],
|
|
848
|
+
default_replacement: EXPR | None = None,
|
|
849
|
+
/,
|
|
850
|
+
*,
|
|
851
|
+
egg_name: str | None = None,
|
|
852
|
+
ruleset: Ruleset | None = None,
|
|
853
|
+
) -> EXPR:
|
|
808
854
|
"""
|
|
809
855
|
A "constant" is implemented as the instantiation of a value that takes no args.
|
|
810
856
|
This creates a function with `name` and return type `tp` and returns a value of it being called.
|
|
811
857
|
"""
|
|
812
|
-
return cast(
|
|
858
|
+
return cast(
|
|
859
|
+
EXPR, RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
|
|
860
|
+
)
|
|
813
861
|
|
|
814
862
|
|
|
815
|
-
def _constant_thunk(
|
|
863
|
+
def _constant_thunk(
|
|
864
|
+
name: str, tp: type, egg_name: str | None, default_replacement: object, ruleset: Ruleset | None
|
|
865
|
+
) -> tuple[Declarations, TypedExprDecl]:
|
|
816
866
|
decls = Declarations()
|
|
817
|
-
type_ref = resolve_type_annotation(decls, tp)
|
|
818
|
-
|
|
819
|
-
|
|
867
|
+
type_ref = resolve_type_annotation(decls, tp)
|
|
868
|
+
callable_ref = ConstantRef(name)
|
|
869
|
+
decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
|
|
870
|
+
_add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
|
|
871
|
+
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def _create_default_value(
|
|
875
|
+
decls: Declarations,
|
|
876
|
+
ref: CallableRef | None,
|
|
877
|
+
fn: Callable,
|
|
878
|
+
args: Iterable[TypedExprDecl],
|
|
879
|
+
ruleset: Ruleset | None,
|
|
880
|
+
) -> object:
|
|
881
|
+
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
|
|
882
|
+
|
|
883
|
+
# If this is a classmethod, add the class as the first arg
|
|
884
|
+
if isinstance(ref, ClassMethodRef):
|
|
885
|
+
tp = decls.get_paramaterized_class(ref.class_name)
|
|
886
|
+
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
|
|
887
|
+
with set_current_ruleset(ruleset):
|
|
888
|
+
return fn(*args)
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
def _add_default_rewrite_function(
|
|
892
|
+
decls: Declarations,
|
|
893
|
+
ref: CallableRef,
|
|
894
|
+
res_type: TypeOrVarRef,
|
|
895
|
+
ruleset: Ruleset | None,
|
|
896
|
+
value_thunk: Callable[[], object],
|
|
897
|
+
) -> None:
|
|
898
|
+
"""
|
|
899
|
+
Helper functions that resolves a value thunk to create the default value.
|
|
900
|
+
"""
|
|
901
|
+
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
def _add_default_rewrite(
|
|
905
|
+
decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
|
|
906
|
+
) -> None:
|
|
907
|
+
"""
|
|
908
|
+
Adds a default rewrite for the callable, if the default rewrite is not None
|
|
909
|
+
|
|
910
|
+
Will add it to the ruleset if it is passed in, or add it to the default ruleset on the passed in decls if not.
|
|
911
|
+
"""
|
|
912
|
+
if default_rewrite is None:
|
|
913
|
+
return
|
|
914
|
+
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
|
|
915
|
+
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
|
|
916
|
+
if ruleset:
|
|
917
|
+
ruleset_decls = ruleset._current_egg_decls
|
|
918
|
+
ruleset_decl = ruleset.__egg_ruleset__
|
|
919
|
+
else:
|
|
920
|
+
ruleset_decls = decls
|
|
921
|
+
ruleset_decl = decls.default_ruleset
|
|
922
|
+
ruleset_decl.rules.append(rewrite_decl)
|
|
923
|
+
ruleset_decls |= resolved_value
|
|
820
924
|
|
|
821
925
|
|
|
822
926
|
def _last_param_variable(params: list[Parameter]) -> bool:
|
|
@@ -866,6 +970,8 @@ class GraphvizKwargs(TypedDict, total=False):
|
|
|
866
970
|
max_calls_per_function: int | None
|
|
867
971
|
n_inline_leaves: int
|
|
868
972
|
split_primitive_outputs: bool
|
|
973
|
+
split_functions: list[object]
|
|
974
|
+
include_temporary_functions: bool
|
|
869
975
|
|
|
870
976
|
|
|
871
977
|
@dataclass
|
|
@@ -908,81 +1014,19 @@ class EGraph(_BaseModule):
|
|
|
908
1014
|
raise ValueError(msg)
|
|
909
1015
|
return cmds
|
|
910
1016
|
|
|
911
|
-
def
|
|
912
|
-
|
|
913
|
-
Returns the graphviz representation of the e-graph.
|
|
914
|
-
"""
|
|
915
|
-
return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}
|
|
916
|
-
|
|
917
|
-
def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
|
|
918
|
-
# By default we want to split primitive outputs
|
|
919
|
-
kwargs.setdefault("split_primitive_outputs", True)
|
|
920
|
-
n_inline = kwargs.pop("n_inline_leaves", 0)
|
|
921
|
-
serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
|
|
922
|
-
serialized.map_ops(self._state.op_mapping())
|
|
923
|
-
for _ in range(n_inline):
|
|
924
|
-
serialized.inline_leaves()
|
|
925
|
-
original = serialized.to_dot()
|
|
926
|
-
# Add link to stylesheet to the graph, so that edges light up on hover
|
|
927
|
-
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
|
|
928
|
-
styles = """/* the lines within the edges */
|
|
929
|
-
.edge:active path,
|
|
930
|
-
.edge:hover path {
|
|
931
|
-
stroke: fuchsia;
|
|
932
|
-
stroke-width: 3;
|
|
933
|
-
stroke-opacity: 1;
|
|
934
|
-
}
|
|
935
|
-
/* arrows are typically drawn with a polygon */
|
|
936
|
-
.edge:active polygon,
|
|
937
|
-
.edge:hover polygon {
|
|
938
|
-
stroke: fuchsia;
|
|
939
|
-
stroke-width: 3;
|
|
940
|
-
fill: fuchsia;
|
|
941
|
-
stroke-opacity: 1;
|
|
942
|
-
fill-opacity: 1;
|
|
943
|
-
}
|
|
944
|
-
/* If you happen to have text and want to color that as well... */
|
|
945
|
-
.edge:active text,
|
|
946
|
-
.edge:hover text {
|
|
947
|
-
fill: fuchsia;
|
|
948
|
-
}"""
|
|
949
|
-
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
|
|
950
|
-
p.write_text(styles)
|
|
951
|
-
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
|
|
952
|
-
return graphviz.Source(with_stylesheet)
|
|
953
|
-
|
|
954
|
-
def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
|
|
955
|
-
return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
|
|
956
|
-
|
|
957
|
-
def _repr_html_(self) -> str:
|
|
958
|
-
"""
|
|
959
|
-
Add a _repr_html_ to be an SVG to work with sphinx gallery.
|
|
960
|
-
|
|
961
|
-
ala https://github.com/xflr6/graphviz/pull/121
|
|
962
|
-
until this PR is merged and released
|
|
963
|
-
https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
|
|
964
|
-
"""
|
|
965
|
-
return self.graphviz_svg()
|
|
966
|
-
|
|
967
|
-
def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
968
|
-
"""
|
|
969
|
-
Displays the e-graph in the notebook.
|
|
970
|
-
"""
|
|
971
|
-
if IN_IPYTHON:
|
|
972
|
-
from IPython.display import SVG, display
|
|
973
|
-
|
|
974
|
-
display(SVG(self.graphviz_svg(**kwargs)))
|
|
975
|
-
else:
|
|
976
|
-
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
1017
|
+
def _ipython_display_(self) -> None:
|
|
1018
|
+
self.display()
|
|
977
1019
|
|
|
978
1020
|
def input(self, fn: Callable[..., String], path: str) -> None:
|
|
979
1021
|
"""
|
|
980
1022
|
Loads a CSV file and sets it as *input, output of the function.
|
|
981
1023
|
"""
|
|
1024
|
+
self._egraph.run_program(bindings.Input(self._callable_to_egg(fn), path))
|
|
1025
|
+
|
|
1026
|
+
def _callable_to_egg(self, fn: object) -> str:
|
|
982
1027
|
ref, decls = resolve_callable(fn)
|
|
983
1028
|
self._add_decls(decls)
|
|
984
|
-
|
|
985
|
-
self._egraph.run_program(bindings.Input(fn_name, path))
|
|
1029
|
+
return self._state.callable_ref_to_egg(ref)
|
|
986
1030
|
|
|
987
1031
|
def let(self, name: str, expr: EXPR) -> EXPR:
|
|
988
1032
|
"""
|
|
@@ -991,10 +1035,11 @@ class EGraph(_BaseModule):
|
|
|
991
1035
|
action = let(name, expr)
|
|
992
1036
|
self.register(action)
|
|
993
1037
|
runtime_expr = to_runtime_expr(expr)
|
|
1038
|
+
self._add_decls(runtime_expr)
|
|
994
1039
|
return cast(
|
|
995
1040
|
EXPR,
|
|
996
|
-
RuntimeExpr.
|
|
997
|
-
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
|
|
1041
|
+
RuntimeExpr.__from_values__(
|
|
1042
|
+
self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True))
|
|
998
1043
|
),
|
|
999
1044
|
)
|
|
1000
1045
|
|
|
@@ -1016,14 +1061,15 @@ class EGraph(_BaseModule):
|
|
|
1016
1061
|
self._add_decls(runtime_expr, schedule)
|
|
1017
1062
|
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
|
|
1018
1063
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1019
|
-
|
|
1064
|
+
# Must also register type
|
|
1065
|
+
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1020
1066
|
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
1021
1067
|
extract_report = self._egraph.extract_report()
|
|
1022
1068
|
if not isinstance(extract_report, bindings.Best):
|
|
1023
1069
|
msg = "No extract report saved"
|
|
1024
1070
|
raise ValueError(msg) # noqa: TRY004
|
|
1025
1071
|
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1026
|
-
return cast(EXPR, RuntimeExpr.
|
|
1072
|
+
return cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
1027
1073
|
|
|
1028
1074
|
def include(self, path: str) -> None:
|
|
1029
1075
|
"""
|
|
@@ -1078,7 +1124,7 @@ class EGraph(_BaseModule):
|
|
|
1078
1124
|
facts = _fact_likes(fact_likes)
|
|
1079
1125
|
self._add_decls(*facts)
|
|
1080
1126
|
egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
|
|
1081
|
-
return bindings.Check(egg_facts)
|
|
1127
|
+
return bindings.Check(bindings.DUMMY_SPAN, egg_facts)
|
|
1082
1128
|
|
|
1083
1129
|
@overload
|
|
1084
1130
|
def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ...
|
|
@@ -1100,7 +1146,7 @@ class EGraph(_BaseModule):
|
|
|
1100
1146
|
raise ValueError(msg) # noqa: TRY004
|
|
1101
1147
|
(new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
|
|
1102
1148
|
|
|
1103
|
-
res = cast(EXPR, RuntimeExpr.
|
|
1149
|
+
res = cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
|
|
1104
1150
|
if include_cost:
|
|
1105
1151
|
return res, extract_report.cost
|
|
1106
1152
|
return res
|
|
@@ -1118,12 +1164,15 @@ class EGraph(_BaseModule):
|
|
|
1118
1164
|
msg = "Wrong extract report type"
|
|
1119
1165
|
raise ValueError(msg) # noqa: TRY004
|
|
1120
1166
|
new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
|
|
1121
|
-
return [cast(EXPR, RuntimeExpr.
|
|
1167
|
+
return [cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
|
|
1122
1168
|
|
|
1123
1169
|
def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
|
|
1124
|
-
self._state.
|
|
1125
|
-
|
|
1126
|
-
|
|
1170
|
+
expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1171
|
+
self._egraph.run_program(
|
|
1172
|
+
bindings.ActionCommand(
|
|
1173
|
+
bindings.Extract(bindings.DUMMY_SPAN, expr, bindings.Lit(bindings.DUMMY_SPAN, bindings.Int(n)))
|
|
1174
|
+
)
|
|
1175
|
+
)
|
|
1127
1176
|
extract_report = self._egraph.extract_report()
|
|
1128
1177
|
if not extract_report:
|
|
1129
1178
|
msg = "No extract report saved"
|
|
@@ -1181,7 +1230,7 @@ class EGraph(_BaseModule):
|
|
|
1181
1230
|
runtime_expr = to_runtime_expr(expr)
|
|
1182
1231
|
self._add_decls(runtime_expr)
|
|
1183
1232
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1184
|
-
egg_expr = self._state.
|
|
1233
|
+
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1185
1234
|
match typed_expr.tp:
|
|
1186
1235
|
case JustTypeRef("i64"):
|
|
1187
1236
|
return self._egraph.eval_i64(egg_expr)
|
|
@@ -1195,42 +1244,110 @@ class EGraph(_BaseModule):
|
|
|
1195
1244
|
return self._egraph.eval_py_object(egg_expr)
|
|
1196
1245
|
raise TypeError(f"Eval not implemented for {typed_expr.tp}")
|
|
1197
1246
|
|
|
1198
|
-
def
|
|
1199
|
-
self,
|
|
1200
|
-
|
|
1201
|
-
|
|
1247
|
+
def _serialize(
|
|
1248
|
+
self,
|
|
1249
|
+
**kwargs: Unpack[GraphvizKwargs],
|
|
1250
|
+
) -> bindings.SerializedEGraph:
|
|
1251
|
+
max_functions = kwargs.pop("max_functions", None)
|
|
1252
|
+
max_calls_per_function = kwargs.pop("max_calls_per_function", None)
|
|
1253
|
+
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
|
|
1254
|
+
split_functions = kwargs.pop("split_functions", [])
|
|
1255
|
+
include_temporary_functions = kwargs.pop("include_temporary_functions", False)
|
|
1256
|
+
n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
|
|
1257
|
+
serialized = self._egraph.serialize(
|
|
1258
|
+
[],
|
|
1259
|
+
max_functions=max_functions,
|
|
1260
|
+
max_calls_per_function=max_calls_per_function,
|
|
1261
|
+
include_temporary_functions=include_temporary_functions,
|
|
1262
|
+
)
|
|
1263
|
+
if split_primitive_outputs or split_functions:
|
|
1264
|
+
additional_ops = set(map(self._callable_to_egg, split_functions))
|
|
1265
|
+
serialized.split_e_classes(self._egraph, additional_ops)
|
|
1266
|
+
serialized.map_ops(self._state.op_mapping())
|
|
1202
1267
|
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
while self.run(1).updated and i < max:
|
|
1206
|
-
i += 1
|
|
1207
|
-
dots.append(str(self.graphviz(**kwargs)))
|
|
1208
|
-
return graphviz_widget_with_slider(dots, performance=performance)
|
|
1268
|
+
for _ in range(n_inline_leaves):
|
|
1269
|
+
serialized.inline_leaves()
|
|
1209
1270
|
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
) ->
|
|
1213
|
-
|
|
1271
|
+
return serialized
|
|
1272
|
+
|
|
1273
|
+
def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
|
|
1274
|
+
serialized = self._serialize(**kwargs)
|
|
1275
|
+
|
|
1276
|
+
original = serialized.to_dot()
|
|
1277
|
+
# Add link to stylesheet to the graph, so that edges light up on hover
|
|
1278
|
+
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
|
|
1279
|
+
styles = """/* the lines within the edges */
|
|
1280
|
+
.edge:active path,
|
|
1281
|
+
.edge:hover path {
|
|
1282
|
+
stroke: fuchsia;
|
|
1283
|
+
stroke-width: 3;
|
|
1284
|
+
stroke-opacity: 1;
|
|
1285
|
+
}
|
|
1286
|
+
/* arrows are typically drawn with a polygon */
|
|
1287
|
+
.edge:active polygon,
|
|
1288
|
+
.edge:hover polygon {
|
|
1289
|
+
stroke: fuchsia;
|
|
1290
|
+
stroke-width: 3;
|
|
1291
|
+
fill: fuchsia;
|
|
1292
|
+
stroke-opacity: 1;
|
|
1293
|
+
fill-opacity: 1;
|
|
1294
|
+
}
|
|
1295
|
+
/* If you happen to have text and want to color that as well... */
|
|
1296
|
+
.edge:active text,
|
|
1297
|
+
.edge:hover text {
|
|
1298
|
+
fill: fuchsia;
|
|
1299
|
+
}"""
|
|
1300
|
+
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
|
|
1301
|
+
p.write_text(styles)
|
|
1302
|
+
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
|
|
1303
|
+
return graphviz.Source(with_stylesheet)
|
|
1214
1304
|
|
|
1215
|
-
|
|
1305
|
+
def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
1306
|
+
"""
|
|
1307
|
+
Displays the e-graph.
|
|
1216
1308
|
|
|
1217
|
-
|
|
1309
|
+
If in IPython it will display it inline, otherwise it will write it to a file and open it.
|
|
1310
|
+
"""
|
|
1311
|
+
from IPython.display import SVG, display
|
|
1218
1312
|
|
|
1219
|
-
|
|
1220
|
-
# panel.panel(widget).save(file)
|
|
1313
|
+
from .visualizer_widget import VisualizerWidget
|
|
1221
1314
|
|
|
1222
|
-
|
|
1315
|
+
if graphviz:
|
|
1316
|
+
if IN_IPYTHON:
|
|
1317
|
+
svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
|
|
1318
|
+
display(SVG(svg))
|
|
1319
|
+
else:
|
|
1320
|
+
self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
|
|
1321
|
+
else:
|
|
1322
|
+
serialized = self._serialize(**kwargs)
|
|
1323
|
+
VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
|
|
1223
1324
|
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1325
|
+
def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None:
|
|
1326
|
+
"""
|
|
1327
|
+
Saturate the egraph, running the given schedule until the egraph is saturated.
|
|
1328
|
+
It serializes the egraph at each step and returns a widget to visualize the egraph.
|
|
1329
|
+
"""
|
|
1330
|
+
from .visualizer_widget import VisualizerWidget
|
|
1331
|
+
|
|
1332
|
+
def to_json() -> str:
|
|
1333
|
+
return self._serialize(**kwargs).to_json()
|
|
1334
|
+
|
|
1335
|
+
egraphs = [to_json()]
|
|
1336
|
+
i = 0
|
|
1337
|
+
while self.run(schedule or 1).updated and i < max:
|
|
1338
|
+
i += 1
|
|
1339
|
+
egraphs.append(to_json())
|
|
1340
|
+
VisualizerWidget(egraphs=egraphs).display_or_open()
|
|
1227
1341
|
|
|
1228
1342
|
@classmethod
|
|
1229
1343
|
def current(cls) -> EGraph:
|
|
1230
1344
|
"""
|
|
1231
1345
|
Returns the current egraph, which is the one in the context.
|
|
1232
1346
|
"""
|
|
1233
|
-
|
|
1347
|
+
try:
|
|
1348
|
+
return CURRENT_EGRAPH.get()
|
|
1349
|
+
except LookupError:
|
|
1350
|
+
return cls(save_egglog_string=True)
|
|
1234
1351
|
|
|
1235
1352
|
@property
|
|
1236
1353
|
def _egraph(self) -> bindings.EGraph:
|
|
@@ -1382,7 +1499,8 @@ class Ruleset(Schedule):
|
|
|
1382
1499
|
To return the egg decls, we go through our deferred rules and add any we haven't yet
|
|
1383
1500
|
"""
|
|
1384
1501
|
while self.deferred_rule_gens:
|
|
1385
|
-
|
|
1502
|
+
with set_current_ruleset(self):
|
|
1503
|
+
rules = self.deferred_rule_gens.pop()()
|
|
1386
1504
|
self._current_egg_decls.update(*rules)
|
|
1387
1505
|
self.__egg_ruleset__.rules.extend(r.decl for r in rules)
|
|
1388
1506
|
return self._current_egg_decls
|
|
@@ -1622,16 +1740,16 @@ def action_command(action: Action) -> Action:
|
|
|
1622
1740
|
return action
|
|
1623
1741
|
|
|
1624
1742
|
|
|
1625
|
-
def var(name: str, bound: type[
|
|
1743
|
+
def var(name: str, bound: type[T]) -> T:
|
|
1626
1744
|
"""Create a new variable with the given name and type."""
|
|
1627
|
-
return cast(
|
|
1745
|
+
return cast(T, _var(name, bound))
|
|
1628
1746
|
|
|
1629
1747
|
|
|
1630
1748
|
def _var(name: str, bound: object) -> RuntimeExpr:
|
|
1631
1749
|
"""Create a new variable with the given name and type."""
|
|
1632
1750
|
decls = Declarations()
|
|
1633
1751
|
type_ref = resolve_type_annotation(decls, bound)
|
|
1634
|
-
return RuntimeExpr.
|
|
1752
|
+
return RuntimeExpr.__from_values__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name, False)))
|
|
1635
1753
|
|
|
1636
1754
|
|
|
1637
1755
|
def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
|
|
@@ -1724,7 +1842,7 @@ class _NeBuilder(Generic[EXPR]):
|
|
|
1724
1842
|
lhs = to_runtime_expr(self.lhs)
|
|
1725
1843
|
rhs = convert_to_same_type(rhs, lhs)
|
|
1726
1844
|
assert isinstance(Unit, RuntimeClass)
|
|
1727
|
-
res = RuntimeExpr.
|
|
1845
|
+
res = RuntimeExpr.__from_values__(
|
|
1728
1846
|
Declarations.create(Unit, lhs, rhs),
|
|
1729
1847
|
TypedExprDecl(
|
|
1730
1848
|
JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
|
|
@@ -1893,3 +2011,19 @@ def _fact_like(fact_like: FactLike) -> Fact:
|
|
|
1893
2011
|
if isinstance(fact_like, Expr):
|
|
1894
2012
|
return expr_fact(fact_like)
|
|
1895
2013
|
return fact_like
|
|
2014
|
+
|
|
2015
|
+
|
|
2016
|
+
_CURRENT_RULESET = ContextVar[Ruleset | None]("CURRENT_RULESET", default=None)
|
|
2017
|
+
|
|
2018
|
+
|
|
2019
|
+
def get_current_ruleset() -> Ruleset | None:
|
|
2020
|
+
return _CURRENT_RULESET.get()
|
|
2021
|
+
|
|
2022
|
+
|
|
2023
|
+
@contextlib.contextmanager
|
|
2024
|
+
def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
2025
|
+
token = _CURRENT_RULESET.set(r)
|
|
2026
|
+
try:
|
|
2027
|
+
yield
|
|
2028
|
+
finally:
|
|
2029
|
+
_CURRENT_RULESET.reset(token)
|