egglog 11.2.0__cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +13 -0
- egglog/bindings.cpython-314-x86_64-linux-gnu.so +0 -0
- egglog/bindings.pyi +734 -0
- egglog/builtins.py +1133 -0
- egglog/config.py +8 -0
- egglog/conversion.py +286 -0
- egglog/declarations.py +912 -0
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +1875 -0
- egglog/egraph_state.py +680 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +3 -0
- egglog/examples/bignum.py +32 -0
- egglog/examples/bool.py +38 -0
- egglog/examples/eqsat_basic.py +44 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/higher_order_functions.py +42 -0
- egglog/examples/jointree.py +67 -0
- egglog/examples/lambda_.py +287 -0
- egglog/examples/matrix.py +175 -0
- egglog/examples/multiset.py +60 -0
- egglog/examples/ndarrays.py +144 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +34 -0
- egglog/exp/__init__.py +3 -0
- egglog/exp/array_api.py +2019 -0
- egglog/exp/array_api_jit.py +51 -0
- egglog/exp/array_api_loopnest.py +74 -0
- egglog/exp/array_api_numba.py +69 -0
- egglog/exp/array_api_program_gen.py +510 -0
- egglog/exp/program_gen.py +425 -0
- egglog/exp/siu_examples.py +32 -0
- egglog/ipython_magic.py +41 -0
- egglog/pretty.py +509 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +712 -0
- egglog/thunk.py +97 -0
- egglog/type_constraint_solver.py +113 -0
- egglog/version_compat.py +87 -0
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35777 -0
- egglog/visualizer_widget.py +39 -0
- egglog-11.2.0.dist-info/METADATA +74 -0
- egglog-11.2.0.dist-info/RECORD +46 -0
- egglog-11.2.0.dist-info/WHEEL +4 -0
- egglog-11.2.0.dist-info/licenses/LICENSE +21 -0
egglog/thunk.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from typing_extensions import TypeVarTuple, Unpack
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = ["Thunk", "split_thunk"]
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
TS = TypeVarTuple("TS")
|
|
16
|
+
V = TypeVar("V")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def split_thunk(fn: Callable[[], tuple[T, V]]) -> tuple[Callable[[], T], Callable[[], V]]:
|
|
20
|
+
s = _Split(fn)
|
|
21
|
+
return s.left, s.right
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class _Split(Generic[T, V]):
|
|
26
|
+
fn: Callable[[], tuple[T, V]]
|
|
27
|
+
|
|
28
|
+
def left(self) -> T:
|
|
29
|
+
return self.fn()[0]
|
|
30
|
+
|
|
31
|
+
def right(self) -> V:
|
|
32
|
+
return self.fn()[1]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class Thunk(Generic[T, Unpack[TS]]):
|
|
37
|
+
"""
|
|
38
|
+
Cached delayed function call.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], context: str | None = None) -> Thunk[T, Unpack[TS]]:
|
|
45
|
+
"""
|
|
46
|
+
Create a thunk based on some functions and some partial args.
|
|
47
|
+
|
|
48
|
+
If the function is called while it is being resolved recursively it will raise an exception.
|
|
49
|
+
"""
|
|
50
|
+
return cls(Unresolved(fn, args, context))
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def value(cls, value: T) -> Thunk[T]:
|
|
54
|
+
return Thunk(Resolved(value))
|
|
55
|
+
|
|
56
|
+
def __call__(self) -> T:
|
|
57
|
+
match self.state:
|
|
58
|
+
case Resolved(value):
|
|
59
|
+
return value
|
|
60
|
+
case Unresolved(fn, args, context):
|
|
61
|
+
self.state = Resolving()
|
|
62
|
+
try:
|
|
63
|
+
res = fn(*args)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
self.state = Error(e, context)
|
|
66
|
+
raise e from None
|
|
67
|
+
else:
|
|
68
|
+
self.state = Resolved(res)
|
|
69
|
+
return res
|
|
70
|
+
case Resolving():
|
|
71
|
+
msg = "Recursively resolving thunk"
|
|
72
|
+
raise ValueError(msg)
|
|
73
|
+
case Error(e):
|
|
74
|
+
raise e
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class Resolved(Generic[T]):
|
|
79
|
+
value: T
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass
|
|
83
|
+
class Unresolved(Generic[T, Unpack[TS]]):
|
|
84
|
+
fn: Callable[[Unpack[TS]], T]
|
|
85
|
+
args: tuple[Unpack[TS]]
|
|
86
|
+
context: str | None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class Resolving:
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class Error:
|
|
96
|
+
e: Exception
|
|
97
|
+
context: str | None
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Provides a class for solving type constraints."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from itertools import chain, repeat
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from .declarations import *
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import Collection, Iterable
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
__all__ = ["TypeConstraintError", "TypeConstraintSolver"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TypeConstraintError(RuntimeError):
|
|
22
|
+
"""Typing error when trying to infer the return type."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class TypeConstraintSolver:
|
|
27
|
+
"""
|
|
28
|
+
Given some typevars and types, solves the constraints to resolve the typevars.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
_decls: Declarations = field(repr=False)
|
|
32
|
+
# Mapping of class name to mapping of bound class typevar to type
|
|
33
|
+
_cls_typevar_index_to_type: defaultdict[str, dict[ClassTypeVarRef, JustTypeRef]] = field(
|
|
34
|
+
default_factory=lambda: defaultdict(dict)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def bind_class(self, ref: JustTypeRef) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Bind the typevars of a class to the given types.
|
|
40
|
+
Used for a situation like Map[int, str].create().
|
|
41
|
+
"""
|
|
42
|
+
name = ref.name
|
|
43
|
+
cls_typevars = self._decls.get_class_decl(name).type_vars
|
|
44
|
+
if len(cls_typevars) != len(ref.args):
|
|
45
|
+
raise TypeConstraintError(f"Mismatch of typevars {cls_typevars} and {ref}")
|
|
46
|
+
bound_typevars = self._cls_typevar_index_to_type[name]
|
|
47
|
+
for i, arg in enumerate(ref.args):
|
|
48
|
+
bound_typevars[cls_typevars[i]] = arg
|
|
49
|
+
|
|
50
|
+
def infer_arg_types(
|
|
51
|
+
self,
|
|
52
|
+
fn_args: Collection[TypeOrVarRef],
|
|
53
|
+
fn_return: TypeOrVarRef,
|
|
54
|
+
fn_var_args: TypeOrVarRef | None,
|
|
55
|
+
return_: JustTypeRef,
|
|
56
|
+
cls_name: str | None,
|
|
57
|
+
) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...] | None]:
|
|
58
|
+
"""
|
|
59
|
+
Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable.
|
|
60
|
+
|
|
61
|
+
Also returns the bound type params if the class name is passed in.
|
|
62
|
+
"""
|
|
63
|
+
self.infer_typevars(fn_return, return_, cls_name)
|
|
64
|
+
arg_types: Iterable[JustTypeRef] = [self.substitute_typevars(a, cls_name) for a in fn_args]
|
|
65
|
+
if fn_var_args:
|
|
66
|
+
# Need to be generator so it can be infinite for variable args
|
|
67
|
+
arg_types = chain(arg_types, repeat(self.substitute_typevars(fn_var_args, cls_name)))
|
|
68
|
+
bound_typevars = (
|
|
69
|
+
tuple(
|
|
70
|
+
v
|
|
71
|
+
# Sort by the index of the typevar in the class
|
|
72
|
+
for _, v in sorted(
|
|
73
|
+
self._cls_typevar_index_to_type[cls_name].items(),
|
|
74
|
+
key=lambda kv: self._decls.get_class_decl(cls_name).type_vars.index(kv[0]),
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
if cls_name
|
|
78
|
+
else None
|
|
79
|
+
)
|
|
80
|
+
return arg_types, bound_typevars
|
|
81
|
+
|
|
82
|
+
def infer_typevars(self, fn_arg: TypeOrVarRef, arg: JustTypeRef, cls_name: str | None = None) -> None:
|
|
83
|
+
match fn_arg:
|
|
84
|
+
case TypeRefWithVars(cls_name, fn_args):
|
|
85
|
+
if cls_name != arg.name:
|
|
86
|
+
raise TypeConstraintError(f"Expected {cls_name}, got {arg.name}")
|
|
87
|
+
for inner_fn_arg, inner_arg in zip(fn_args, arg.args, strict=True):
|
|
88
|
+
self.infer_typevars(inner_fn_arg, inner_arg, cls_name)
|
|
89
|
+
case ClassTypeVarRef():
|
|
90
|
+
if cls_name is None:
|
|
91
|
+
msg = "Cannot infer typevar without class name"
|
|
92
|
+
raise RuntimeError(msg)
|
|
93
|
+
|
|
94
|
+
class_typevars = self._cls_typevar_index_to_type[cls_name]
|
|
95
|
+
if fn_arg in class_typevars:
|
|
96
|
+
if class_typevars[fn_arg] != arg:
|
|
97
|
+
raise TypeConstraintError(f"Expected {class_typevars[fn_arg]}, got {arg}")
|
|
98
|
+
else:
|
|
99
|
+
class_typevars[fn_arg] = arg
|
|
100
|
+
case _:
|
|
101
|
+
assert_never(fn_arg)
|
|
102
|
+
|
|
103
|
+
def substitute_typevars(self, tp: TypeOrVarRef, cls_name: str | None = None) -> JustTypeRef:
|
|
104
|
+
match tp:
|
|
105
|
+
case ClassTypeVarRef():
|
|
106
|
+
assert cls_name is not None
|
|
107
|
+
try:
|
|
108
|
+
return self._cls_typevar_index_to_type[cls_name][tp]
|
|
109
|
+
except KeyError as e:
|
|
110
|
+
raise TypeConstraintError(f"Not enough bound typevars for {tp!r} in class {cls_name}") from e
|
|
111
|
+
case TypeRefWithVars(name, args):
|
|
112
|
+
return JustTypeRef(name, tuple(self.substitute_typevars(arg, cls_name) for arg in args))
|
|
113
|
+
assert_never(tp)
|
egglog/version_compat.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import sys
|
|
3
|
+
import types
|
|
4
|
+
import typing
|
|
5
|
+
|
|
6
|
+
BEFORE_3_11 = sys.version_info < (3, 11)
|
|
7
|
+
|
|
8
|
+
__all__ = ["add_note"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def add_note(message: str, exc: BaseException) -> BaseException:
|
|
12
|
+
"""
|
|
13
|
+
Backwards compatible add_note for Python <= 3.10
|
|
14
|
+
"""
|
|
15
|
+
if BEFORE_3_11:
|
|
16
|
+
return exc
|
|
17
|
+
exc.add_note(message)
|
|
18
|
+
return exc
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# For Python version 3.10 need to monkeypatch this function so that RuntimeClass type parameters
|
|
22
|
+
# will be collected as typevars
|
|
23
|
+
if BEFORE_3_11:
|
|
24
|
+
|
|
25
|
+
@typing.no_type_check
|
|
26
|
+
def _collect_type_vars_monkeypatch(types_, typevar_types=None):
|
|
27
|
+
"""
|
|
28
|
+
Collect all type variable contained
|
|
29
|
+
in types in order of first appearance (lexicographic order). For example::
|
|
30
|
+
|
|
31
|
+
_collect_type_vars((T, List[S, T])) == (T, S)
|
|
32
|
+
"""
|
|
33
|
+
from .runtime import RuntimeClass # noqa: PLC0415
|
|
34
|
+
|
|
35
|
+
if typevar_types is None:
|
|
36
|
+
typevar_types = typing.TypeVar
|
|
37
|
+
tvars = []
|
|
38
|
+
for t in types_:
|
|
39
|
+
if isinstance(t, typevar_types) and t not in tvars:
|
|
40
|
+
tvars.append(t)
|
|
41
|
+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
|
|
42
|
+
if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
|
|
43
|
+
tvars.extend([t for t in t.__parameters__ if t not in tvars])
|
|
44
|
+
return tuple(tvars)
|
|
45
|
+
|
|
46
|
+
typing._collect_type_vars = _collect_type_vars_monkeypatch # type: ignore[attr-defined]
|
|
47
|
+
|
|
48
|
+
@typing.no_type_check
|
|
49
|
+
@typing._tp_cache
|
|
50
|
+
def __getitem__monkeypatch(self, params): # noqa: C901, PLR0912
|
|
51
|
+
from .runtime import RuntimeClass # noqa: PLC0415
|
|
52
|
+
|
|
53
|
+
if self.__origin__ in (typing.Generic, typing.Protocol):
|
|
54
|
+
# Can't subscript Generic[...] or Protocol[...].
|
|
55
|
+
raise TypeError(f"Cannot subscript already-subscripted {self}")
|
|
56
|
+
if not isinstance(params, tuple):
|
|
57
|
+
params = (params,)
|
|
58
|
+
params = tuple(typing._type_convert(p) for p in params)
|
|
59
|
+
if self._paramspec_tvars and any(isinstance(t, typing.ParamSpec) for t in self.__parameters__):
|
|
60
|
+
params = typing._prepare_paramspec_params(self, params)
|
|
61
|
+
else:
|
|
62
|
+
typing._check_generic(self, params, len(self.__parameters__))
|
|
63
|
+
|
|
64
|
+
subst = dict(zip(self.__parameters__, params, strict=False))
|
|
65
|
+
new_args = []
|
|
66
|
+
for arg in self.__args__:
|
|
67
|
+
if isinstance(arg, self._typevar_types):
|
|
68
|
+
if isinstance(arg, typing.ParamSpec):
|
|
69
|
+
arg = subst[arg] # noqa: PLW2901
|
|
70
|
+
if not typing._is_param_expr(arg):
|
|
71
|
+
raise TypeError(f"Expected a list of types, an ellipsis, ParamSpec, or Concatenate. Got {arg}")
|
|
72
|
+
else:
|
|
73
|
+
arg = subst[arg] # noqa: PLW2901
|
|
74
|
+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
|
|
75
|
+
elif isinstance(arg, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
|
|
76
|
+
subparams = arg.__parameters__
|
|
77
|
+
if subparams:
|
|
78
|
+
subargs = tuple(subst[x] for x in subparams)
|
|
79
|
+
arg = arg[subargs] # noqa: PLW2901
|
|
80
|
+
# Required to flatten out the args for CallableGenericAlias
|
|
81
|
+
if self.__origin__ == collections.abc.Callable and isinstance(arg, tuple):
|
|
82
|
+
new_args.extend(arg)
|
|
83
|
+
else:
|
|
84
|
+
new_args.append(arg)
|
|
85
|
+
return self.copy_with(tuple(new_args))
|
|
86
|
+
|
|
87
|
+
typing._GenericAlias.__getitem__ = __getitem__monkeypatch # type: ignore[attr-defined]
|