egglog 7.2.0__cp310-none-win_amd64.whl → 8.0.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +83 -4
- egglog/egraph.py +241 -173
- egglog/egraph_state.py +137 -61
- egglog/examples/higher_order_functions.py +3 -8
- egglog/exp/array_api.py +274 -92
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +84 -40
- egglog/runtime.py +52 -39
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/METADATA +33 -32
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
|
Binary file
|
egglog/bindings.pyi
CHANGED
|
@@ -14,6 +14,7 @@ class SerializedEGraph:
|
|
|
14
14
|
def to_dot(self) -> str: ...
|
|
15
15
|
def to_json(self) -> str: ...
|
|
16
16
|
def map_ops(self, map: dict[str, str]) -> None: ...
|
|
17
|
+
def split_e_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
|
|
17
18
|
|
|
18
19
|
@final
|
|
19
20
|
class PyObjectSort:
|
|
@@ -32,7 +33,7 @@ class EGraph:
|
|
|
32
33
|
record: bool = False,
|
|
33
34
|
) -> None: ...
|
|
34
35
|
def commands(self) -> str | None: ...
|
|
35
|
-
def parse_program(self, __input: str,
|
|
36
|
+
def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ...
|
|
36
37
|
def run_program(self, *commands: _Command) -> list[str]: ...
|
|
37
38
|
def extract_report(self) -> _ExtractReport | None: ...
|
|
38
39
|
def run_report(self) -> RunReport | None: ...
|
|
@@ -43,7 +44,6 @@ class EGraph:
|
|
|
43
44
|
max_functions: int | None = None,
|
|
44
45
|
max_calls_per_function: int | None = None,
|
|
45
46
|
include_temporary_functions: bool = False,
|
|
46
|
-
split_primitive_outputs: bool = False,
|
|
47
47
|
) -> SerializedEGraph: ...
|
|
48
48
|
def eval_py_object(self, __expr: _Expr) -> object: ...
|
|
49
49
|
def eval_i64(self, __expr: _Expr) -> int: ...
|
|
@@ -56,6 +56,25 @@ class EGraph:
|
|
|
56
56
|
class EggSmolError(Exception):
|
|
57
57
|
context: str
|
|
58
58
|
|
|
59
|
+
##
|
|
60
|
+
# Spans
|
|
61
|
+
##
|
|
62
|
+
|
|
63
|
+
@final
|
|
64
|
+
class SrcFile:
|
|
65
|
+
def __init__(self, name: str, contents: str | None = None) -> None: ...
|
|
66
|
+
name: str
|
|
67
|
+
contents: str | None
|
|
68
|
+
|
|
69
|
+
@final
|
|
70
|
+
class Span:
|
|
71
|
+
def __init__(self, file: SrcFile, start: int, end: int) -> None: ...
|
|
72
|
+
file: SrcFile
|
|
73
|
+
start: int
|
|
74
|
+
end: int
|
|
75
|
+
|
|
76
|
+
DUMMY_SPAN: Span = ...
|
|
77
|
+
|
|
59
78
|
##
|
|
60
79
|
# Literals
|
|
61
80
|
##
|
|
@@ -92,17 +111,20 @@ _Literal: TypeAlias = Int | F64 | String | Bool | Unit
|
|
|
92
111
|
|
|
93
112
|
@final
|
|
94
113
|
class Lit:
|
|
95
|
-
def __init__(self, value: _Literal) -> None: ...
|
|
114
|
+
def __init__(self, span: Span, value: _Literal) -> None: ...
|
|
115
|
+
span: Span
|
|
96
116
|
value: _Literal
|
|
97
117
|
|
|
98
118
|
@final
|
|
99
119
|
class Var:
|
|
100
|
-
def __init__(self, name: str) -> None: ...
|
|
120
|
+
def __init__(self, span: Span, name: str) -> None: ...
|
|
121
|
+
span: Span
|
|
101
122
|
name: str
|
|
102
123
|
|
|
103
124
|
@final
|
|
104
125
|
class Call:
|
|
105
|
-
def __init__(self, name: str, args: list[_Expr]) -> None: ...
|
|
126
|
+
def __init__(self, span: Span, name: str, args: list[_Expr]) -> None: ...
|
|
127
|
+
span: Span
|
|
106
128
|
name: str
|
|
107
129
|
args: list[_Expr]
|
|
108
130
|
|
|
@@ -142,7 +164,8 @@ class TermDag:
|
|
|
142
164
|
|
|
143
165
|
@final
|
|
144
166
|
class Eq:
|
|
145
|
-
def __init__(self, exprs: list[_Expr]) -> None: ...
|
|
167
|
+
def __init__(self, span: Span, exprs: list[_Expr]) -> None: ...
|
|
168
|
+
span: Span
|
|
146
169
|
exprs: list[_Expr]
|
|
147
170
|
|
|
148
171
|
@final
|
|
@@ -172,43 +195,50 @@ _Change: TypeAlias = Delete | Subsume
|
|
|
172
195
|
|
|
173
196
|
@final
|
|
174
197
|
class Let:
|
|
175
|
-
def __init__(self, lhs: str, rhs: _Expr) -> None: ...
|
|
198
|
+
def __init__(self, span: Span, lhs: str, rhs: _Expr) -> None: ...
|
|
199
|
+
span: Span
|
|
176
200
|
lhs: str
|
|
177
201
|
rhs: _Expr
|
|
178
202
|
|
|
179
203
|
@final
|
|
180
204
|
class Set:
|
|
181
|
-
def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
|
|
205
|
+
def __init__(self, span: Span, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
|
|
206
|
+
span: Span
|
|
182
207
|
lhs: str
|
|
183
208
|
args: list[_Expr]
|
|
184
209
|
rhs: _Expr
|
|
185
210
|
|
|
186
211
|
@final
|
|
187
212
|
class Change:
|
|
213
|
+
span: Span
|
|
188
214
|
change: _Change
|
|
189
215
|
sym: str
|
|
190
216
|
args: list[_Expr]
|
|
191
|
-
def __init__(self, change: _Change, sym: str, args: list[_Expr]) -> None: ...
|
|
217
|
+
def __init__(self, span: Span, change: _Change, sym: str, args: list[_Expr]) -> None: ...
|
|
192
218
|
|
|
193
219
|
@final
|
|
194
220
|
class Union:
|
|
195
|
-
def __init__(self, lhs: _Expr, rhs: _Expr) -> None: ...
|
|
221
|
+
def __init__(self, span: Span, lhs: _Expr, rhs: _Expr) -> None: ...
|
|
222
|
+
span: Span
|
|
196
223
|
lhs: _Expr
|
|
197
224
|
rhs: _Expr
|
|
198
225
|
|
|
199
226
|
@final
|
|
200
227
|
class Panic:
|
|
201
|
-
def __init__(self, msg: str) -> None: ...
|
|
228
|
+
def __init__(self, span: Span, msg: str) -> None: ...
|
|
229
|
+
span: Span
|
|
202
230
|
msg: str
|
|
203
231
|
|
|
204
232
|
@final
|
|
205
233
|
class Expr_: # noqa: N801
|
|
206
|
-
def __init__(self, expr: _Expr) -> None: ...
|
|
234
|
+
def __init__(self, span: Span, expr: _Expr) -> None: ...
|
|
235
|
+
span: Span
|
|
207
236
|
expr: _Expr
|
|
208
237
|
|
|
209
238
|
@final
|
|
210
239
|
class Extract:
|
|
211
|
-
def __init__(self, expr: _Expr, variants: _Expr) -> None: ...
|
|
240
|
+
def __init__(self, span: Span, expr: _Expr, variants: _Expr) -> None: ...
|
|
241
|
+
span: Span
|
|
212
242
|
expr: _Expr
|
|
213
243
|
variants: _Expr
|
|
214
244
|
|
|
@@ -256,17 +286,19 @@ class Schema:
|
|
|
256
286
|
|
|
257
287
|
@final
|
|
258
288
|
class Rule:
|
|
289
|
+
span: Span
|
|
259
290
|
head: list[_Action]
|
|
260
291
|
body: list[_Fact]
|
|
261
|
-
def __init__(self, head: list[_Action], body: list[_Fact]) -> None: ...
|
|
292
|
+
def __init__(self, span: Span, head: list[_Action], body: list[_Fact]) -> None: ...
|
|
262
293
|
|
|
263
294
|
@final
|
|
264
295
|
class Rewrite:
|
|
296
|
+
span: Span
|
|
265
297
|
lhs: _Expr
|
|
266
298
|
rhs: _Expr
|
|
267
299
|
conditions: list[_Fact]
|
|
268
300
|
|
|
269
|
-
def __init__(self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
|
|
301
|
+
def __init__(self, span: Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
|
|
270
302
|
|
|
271
303
|
@final
|
|
272
304
|
class RunConfig:
|
|
@@ -322,24 +354,28 @@ _ExtractReport: TypeAlias = Variants | Best
|
|
|
322
354
|
|
|
323
355
|
@final
|
|
324
356
|
class Saturate:
|
|
357
|
+
span: Span
|
|
325
358
|
schedule: _Schedule
|
|
326
|
-
def __init__(self, schedule: _Schedule) -> None: ...
|
|
359
|
+
def __init__(self, span: Span, schedule: _Schedule) -> None: ...
|
|
327
360
|
|
|
328
361
|
@final
|
|
329
362
|
class Repeat:
|
|
363
|
+
span: Span
|
|
330
364
|
length: int
|
|
331
365
|
schedule: _Schedule
|
|
332
|
-
def __init__(self, length: int, schedule: _Schedule) -> None: ...
|
|
366
|
+
def __init__(self, span: Span, length: int, schedule: _Schedule) -> None: ...
|
|
333
367
|
|
|
334
368
|
@final
|
|
335
369
|
class Run:
|
|
370
|
+
span: Span
|
|
336
371
|
config: RunConfig
|
|
337
|
-
def __init__(self, config: RunConfig) -> None: ...
|
|
372
|
+
def __init__(self, span: Span, config: RunConfig) -> None: ...
|
|
338
373
|
|
|
339
374
|
@final
|
|
340
375
|
class Sequence:
|
|
376
|
+
span: Span
|
|
341
377
|
schedules: list[_Schedule]
|
|
342
|
-
def __init__(self, schedules: list[_Schedule]) -> None: ...
|
|
378
|
+
def __init__(self, span: Span, schedules: list[_Schedule]) -> None: ...
|
|
343
379
|
|
|
344
380
|
_Schedule: TypeAlias = Saturate | Repeat | Run | Sequence
|
|
345
381
|
|
|
@@ -361,9 +397,10 @@ class Datatype:
|
|
|
361
397
|
|
|
362
398
|
@final
|
|
363
399
|
class Declare:
|
|
400
|
+
span: Span
|
|
364
401
|
name: str
|
|
365
402
|
sort: str
|
|
366
|
-
def __init__(self, name: str, sort: str) -> None: ...
|
|
403
|
+
def __init__(self, span: Span, name: str, sort: str) -> None: ...
|
|
367
404
|
|
|
368
405
|
@final
|
|
369
406
|
class Sort:
|
|
@@ -421,9 +458,10 @@ class Simplify:
|
|
|
421
458
|
|
|
422
459
|
@final
|
|
423
460
|
class Calc:
|
|
461
|
+
span: Span
|
|
424
462
|
identifiers: list[IdentSort]
|
|
425
463
|
exprs: list[_Expr]
|
|
426
|
-
def __init__(self, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
|
|
464
|
+
def __init__(self, span: Span, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
|
|
427
465
|
|
|
428
466
|
@final
|
|
429
467
|
class QueryExtract:
|
|
@@ -433,14 +471,16 @@ class QueryExtract:
|
|
|
433
471
|
|
|
434
472
|
@final
|
|
435
473
|
class Check:
|
|
474
|
+
span: Span
|
|
436
475
|
facts: list[_Fact]
|
|
437
|
-
def __init__(self, facts: list[_Fact]) -> None: ...
|
|
476
|
+
def __init__(self, span: Span, facts: list[_Fact]) -> None: ...
|
|
438
477
|
|
|
439
478
|
@final
|
|
440
479
|
class PrintFunction:
|
|
480
|
+
span: Span
|
|
441
481
|
name: str
|
|
442
482
|
length: int
|
|
443
|
-
def __init__(self, name: str, length: int) -> None: ...
|
|
483
|
+
def __init__(self, span: Span, name: str, length: int) -> None: ...
|
|
444
484
|
|
|
445
485
|
@final
|
|
446
486
|
class PrintSize:
|
egglog/builtins.py
CHANGED
|
@@ -6,17 +6,21 @@ Builtin sorts and function to egg.
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
8
|
from functools import partial
|
|
9
|
-
from
|
|
9
|
+
from types import FunctionType
|
|
10
|
+
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
|
|
10
11
|
|
|
11
12
|
from typing_extensions import TypeVarTuple, Unpack
|
|
12
13
|
|
|
13
|
-
from .conversion import converter
|
|
14
|
-
from .egraph import Expr, Unit, function, method
|
|
15
|
-
from .
|
|
14
|
+
from .conversion import converter, get_type_args
|
|
15
|
+
from .egraph import Expr, Unit, function, get_current_ruleset, method
|
|
16
|
+
from .functionalize import functionalize
|
|
17
|
+
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
|
|
18
|
+
from .thunk import Thunk
|
|
16
19
|
|
|
17
20
|
if TYPE_CHECKING:
|
|
18
21
|
from collections.abc import Callable
|
|
19
22
|
|
|
23
|
+
|
|
20
24
|
__all__ = [
|
|
21
25
|
"i64",
|
|
22
26
|
"i64Like",
|
|
@@ -80,7 +84,7 @@ class Bool(Expr, egg_sort="bool", builtin=True):
|
|
|
80
84
|
converter(bool, Bool, Bool)
|
|
81
85
|
|
|
82
86
|
# The types which can be convertered into an i64
|
|
83
|
-
i64Like = Union["i64", int] # noqa: N816
|
|
87
|
+
i64Like: TypeAlias = Union["i64", int] # noqa: N816, PYI042
|
|
84
88
|
|
|
85
89
|
|
|
86
90
|
class i64(Expr, builtin=True): # noqa: N801
|
|
@@ -182,7 +186,7 @@ converter(int, i64, i64)
|
|
|
182
186
|
def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
|
|
183
187
|
|
|
184
188
|
|
|
185
|
-
f64Like = Union["f64", float] # noqa: N816
|
|
189
|
+
f64Like: TypeAlias = Union["f64", float] # noqa: N816, PYI042
|
|
186
190
|
|
|
187
191
|
|
|
188
192
|
class f64(Expr, builtin=True): # noqa: N801
|
|
@@ -404,6 +408,12 @@ class Vec(Expr, Generic[T], builtin=True):
|
|
|
404
408
|
@method(egg_fn="rebuild")
|
|
405
409
|
def rebuild(self) -> Vec[T]: ...
|
|
406
410
|
|
|
411
|
+
@method(egg_fn="vec-remove")
|
|
412
|
+
def remove(self, index: i64Like) -> Vec[T]: ...
|
|
413
|
+
|
|
414
|
+
@method(egg_fn="vec-set")
|
|
415
|
+
def set(self, index: i64Like, value: T) -> Vec[T]: ...
|
|
416
|
+
|
|
407
417
|
|
|
408
418
|
class PyObject(Expr, builtin=True):
|
|
409
419
|
def __init__(self, value: object) -> None: ...
|
|
@@ -501,3 +511,36 @@ class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
|
|
|
501
511
|
|
|
502
512
|
converter(RuntimeFunction, UnstableFn, UnstableFn)
|
|
503
513
|
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def _convert_function(a: FunctionType) -> UnstableFn:
|
|
517
|
+
"""
|
|
518
|
+
Converts a function type to an unstable function
|
|
519
|
+
|
|
520
|
+
Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
|
|
521
|
+
which are runtime expressions with `var`s in them and add them as args to the function
|
|
522
|
+
"""
|
|
523
|
+
# Update annotations of a to be the type we are trying to convert to
|
|
524
|
+
return_tp, *arg_tps = get_type_args()
|
|
525
|
+
a.__annotations__ = {
|
|
526
|
+
"return": return_tp,
|
|
527
|
+
# The first varnames should always be the arg names
|
|
528
|
+
**dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
|
|
529
|
+
}
|
|
530
|
+
# Modify name to make it unique
|
|
531
|
+
# a.__name__ = f"{a.__name__} {hash(a.__code__)}"
|
|
532
|
+
transformed_fn = functionalize(a, value_to_annotation)
|
|
533
|
+
assert isinstance(transformed_fn, partial)
|
|
534
|
+
return UnstableFn(
|
|
535
|
+
function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def value_to_annotation(a: object) -> type | None:
|
|
540
|
+
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
|
|
541
|
+
if not isinstance(a, RuntimeExpr):
|
|
542
|
+
return None
|
|
543
|
+
return cast(type, RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
converter(FunctionType, UnstableFn, _convert_function)
|
egglog/conversion.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from contextvars import ContextVar
|
|
3
5
|
from dataclasses import dataclass
|
|
4
6
|
from typing import TYPE_CHECKING, NewType, TypeVar, cast
|
|
5
7
|
|
|
@@ -9,9 +11,8 @@ from .runtime import *
|
|
|
9
11
|
from .thunk import *
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
12
|
-
from collections.abc import Callable
|
|
14
|
+
from collections.abc import Callable, Generator
|
|
13
15
|
|
|
14
|
-
from .declarations import HasDeclerations
|
|
15
16
|
from .egraph import Expr
|
|
16
17
|
|
|
17
18
|
__all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
|
|
@@ -84,7 +85,7 @@ def convert(source: object, target: type[V]) -> V:
|
|
|
84
85
|
Convert a source object to a target type.
|
|
85
86
|
"""
|
|
86
87
|
assert isinstance(target, RuntimeClass)
|
|
87
|
-
return cast(V, resolve_literal(target.__egg_tp__, source))
|
|
88
|
+
return cast(V, resolve_literal(target.__egg_tp__, source, target.__egg_decls_thunk__))
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
@@ -92,7 +93,7 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
|
92
93
|
Convert a source object to the same type as the target.
|
|
93
94
|
"""
|
|
94
95
|
tp = target.__egg_typed_expr__.tp
|
|
95
|
-
return resolve_literal(tp.to_var(), source)
|
|
96
|
+
return resolve_literal(tp.to_var(), source, Thunk.value(target.__egg_decls__))
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
def process_tp(tp: type | RuntimeClass) -> TypeName | type:
|
|
@@ -140,7 +141,28 @@ def identity(x: object) -> object:
|
|
|
140
141
|
return x
|
|
141
142
|
|
|
142
143
|
|
|
143
|
-
|
|
144
|
+
TYPE_ARGS = ContextVar[tuple[RuntimeClass, ...]]("TYPE_ARGS")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_type_args() -> tuple[RuntimeClass, ...]:
|
|
148
|
+
"""
|
|
149
|
+
Get the type args for the type being converted.
|
|
150
|
+
"""
|
|
151
|
+
return TYPE_ARGS.get()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@contextmanager
|
|
155
|
+
def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declarations]) -> Generator[None, None, None]:
|
|
156
|
+
token = TYPE_ARGS.set(tuple(RuntimeClass(decls, a.to_var()) for a in args))
|
|
157
|
+
try:
|
|
158
|
+
yield
|
|
159
|
+
finally:
|
|
160
|
+
TYPE_ARGS.reset(token)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def resolve_literal(
|
|
164
|
+
tp: TypeOrVarRef, arg: object, decls: Callable[[], Declarations] = CONVERSIONS_DECLS
|
|
165
|
+
) -> RuntimeExpr:
|
|
144
166
|
arg_type = _get_tp(arg)
|
|
145
167
|
|
|
146
168
|
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
|
|
@@ -148,7 +170,7 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
|
148
170
|
tp_just = tp.to_just()
|
|
149
171
|
except NotImplementedError:
|
|
150
172
|
# If this is a var, it has to be a runtime expession
|
|
151
|
-
assert isinstance(arg, RuntimeExpr)
|
|
173
|
+
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
|
|
152
174
|
return arg
|
|
153
175
|
tp_name = TypeName(tp_just.name)
|
|
154
176
|
if arg_type == tp_name:
|
|
@@ -158,13 +180,14 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
|
158
180
|
# Try all parent types as well, if we are converting from a Python type
|
|
159
181
|
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
|
|
160
182
|
try:
|
|
161
|
-
fn = CONVERSIONS[(
|
|
183
|
+
fn = CONVERSIONS[(arg_type_instance, tp_name)][1]
|
|
162
184
|
except KeyError:
|
|
163
185
|
continue
|
|
164
186
|
break
|
|
165
187
|
else:
|
|
166
188
|
raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
|
|
167
|
-
|
|
189
|
+
with with_type_args(tp_just.args, decls):
|
|
190
|
+
return fn(arg)
|
|
168
191
|
|
|
169
192
|
|
|
170
193
|
def _get_tp(x: object) -> TypeName | type:
|
egglog/declarations.py
CHANGED
|
@@ -13,10 +13,11 @@ from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_c
|
|
|
13
13
|
from typing_extensions import Self, assert_never
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from collections.abc import Callable, Iterable
|
|
16
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
|
+
"replace_typed_expr",
|
|
20
21
|
"Declarations",
|
|
21
22
|
"DeclerationsLike",
|
|
22
23
|
"DelayedDeclerations",
|
|
@@ -29,6 +30,7 @@ __all__ = [
|
|
|
29
30
|
"MethodRef",
|
|
30
31
|
"ClassMethodRef",
|
|
31
32
|
"FunctionRef",
|
|
33
|
+
"UnnamedFunctionRef",
|
|
32
34
|
"ConstantRef",
|
|
33
35
|
"ClassVariableRef",
|
|
34
36
|
"PropertyRef",
|
|
@@ -73,6 +75,7 @@ __all__ = [
|
|
|
73
75
|
"FunctionSignature",
|
|
74
76
|
"DefaultRewriteDecl",
|
|
75
77
|
"InitRef",
|
|
78
|
+
"HasDeclerations",
|
|
76
79
|
]
|
|
77
80
|
|
|
78
81
|
|
|
@@ -82,12 +85,13 @@ class DelayedDeclerations:
|
|
|
82
85
|
|
|
83
86
|
@property
|
|
84
87
|
def __egg_decls__(self) -> Declarations:
|
|
88
|
+
thunk = self.__egg_decls_thunk__
|
|
85
89
|
try:
|
|
86
|
-
return
|
|
90
|
+
return thunk()
|
|
87
91
|
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
|
|
88
92
|
# instead raise explicitly
|
|
89
93
|
except AttributeError as err:
|
|
90
|
-
msg = "
|
|
94
|
+
msg = f"Cannot resolve declerations for {self}"
|
|
91
95
|
raise RuntimeError(msg) from err
|
|
92
96
|
|
|
93
97
|
|
|
@@ -116,6 +120,7 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D
|
|
|
116
120
|
|
|
117
121
|
@dataclass
|
|
118
122
|
class Declarations:
|
|
123
|
+
_unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
|
|
119
124
|
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
|
|
120
125
|
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
121
126
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
@@ -192,6 +197,8 @@ class Declarations:
|
|
|
192
197
|
init_fn = self._classes[class_name].init
|
|
193
198
|
assert init_fn
|
|
194
199
|
return init_fn
|
|
200
|
+
case UnnamedFunctionRef():
|
|
201
|
+
return ref.to_function_decl()
|
|
195
202
|
assert_never(ref)
|
|
196
203
|
|
|
197
204
|
def set_function_decl(
|
|
@@ -318,6 +325,37 @@ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
|
|
|
318
325
|
##
|
|
319
326
|
|
|
320
327
|
|
|
328
|
+
@dataclass(frozen=True)
|
|
329
|
+
class UnnamedFunctionRef:
|
|
330
|
+
"""
|
|
331
|
+
A reference to a function that doesn't have a name, but does have a body.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
# tuple of var arg names and their types
|
|
335
|
+
args: tuple[TypedExprDecl, ...]
|
|
336
|
+
res: TypedExprDecl
|
|
337
|
+
|
|
338
|
+
def to_function_decl(self) -> FunctionDecl:
|
|
339
|
+
arg_types = []
|
|
340
|
+
arg_names = []
|
|
341
|
+
for a in self.args:
|
|
342
|
+
arg_types.append(a.tp.to_var())
|
|
343
|
+
assert isinstance(a.expr, VarDecl)
|
|
344
|
+
arg_names.append(a.expr.name)
|
|
345
|
+
return FunctionDecl(
|
|
346
|
+
FunctionSignature(
|
|
347
|
+
arg_types=tuple(arg_types),
|
|
348
|
+
arg_names=tuple(arg_names),
|
|
349
|
+
arg_defaults=(None,) * len(self.args),
|
|
350
|
+
return_type=self.res.tp.to_var(),
|
|
351
|
+
),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def egg_name(self) -> None | str:
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
|
|
321
359
|
@dataclass(frozen=True)
|
|
322
360
|
class FunctionRef:
|
|
323
361
|
name: str
|
|
@@ -358,7 +396,14 @@ class PropertyRef:
|
|
|
358
396
|
|
|
359
397
|
|
|
360
398
|
CallableRef: TypeAlias = (
|
|
361
|
-
FunctionRef
|
|
399
|
+
FunctionRef
|
|
400
|
+
| ConstantRef
|
|
401
|
+
| MethodRef
|
|
402
|
+
| ClassMethodRef
|
|
403
|
+
| InitRef
|
|
404
|
+
| ClassVariableRef
|
|
405
|
+
| PropertyRef
|
|
406
|
+
| UnnamedFunctionRef
|
|
362
407
|
)
|
|
363
408
|
|
|
364
409
|
|
|
@@ -455,6 +500,8 @@ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
|
|
|
455
500
|
@dataclass(frozen=True)
|
|
456
501
|
class VarDecl:
|
|
457
502
|
name: str
|
|
503
|
+
# Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
|
|
504
|
+
is_let: bool
|
|
458
505
|
|
|
459
506
|
|
|
460
507
|
@dataclass(frozen=True)
|
|
@@ -561,6 +608,38 @@ class TypedExprDecl:
|
|
|
561
608
|
return l
|
|
562
609
|
|
|
563
610
|
|
|
611
|
+
def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
|
|
612
|
+
"""
|
|
613
|
+
Replace all the typed expressions in the given typed expression with the replacements.
|
|
614
|
+
"""
|
|
615
|
+
# keep track of the traversed expressions for memoization
|
|
616
|
+
traversed: dict[TypedExprDecl, TypedExprDecl] = {}
|
|
617
|
+
|
|
618
|
+
def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
|
|
619
|
+
if typed_expr in traversed:
|
|
620
|
+
return traversed[typed_expr]
|
|
621
|
+
if typed_expr in replacements:
|
|
622
|
+
res = replacements[typed_expr]
|
|
623
|
+
else:
|
|
624
|
+
match typed_expr.expr:
|
|
625
|
+
case (
|
|
626
|
+
CallDecl(callable, args, bound_tp_params)
|
|
627
|
+
| PartialCallDecl(CallDecl(callable, args, bound_tp_params))
|
|
628
|
+
):
|
|
629
|
+
new_args = tuple(_inner(a) for a in args)
|
|
630
|
+
call_decl = CallDecl(callable, new_args, bound_tp_params)
|
|
631
|
+
res = TypedExprDecl(
|
|
632
|
+
typed_expr.tp,
|
|
633
|
+
call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
|
|
634
|
+
)
|
|
635
|
+
case _:
|
|
636
|
+
res = typed_expr
|
|
637
|
+
traversed[typed_expr] = res
|
|
638
|
+
return res
|
|
639
|
+
|
|
640
|
+
return _inner(typed_expr)
|
|
641
|
+
|
|
642
|
+
|
|
564
643
|
##
|
|
565
644
|
# Schedules
|
|
566
645
|
##
|