egglog 9.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/config.py ADDED
@@ -0,0 +1,8 @@
1
+ """
2
+ Global configuration for egglog.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Whether to display the type of each node in the graph when printing.
8
+ SHOW_TYPES = False
egglog/conversion.py ADDED
@@ -0,0 +1,262 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from contextlib import contextmanager
5
+ from contextvars import ContextVar
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, TypeVar, cast
8
+
9
+ from .declarations import *
10
+ from .pretty import *
11
+ from .runtime import *
12
+ from .thunk import *
13
+ from .type_constraint_solver import TypeConstraintError
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Callable, Generator
17
+
18
+ from .egraph import BaseExpr
19
+ from .type_constraint_solver import TypeConstraintSolver
20
+
21
+ __all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
22
+ # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
23
+ CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
24
+ # Global declerations to store all convertable types so we can query if they have certain methods or not
25
+ _CONVERSION_DECLS = Declarations.create()
26
+ # Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
27
+ # until we need them
28
+ _TO_PROCESS_DECLS: list[DeclerationsLike] = []
29
+
30
+
31
+ def _retrieve_conversion_decls() -> Declarations:
32
+ _CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
33
+ _TO_PROCESS_DECLS.clear()
34
+ return _CONVERSION_DECLS
35
+
36
+
37
+ T = TypeVar("T")
38
+ V = TypeVar("V", bound="BaseExpr")
39
+
40
+
41
+ class ConvertError(Exception):
42
+ pass
43
+
44
+
45
+ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None:
46
+ """
47
+ Register a converter from some type to an egglog type.
48
+ """
49
+ to_type_name = process_tp(to_type)
50
+ if not isinstance(to_type_name, JustTypeRef):
51
+ raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
52
+ _register_converter(process_tp(from_type), to_type_name, fn, cost)
53
+
54
+
55
+ def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
56
+ """
57
+ Registers a converter from some type to an egglog type, if not already registered.
58
+
59
+ Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
60
+ Also, if registering A->B and there is already D->A, then D->B will be registered.
61
+ """
62
+ if a == b:
63
+ return
64
+ if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost:
65
+ return
66
+ CONVERSIONS[(a, b)] = (cost, a_b)
67
+ for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
68
+ if _is_type_compatible(b, c):
69
+ _register_converter(
70
+ a, d, _ComposedConverter(a_b, c_d, c.args if isinstance(c, JustTypeRef) else ()), cost + other_cost
71
+ )
72
+ if _is_type_compatible(a, d):
73
+ _register_converter(
74
+ c, b, _ComposedConverter(c_d, a_b, a.args if isinstance(a, JustTypeRef) else ()), cost + other_cost
75
+ )
76
+
77
+
78
+ def _is_type_compatible(source: type | JustTypeRef, target: type | JustTypeRef) -> bool:
79
+ """
80
+ Types must be equal or also support unbound to bound typevar like B -> B[C]
81
+ """
82
+ if source == target:
83
+ return True
84
+ if isinstance(source, JustTypeRef) and isinstance(target, JustTypeRef) and source.args and not target.args:
85
+ return source.name == target.name
86
+ # TODO: Support case where B[T] where T is typevar is mapped to B[C]
87
+ return False
88
+
89
+
90
+ @dataclass
91
+ class _ComposedConverter:
92
+ """
93
+ A converter which is composed of multiple converters.
94
+
95
+ _ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
96
+
97
+ We use the dataclass instead of the lambda to make it easier to debug.
98
+ """
99
+
100
+ a_b: Callable
101
+ b_c: Callable
102
+ b_args: tuple[JustTypeRef, ...]
103
+
104
+ def __call__(self, x: object) -> object:
105
+ # if we have A -> B and B[C] -> D then we should use (C,) as the type args
106
+ # when converting from A -> B
107
+ if self.b_args:
108
+ with with_type_args(self.b_args, _retrieve_conversion_decls):
109
+ first_res = self.a_b(x)
110
+ else:
111
+ first_res = self.a_b(x)
112
+ return self.b_c(first_res)
113
+
114
+ def __str__(self) -> str:
115
+ return f"{self.b_c} ∘ {self.a_b}"
116
+
117
+
118
+ def convert(source: object, target: type[V]) -> V:
119
+ """
120
+ Convert a source object to a target type.
121
+ """
122
+ assert isinstance(target, RuntimeClass)
123
+ return cast("V", resolve_literal(target.__egg_tp__, source, target.__egg_decls_thunk__))
124
+
125
+
126
+ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
127
+ """
128
+ Convert a source object to the same type as the target.
129
+ """
130
+ tp = target.__egg_typed_expr__.tp
131
+ return resolve_literal(tp.to_var(), source, Thunk.value(target.__egg_decls__))
132
+
133
+
134
+ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
135
+ """
136
+ Process a type before converting it, to add it to the global declerations and resolve to a ref.
137
+ """
138
+ if isinstance(tp, RuntimeClass):
139
+ _TO_PROCESS_DECLS.append(tp)
140
+ egg_tp = tp.__egg_tp__
141
+ return egg_tp.to_just()
142
+ return tp
143
+
144
+
145
+ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
146
+ """
147
+ Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
148
+ """
149
+ decls = _retrieve_conversion_decls()
150
+ a_tp = _get_tp(a)
151
+ b_tp = _get_tp(b)
152
+ a_converts_to = {
153
+ to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
154
+ }
155
+ b_converts_to = {
156
+ to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
157
+ }
158
+ if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
159
+ a_converts_to[a_tp] = 0
160
+ if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
161
+ b_converts_to[b_tp] = 0
162
+ common = set(a_converts_to) & set(b_converts_to)
163
+ if not common:
164
+ raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
165
+ return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
166
+
167
+
168
+ def identity(x: object) -> object:
169
+ return x
170
+
171
+
172
+ TYPE_ARGS = ContextVar[tuple[RuntimeClass, ...]]("TYPE_ARGS")
173
+
174
+
175
+ def get_type_args() -> tuple[type, ...]:
176
+ """
177
+ Get the type args for the type being converted.
178
+ """
179
+ return cast("tuple[type, ...]", TYPE_ARGS.get())
180
+
181
+
182
+ @contextmanager
183
+ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declarations]) -> Generator[None, None, None]:
184
+ token = TYPE_ARGS.set(tuple(RuntimeClass(decls, a.to_var()) for a in args))
185
+ try:
186
+ yield
187
+ finally:
188
+ TYPE_ARGS.reset(token)
189
+
190
+
191
+ def resolve_literal(
192
+ tp: TypeOrVarRef,
193
+ arg: object,
194
+ decls: Callable[[], Declarations] = _retrieve_conversion_decls,
195
+ tcs: TypeConstraintSolver | None = None,
196
+ cls_name: str | None = None,
197
+ ) -> RuntimeExpr:
198
+ """
199
+ Try to convert an object to a type, raising a ConvertError if it is not possible.
200
+
201
+ If the type has vars in it, they will be tried to be resolved into concrete vars based on the type constraint solver.
202
+
203
+ If it cannot be resolved, we assume that the value passed in will resolve it.
204
+ """
205
+ arg_type = _get_tp(arg)
206
+
207
+ # If we have any type variables, dont bother trying to resolve the literal, just return the arg
208
+ try:
209
+ tp_just = tp.to_just()
210
+ except NotImplementedError:
211
+ # If this is a generic arg but passed in a non runtime expression, try to resolve the generic
212
+ # args first based on the existing type constraint solver
213
+ if tcs:
214
+ try:
215
+ tp_just = tcs.substitute_typevars(tp, cls_name)
216
+ # If we can't resolve the type var yet, then just assume it is the right value
217
+ except TypeConstraintError:
218
+ assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
219
+ tp_just = arg.__egg_typed_expr__.tp
220
+ else:
221
+ # If this is a var, it has to be a runtime expession
222
+ assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
223
+ return arg
224
+ if tcs:
225
+ tcs.infer_typevars(tp, tp_just, cls_name)
226
+ if arg_type == tp_just:
227
+ # If the type is an egg type, it has to be a runtime expr
228
+ assert isinstance(arg, RuntimeExpr)
229
+ return arg
230
+ # Try all parent types as well, if we are converting from a Python type
231
+ for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
232
+ if (key := (arg_type_instance, tp_just)) in CONVERSIONS:
233
+ fn = CONVERSIONS[key][1]
234
+ break
235
+ # Try broadening if we have a convert to the general type instead of the specific one too, for generics
236
+ if tp_just.args and (key := (arg_type_instance, JustTypeRef(tp_just.name))) in CONVERSIONS:
237
+ fn = CONVERSIONS[key][1]
238
+ break
239
+ # if we didn't find any raise an error
240
+ else:
241
+ raise ConvertError(f"Cannot convert {arg_type} to {tp_just}")
242
+ with with_type_args(tp_just.args, decls):
243
+ return fn(arg)
244
+
245
+
246
+ def _debug_print_converers():
247
+ """
248
+ Prints a mapping of all source types to target types that have a conversion function.
249
+ """
250
+ source_to_targets = defaultdict(list)
251
+ for source, target in CONVERSIONS:
252
+ source_to_targets[source].append(target)
253
+
254
+
255
+ def _get_tp(x: object) -> JustTypeRef | type:
256
+ if isinstance(x, RuntimeExpr):
257
+ return x.__egg_typed_expr__.tp
258
+ tp = type(x)
259
+ # If this value has a custom metaclass, let's use that as our index instead of the type
260
+ if type(tp) is not type:
261
+ return type(tp)
262
+ return tp