egglog 0.4.0__pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +5 -0
- egglog/bindings.pyi +415 -0
- egglog/bindings.pypy310-pp73-x86_64-linux-gnu.so +0 -0
- egglog/builtins.py +345 -0
- egglog/config.py +8 -0
- egglog/declarations.py +934 -0
- egglog/egraph.py +1041 -0
- egglog/examples/README.rst +5 -0
- egglog/examples/__init__.py +0 -0
- egglog/examples/eqsat_basic.py +43 -0
- egglog/examples/fib.py +28 -0
- egglog/examples/lambda.py +310 -0
- egglog/examples/matrix.py +184 -0
- egglog/examples/ndarrays.py +159 -0
- egglog/examples/resolution.py +84 -0
- egglog/examples/schedule_demo.py +33 -0
- egglog/ipython_magic.py +40 -0
- egglog/monkeypatch.py +33 -0
- egglog/py.typed +0 -0
- egglog/runtime.py +304 -0
- egglog/type_constraint_solver.py +79 -0
- egglog-0.4.0.dist-info/METADATA +53 -0
- egglog-0.4.0.dist-info/RECORD +25 -0
- egglog-0.4.0.dist-info/WHEEL +4 -0
- egglog-0.4.0.dist-info/license_files/LICENSE +21 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
N-Dimensional Arrays
|
|
3
|
+
====================
|
|
4
|
+
|
|
5
|
+
Example of building NDarray in the vein of Mathemetics of Arrays.
|
|
6
|
+
"""
|
|
7
|
+
# mypy: disable-error-code=empty-body
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from egglog import *
|
|
11
|
+
|
|
12
|
+
egraph = EGraph()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@egraph.class_
|
|
16
|
+
class Value(BaseExpr):
|
|
17
|
+
def __init__(self, v: i64Like) -> None:
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def __mul__(self, other: Value) -> Value:
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
def __add__(self, other: Value) -> Value:
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
i, j = vars_("i j", i64)
|
|
28
|
+
egraph.register(
|
|
29
|
+
rewrite(Value(i) * Value(j)).to(Value(i * j)),
|
|
30
|
+
rewrite(Value(i) + Value(j)).to(Value(i + j)),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@egraph.class_
|
|
35
|
+
class Values(BaseExpr):
|
|
36
|
+
def __init__(self, v: Vec[Value]) -> None:
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
def __getitem__(self, idx: Value) -> Value:
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
def length(self) -> Value:
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
def concat(self, other: Values) -> Values:
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@egraph.register
|
|
50
|
+
def _values(vs: Vec[Value], other: Vec[Value]):
|
|
51
|
+
yield rewrite(Values(vs)[Value(i)]).to(vs[i])
|
|
52
|
+
yield rewrite(Values(vs).length()).to(Value(vs.length()))
|
|
53
|
+
yield rewrite(Values(vs).concat(Values(other))).to(Values(vs.append(other)))
|
|
54
|
+
# yield rewrite(l.concat(r).length()).to(l.length() + r.length())
|
|
55
|
+
# yield rewrite(l.concat(r)[idx])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@egraph.class_
|
|
59
|
+
class NDArray(BaseExpr):
|
|
60
|
+
"""
|
|
61
|
+
An n-dimensional array.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __getitem__(self, idx: Values) -> Value:
|
|
65
|
+
...
|
|
66
|
+
|
|
67
|
+
def shape(self) -> Values:
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@egraph.function
|
|
72
|
+
def arange(n: Value) -> NDArray:
|
|
73
|
+
...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@egraph.register
|
|
77
|
+
def _ndarray_arange(n: Value, idx: Values):
|
|
78
|
+
yield rewrite(arange(n).shape()).to(Values(Vec(n)))
|
|
79
|
+
yield rewrite(arange(n)[idx]).to(idx[Value(0)])
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Simplify and print
|
|
85
|
+
"""
|
|
86
|
+
egraph.register(left)
|
|
87
|
+
egraph.run(30)
|
|
88
|
+
res = egraph.extract(left)
|
|
89
|
+
print(f"{left} == {right} ➡ {res}")
|
|
90
|
+
egraph.check(eq(left).to(right))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
assert_simplifies(arange(Value(10)).shape(), Values(Vec(Value(10))))
|
|
94
|
+
assert_simplifies(arange(Value(10))[Values(Vec(Value(0)))], Value(0))
|
|
95
|
+
assert_simplifies(arange(Value(10))[Values(Vec(Value(1)))], Value(1))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@egraph.function
|
|
99
|
+
def py_value(s: StringLike) -> Value:
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@egraph.register
|
|
104
|
+
def _py_value(l: String, r: String):
|
|
105
|
+
yield rewrite(py_value(l) + py_value(r)).to(py_value(join(l, " + ", r)))
|
|
106
|
+
yield rewrite(py_value(l) * py_value(r)).to(py_value(join(l, " * ", r)))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@egraph.function
|
|
110
|
+
def py_values(s: StringLike) -> Values:
|
|
111
|
+
...
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@egraph.register
|
|
115
|
+
def _py_values(l: String, r: String):
|
|
116
|
+
yield rewrite(py_values(l)[py_value(r)]).to(py_value(join(l, "[", r, "]")))
|
|
117
|
+
yield rewrite(py_values(l).length()).to(py_value(join("len(", l, ")")))
|
|
118
|
+
yield rewrite(py_values(l).concat(py_values(r))).to(py_values(join(l, " + ", r)))
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@egraph.function
|
|
122
|
+
def py_ndarray(s: StringLike) -> NDArray:
|
|
123
|
+
...
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@egraph.register
|
|
127
|
+
def _py_ndarray(l: String, r: String):
|
|
128
|
+
yield rewrite(py_ndarray(l)[py_values(r)]).to(py_value(join(l, "[", r, "]")))
|
|
129
|
+
yield rewrite(py_ndarray(l).shape()).to(py_values(join(l, ".shape")))
|
|
130
|
+
yield rewrite(arange(py_value(l))).to(py_ndarray(join("np.arange(", l, ")")))
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
assert_simplifies(py_ndarray("x").shape(), py_values("x.shape"))
|
|
134
|
+
assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("np.arange(x)[y]"))
|
|
135
|
+
# assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("y[0]"))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@egraph.function
|
|
139
|
+
def cross(l: NDArray, r: NDArray) -> NDArray:
|
|
140
|
+
...
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@egraph.register
|
|
144
|
+
def _cross(l: NDArray, r: NDArray, idx: Values):
|
|
145
|
+
yield rewrite(cross(l, r).shape()).to(l.shape().concat(r.shape()))
|
|
146
|
+
yield rewrite(cross(l, r)[idx]).to(l[idx] * r[idx])
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
assert_simplifies(cross(arange(Value(10)), arange(Value(11))).shape(), Values(Vec(Value(10), Value(11))))
|
|
150
|
+
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y")).shape(), py_values("x.shape + y.shape"))
|
|
151
|
+
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("x[idx] * y[idx]"))
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@egraph.register
|
|
155
|
+
def _cross_py(l: String, r: String):
|
|
156
|
+
yield rewrite(cross(py_ndarray(l), py_ndarray(r))).to(py_ndarray(join("np.multiply.outer(", l, ", ", r, ")")))
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("np.multiply.outer(x, y)[idx]"))
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Resolution theorem proving.
|
|
3
|
+
===========================
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import ClassVar
|
|
8
|
+
|
|
9
|
+
from egglog import *
|
|
10
|
+
|
|
11
|
+
egraph = EGraph()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@egraph.class_
|
|
15
|
+
class Bool(BaseExpr):
|
|
16
|
+
FALSE: ClassVar[Bool]
|
|
17
|
+
|
|
18
|
+
def __or__(self, other: Bool) -> Bool: # type: ignore[empty-body]
|
|
19
|
+
...
|
|
20
|
+
|
|
21
|
+
def __invert__(self) -> Bool: # type: ignore[empty-body]
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Show off two ways of creating constants, either as top level values or as classvars
|
|
26
|
+
T = egraph.constant("T", Bool)
|
|
27
|
+
F = Bool.FALSE
|
|
28
|
+
|
|
29
|
+
p, a, b, c, as_, bs = vars_("p a b c as bs", Bool)
|
|
30
|
+
egraph.register(
|
|
31
|
+
# clauses are assumed in the normal form (or a (or b (or c False)))
|
|
32
|
+
set_(~F).to(T),
|
|
33
|
+
set_(~T).to(F),
|
|
34
|
+
# "Solving" negation equations
|
|
35
|
+
rule(eq(~p).to(T)).then(union(p).with_(F)),
|
|
36
|
+
rule(eq(~p).to(F)).then(union(p).with_(T)),
|
|
37
|
+
# canonicalize associtivity. "append" for clauses terminate with false
|
|
38
|
+
rewrite((a | b) | c).to(a | (b | c)),
|
|
39
|
+
# commutativity
|
|
40
|
+
rewrite(a | (b | c)).to(b | (a | c)),
|
|
41
|
+
# absoprtion
|
|
42
|
+
rewrite(a | (a | b)).to(a | b),
|
|
43
|
+
rewrite(a | (~a | b)).to(T),
|
|
44
|
+
# Simplification
|
|
45
|
+
rewrite(F | a).to(a),
|
|
46
|
+
rewrite(a | F).to(a),
|
|
47
|
+
rewrite(T | a).to(T),
|
|
48
|
+
rewrite(a | T).to(T),
|
|
49
|
+
# unit propagation
|
|
50
|
+
# This is kind of interesting actually.
|
|
51
|
+
# Looks a bit like equation solving
|
|
52
|
+
rule(eq(T).to(p | F)).then(union(p).with_(T)),
|
|
53
|
+
# resolution
|
|
54
|
+
# This counts on commutativity to bubble everything possible up to the front of the clause.
|
|
55
|
+
rule(
|
|
56
|
+
eq(T).to(a | as_),
|
|
57
|
+
eq(T).to(~a | bs),
|
|
58
|
+
).then(
|
|
59
|
+
set_(as_ | bs).to(T),
|
|
60
|
+
),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Example predicate
|
|
65
|
+
@egraph.function
|
|
66
|
+
def pred(x: i64Like) -> Bool: # type: ignore[empty-body]
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
p0 = egraph.define("p0", pred(0))
|
|
71
|
+
p1 = egraph.define("p1", pred(1))
|
|
72
|
+
p2 = egraph.define("p2", pred(2))
|
|
73
|
+
egraph.register(
|
|
74
|
+
set_(p1 | (~p2 | F)).to(T),
|
|
75
|
+
set_(p2 | (~p0 | F)).to(T),
|
|
76
|
+
set_(p0 | (~p1 | F)).to(T),
|
|
77
|
+
union(p1).with_(F),
|
|
78
|
+
set_(~p0 | (~p1 | (p2 | F))).to(T),
|
|
79
|
+
)
|
|
80
|
+
egraph.run(10)
|
|
81
|
+
egraph.check(T != F)
|
|
82
|
+
egraph.check(eq(p0).to(F))
|
|
83
|
+
egraph.check(eq(p2).to(F))
|
|
84
|
+
egraph
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schedule demo
|
|
3
|
+
=============
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from egglog import *
|
|
8
|
+
|
|
9
|
+
egraph = EGraph()
|
|
10
|
+
|
|
11
|
+
left = egraph.relation("left", i64)
|
|
12
|
+
right = egraph.relation("right", i64)
|
|
13
|
+
|
|
14
|
+
egraph.register(left(i64(0)), right(i64(0)))
|
|
15
|
+
|
|
16
|
+
x, y = vars_("x y", i64)
|
|
17
|
+
|
|
18
|
+
step_left = egraph.ruleset("step-left")
|
|
19
|
+
egraph.register(rule(left(x), right(x), ruleset=step_left).then(left(x + 1)))
|
|
20
|
+
|
|
21
|
+
step_right = egraph.ruleset("step-right")
|
|
22
|
+
egraph.register(rule(left(x), right(y), eq(x).to(y + 1), ruleset=step_right).then(right(x)))
|
|
23
|
+
|
|
24
|
+
egraph.run(
|
|
25
|
+
seq(
|
|
26
|
+
run(step_right).saturate(),
|
|
27
|
+
run(step_left).saturate(),
|
|
28
|
+
)
|
|
29
|
+
* 10
|
|
30
|
+
)
|
|
31
|
+
egraph.check(left(i64(10)), right(i64(9)))
|
|
32
|
+
egraph.check_fail(left(i64(11)), right(i64(10)))
|
|
33
|
+
egraph
|
egglog/ipython_magic.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from .bindings import EGraph
|
|
2
|
+
|
|
3
|
+
EGRAPH_VAR = "_MAGIC_EGRAPH"
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
get_ipython() # type: ignore[name-defined]
|
|
7
|
+
in_ipython = True
|
|
8
|
+
except NameError:
|
|
9
|
+
in_ipython = False
|
|
10
|
+
|
|
11
|
+
if in_ipython:
|
|
12
|
+
import graphviz
|
|
13
|
+
from IPython.core.magic import needs_local_scope, register_cell_magic
|
|
14
|
+
|
|
15
|
+
@needs_local_scope
|
|
16
|
+
@register_cell_magic
|
|
17
|
+
def egglog(line, cell, local_ns):
|
|
18
|
+
"""
|
|
19
|
+
Run an egglog program
|
|
20
|
+
|
|
21
|
+
Usage:
|
|
22
|
+
|
|
23
|
+
%%egglog [output] [continue] [graph]
|
|
24
|
+
(egglog program)
|
|
25
|
+
|
|
26
|
+
If `output` is specified, the output of the program will be printed.
|
|
27
|
+
If `continue` is specified, the program will be run in the same EGraph as the previous cell.
|
|
28
|
+
If `graph` is specified, the EGraph will be displayed as a graph.
|
|
29
|
+
"""
|
|
30
|
+
if EGRAPH_VAR in local_ns and "continue" in line:
|
|
31
|
+
e = local_ns[EGRAPH_VAR]
|
|
32
|
+
else:
|
|
33
|
+
e = EGraph()
|
|
34
|
+
local_ns[EGRAPH_VAR] = e
|
|
35
|
+
cmds = e.parse_program(cell)
|
|
36
|
+
res = e.run_program(*cmds)
|
|
37
|
+
if "output" in line:
|
|
38
|
+
print("\n".join(res))
|
|
39
|
+
if "graph" in line:
|
|
40
|
+
return graphviz.Source(e.to_graphviz_string())
|
egglog/monkeypatch.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
__all__ = ["monkeypatch_forward_ref"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def monkeypatch_forward_ref():
|
|
8
|
+
"""
|
|
9
|
+
Monkeypatch to backport https://github.com/python/cpython/pull/21553.
|
|
10
|
+
Removed recursive gaurd for simplicity
|
|
11
|
+
Can be removed once Pytho 3.8 is no longer supported
|
|
12
|
+
"""
|
|
13
|
+
if sys.version_info >= (3, 9):
|
|
14
|
+
return
|
|
15
|
+
typing.ForwardRef._evaluate = _evaluate_monkeypatch # type: ignore
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _evaluate_monkeypatch(self, globalns, localns):
|
|
19
|
+
if not self.__forward_evaluated__ or localns is not globalns:
|
|
20
|
+
if globalns is None and localns is None:
|
|
21
|
+
globalns = localns = {}
|
|
22
|
+
elif globalns is None:
|
|
23
|
+
globalns = localns
|
|
24
|
+
elif localns is None:
|
|
25
|
+
localns = globalns
|
|
26
|
+
type_ = typing._type_check( # type: ignore
|
|
27
|
+
eval(self.__forward_code__, globalns, localns),
|
|
28
|
+
"Forward references must evaluate to types.",
|
|
29
|
+
is_argument=self.__forward_is_argument__,
|
|
30
|
+
)
|
|
31
|
+
self.__forward_value__ = typing._eval_type(type_, globalns, localns) # type: ignore
|
|
32
|
+
self.__forward_evaluated__ = True
|
|
33
|
+
return self.__forward_value__
|
egglog/py.typed
ADDED
|
File without changes
|
egglog/runtime.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module holds a number of types which are only used at runtime to emulate Python objects.
|
|
3
|
+
|
|
4
|
+
Users will not import anything from this module, and statically they won't know these are the types they are using.
|
|
5
|
+
|
|
6
|
+
But at runtime they will be exposed.
|
|
7
|
+
|
|
8
|
+
Note that all their internal fields are prefixed with __egg_ to avoid name collisions with user code, but will end in __
|
|
9
|
+
so they are not mangled by Python and can be accessed by the user.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import Collection, Iterable, Optional, Union
|
|
16
|
+
|
|
17
|
+
import black
|
|
18
|
+
from typing_extensions import assert_never
|
|
19
|
+
|
|
20
|
+
from . import config # noqa: F401
|
|
21
|
+
from .declarations import *
|
|
22
|
+
from .declarations import BINARY_METHODS, UNARY_METHODS
|
|
23
|
+
from .type_constraint_solver import *
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"LIT_CLASS_NAMES",
|
|
27
|
+
"RuntimeClass",
|
|
28
|
+
"RuntimeParamaterizedClass",
|
|
29
|
+
"RuntimeClassMethod",
|
|
30
|
+
"RuntimeExpr",
|
|
31
|
+
"RuntimeFunction",
|
|
32
|
+
"ArgType",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
BLACK_MODE = black.Mode(line_length=120) # type: ignore
|
|
37
|
+
|
|
38
|
+
UNIT_CLASS_NAME = "Unit"
|
|
39
|
+
UNARY_LIT_CLASS_NAMES = {"i64", "f64", "String"}
|
|
40
|
+
LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class RuntimeClass:
|
|
45
|
+
__egg_decls__: ModuleDeclarations
|
|
46
|
+
__egg_name__: str
|
|
47
|
+
|
|
48
|
+
def __call__(self, *args: ArgType) -> RuntimeExpr:
|
|
49
|
+
"""
|
|
50
|
+
Create an instance of this kind by calling the __init__ classmethod
|
|
51
|
+
"""
|
|
52
|
+
# If this is a literal type, initializing it with a literal should return a literal
|
|
53
|
+
if self.__egg_name__ in UNARY_LIT_CLASS_NAMES:
|
|
54
|
+
assert len(args) == 1
|
|
55
|
+
assert isinstance(args[0], (int, float, str))
|
|
56
|
+
return RuntimeExpr(self.__egg_decls__, TypedExprDecl(JustTypeRef(self.__egg_name__), LitDecl(args[0])))
|
|
57
|
+
if self.__egg_name__ == UNIT_CLASS_NAME:
|
|
58
|
+
assert len(args) == 0
|
|
59
|
+
return RuntimeExpr(self.__egg_decls__, TypedExprDecl(JustTypeRef(self.__egg_name__), LitDecl(None)))
|
|
60
|
+
|
|
61
|
+
return RuntimeClassMethod(self.__egg_decls__, self.__egg_name__, "__init__")(*args)
|
|
62
|
+
|
|
63
|
+
def __dir__(self) -> list[str]:
|
|
64
|
+
cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
|
|
65
|
+
possible_methods = list(cls_decl.class_methods) + list(cls_decl.class_variables)
|
|
66
|
+
if "__init__" in possible_methods:
|
|
67
|
+
possible_methods.remove("__init__")
|
|
68
|
+
possible_methods.append("__call__")
|
|
69
|
+
return possible_methods
|
|
70
|
+
|
|
71
|
+
def __getitem__(self, args: tuple[RuntimeTypeArgType, ...] | RuntimeTypeArgType) -> RuntimeParamaterizedClass:
|
|
72
|
+
if not isinstance(args, tuple):
|
|
73
|
+
args = (args,)
|
|
74
|
+
tp = JustTypeRef(self.__egg_name__, tuple(class_to_ref(arg) for arg in args))
|
|
75
|
+
return RuntimeParamaterizedClass(self.__egg_decls__, tp)
|
|
76
|
+
|
|
77
|
+
def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr:
|
|
78
|
+
cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
|
|
79
|
+
# if this is a class variable, return an expr for it, otherwise, assume it's a method
|
|
80
|
+
if name in cls_decl.class_variables:
|
|
81
|
+
return_tp = cls_decl.class_variables[name]
|
|
82
|
+
return RuntimeExpr(
|
|
83
|
+
self.__egg_decls__, TypedExprDecl(return_tp, CallDecl(ClassVariableRef(self.__egg_name__, name)))
|
|
84
|
+
)
|
|
85
|
+
return RuntimeClassMethod(self.__egg_decls__, self.__egg_name__, name)
|
|
86
|
+
|
|
87
|
+
def __str__(self) -> str:
|
|
88
|
+
return self.__egg_name__
|
|
89
|
+
|
|
90
|
+
# Make hashable so can go in Union
|
|
91
|
+
def __hash__(self) -> int:
|
|
92
|
+
return hash((id(self.__egg_decls__), self.__egg_name__))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class RuntimeParamaterizedClass:
|
|
97
|
+
__egg_decls__: ModuleDeclarations
|
|
98
|
+
# Note that this will never be a typevar because we don't use RuntimeParamaterizedClass for maps on their own methods
|
|
99
|
+
# which is the only time we define function which take typevars
|
|
100
|
+
__egg_tp__: JustTypeRef
|
|
101
|
+
|
|
102
|
+
def __post_init__(self):
|
|
103
|
+
desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).n_type_vars
|
|
104
|
+
if len(self.__egg_tp__.args) != desired_args:
|
|
105
|
+
raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
|
|
106
|
+
|
|
107
|
+
def __call__(self, *args: ArgType) -> RuntimeExpr:
|
|
108
|
+
return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args)
|
|
109
|
+
|
|
110
|
+
def __getattr__(self, name: str) -> RuntimeClassMethod:
|
|
111
|
+
return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name)
|
|
112
|
+
|
|
113
|
+
def __str__(self) -> str:
|
|
114
|
+
return self.__egg_tp__.pretty()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# Type args can either be typevars or classes
|
|
118
|
+
RuntimeTypeArgType = Union[RuntimeClass, RuntimeParamaterizedClass]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef:
|
|
122
|
+
if isinstance(cls, RuntimeClass):
|
|
123
|
+
return JustTypeRef(cls.__egg_name__)
|
|
124
|
+
if isinstance(cls, RuntimeParamaterizedClass):
|
|
125
|
+
return cls.__egg_tp__
|
|
126
|
+
assert_never(cls)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class RuntimeFunction:
|
|
131
|
+
__egg_decls__: ModuleDeclarations
|
|
132
|
+
__egg_name__: str
|
|
133
|
+
__egg_fn_ref__: FunctionRef = field(init=False)
|
|
134
|
+
__egg_fn_decl__: FunctionDecl = field(init=False)
|
|
135
|
+
|
|
136
|
+
def __post_init__(self):
|
|
137
|
+
self.__egg_fn_ref__ = FunctionRef(self.__egg_name__)
|
|
138
|
+
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__)
|
|
139
|
+
|
|
140
|
+
def __call__(self, *args: ArgType) -> RuntimeExpr:
|
|
141
|
+
return _call(self.__egg_decls__, self.__egg_fn_ref__, self.__egg_fn_decl__, args)
|
|
142
|
+
|
|
143
|
+
def __str__(self) -> str:
|
|
144
|
+
return self.__egg_name__
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _call(
|
|
148
|
+
decls: ModuleDeclarations,
|
|
149
|
+
callable_ref: CallableRef,
|
|
150
|
+
# Not included if this is the != method
|
|
151
|
+
fn_decl: Optional[FunctionDecl],
|
|
152
|
+
args: Collection[ArgType],
|
|
153
|
+
bound_params: Optional[tuple[JustTypeRef, ...]] = None,
|
|
154
|
+
) -> RuntimeExpr:
|
|
155
|
+
upcasted_args = [_resolve_literal(decls, arg) for arg in args]
|
|
156
|
+
|
|
157
|
+
arg_types = [arg.__egg_typed_expr__.tp for arg in upcasted_args]
|
|
158
|
+
|
|
159
|
+
if bound_params is not None:
|
|
160
|
+
tcs = TypeConstraintSolver.from_type_parameters(bound_params)
|
|
161
|
+
else:
|
|
162
|
+
tcs = TypeConstraintSolver()
|
|
163
|
+
|
|
164
|
+
if fn_decl is not None:
|
|
165
|
+
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types)
|
|
166
|
+
else:
|
|
167
|
+
return_tp = JustTypeRef("Unit")
|
|
168
|
+
|
|
169
|
+
arg_decls = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
|
|
170
|
+
expr_decl = CallDecl(callable_ref, arg_decls, bound_params)
|
|
171
|
+
return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl))
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@dataclass
|
|
175
|
+
class RuntimeClassMethod:
|
|
176
|
+
__egg_decls__: ModuleDeclarations
|
|
177
|
+
# Either a string if it isn't bound or a tp if it s
|
|
178
|
+
__egg_tp__: JustTypeRef | str
|
|
179
|
+
__egg_method_name__: str
|
|
180
|
+
__egg_callable_ref__: ClassMethodRef = field(init=False)
|
|
181
|
+
__egg_fn_decl__: FunctionDecl = field(init=False)
|
|
182
|
+
|
|
183
|
+
def __post_init__(self):
|
|
184
|
+
self.__egg_callable_ref__ = ClassMethodRef(self.class_name, self.__egg_method_name__)
|
|
185
|
+
try:
|
|
186
|
+
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
|
|
187
|
+
except KeyError:
|
|
188
|
+
raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}")
|
|
189
|
+
|
|
190
|
+
def __call__(self, *args: ArgType) -> RuntimeExpr:
|
|
191
|
+
bound_params = self.__egg_tp__.args if isinstance(self.__egg_tp__, JustTypeRef) else None
|
|
192
|
+
return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, bound_params)
|
|
193
|
+
|
|
194
|
+
def __str__(self) -> str:
|
|
195
|
+
return f"{self.class_name}.{self.__egg_method_name__}"
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def class_name(self) -> str:
|
|
199
|
+
if isinstance(self.__egg_tp__, str):
|
|
200
|
+
return self.__egg_tp__
|
|
201
|
+
return self.__egg_tp__.name
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass
|
|
205
|
+
class RuntimeMethod:
|
|
206
|
+
__egg_decls__: ModuleDeclarations
|
|
207
|
+
__egg_typed_expr__: TypedExprDecl
|
|
208
|
+
__egg_method_name__: str
|
|
209
|
+
__egg_callable_ref__: MethodRef = field(init=False)
|
|
210
|
+
__egg_fn_decl__: Optional[FunctionDecl] = field(init=False)
|
|
211
|
+
|
|
212
|
+
def __post_init__(self):
|
|
213
|
+
self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__)
|
|
214
|
+
# Special case for __ne__ which does not have a normal function defintion since
|
|
215
|
+
# it relies of type parameters
|
|
216
|
+
if self.__egg_method_name__ == "__ne__":
|
|
217
|
+
self.__egg_fn_decl__ = None
|
|
218
|
+
else:
|
|
219
|
+
try:
|
|
220
|
+
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
|
|
221
|
+
except KeyError:
|
|
222
|
+
raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}")
|
|
223
|
+
|
|
224
|
+
def __call__(self, *args: ArgType) -> RuntimeExpr:
|
|
225
|
+
first_arg = RuntimeExpr(self.__egg_decls__, self.__egg_typed_expr__)
|
|
226
|
+
args = (first_arg, *args)
|
|
227
|
+
return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args)
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def class_name(self) -> str:
|
|
231
|
+
return self.__egg_typed_expr__.tp.name
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@dataclass
|
|
235
|
+
class RuntimeExpr:
|
|
236
|
+
__egg_decls__: ModuleDeclarations
|
|
237
|
+
__egg_typed_expr__: TypedExprDecl
|
|
238
|
+
|
|
239
|
+
def __getattr__(self, name: str) -> RuntimeMethod:
|
|
240
|
+
return RuntimeMethod(self.__egg_decls__, self.__egg_typed_expr__, name)
|
|
241
|
+
|
|
242
|
+
def __repr__(self) -> str:
|
|
243
|
+
"""
|
|
244
|
+
The repr of the expr is the pretty printed version of the expr.
|
|
245
|
+
"""
|
|
246
|
+
return str(self)
|
|
247
|
+
|
|
248
|
+
def __str__(self) -> str:
|
|
249
|
+
pretty_expr = self.__egg_typed_expr__.expr.pretty(parens=False)
|
|
250
|
+
if config.SHOW_TYPES:
|
|
251
|
+
s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}"
|
|
252
|
+
return black.format_str(s, mode=black.FileMode()).strip()
|
|
253
|
+
else:
|
|
254
|
+
return black.format_str(pretty_expr, mode=black.FileMode(line_length=180)).strip()
|
|
255
|
+
|
|
256
|
+
def __dir__(self) -> Iterable[str]:
|
|
257
|
+
return list(self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).methods)
|
|
258
|
+
|
|
259
|
+
# Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because
|
|
260
|
+
# we don't wany any type that MyPy thinks is an expr to be used with __eq__.
|
|
261
|
+
# That's because we want to reserve __eq__ for domain specific equality checks, overloading this method.
|
|
262
|
+
# To check if two exprs are equal, use the expr_eq method.
|
|
263
|
+
def __eq__(self, other: NoReturn) -> Expr: # type: ignore
|
|
264
|
+
raise NotImplementedError(
|
|
265
|
+
"Do not use == on RuntimeExpr. Compare the __egg_typed_expr__ attribute instead for structural equality."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# Define each of the special methods, since we have already declared them for pretty printing
|
|
270
|
+
for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__"]:
|
|
271
|
+
|
|
272
|
+
def _special_method(self: RuntimeExpr, *args: ArgType, __name: str = name) -> RuntimeExpr:
|
|
273
|
+
return RuntimeMethod(self.__egg_decls__, self.__egg_typed_expr__, __name)(*args)
|
|
274
|
+
|
|
275
|
+
setattr(RuntimeExpr, name, _special_method)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# Args can either be expressions or literals which are automatically promoted
|
|
279
|
+
ArgType = Union[RuntimeExpr, int, str, float]
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _resolve_literal(decls: ModuleDeclarations, arg: ArgType) -> RuntimeExpr:
|
|
283
|
+
if isinstance(arg, int):
|
|
284
|
+
return RuntimeExpr(decls, TypedExprDecl(JustTypeRef("i64"), LitDecl(arg)))
|
|
285
|
+
elif isinstance(arg, float):
|
|
286
|
+
return RuntimeExpr(decls, TypedExprDecl(JustTypeRef("f64"), LitDecl(arg)))
|
|
287
|
+
elif isinstance(arg, str):
|
|
288
|
+
return RuntimeExpr(decls, TypedExprDecl(JustTypeRef("String"), LitDecl(arg)))
|
|
289
|
+
return arg
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _resolve_callable(callable: object) -> CallableRef:
|
|
293
|
+
"""
|
|
294
|
+
Resolves a runtime callable into a ref
|
|
295
|
+
"""
|
|
296
|
+
if isinstance(callable, RuntimeFunction):
|
|
297
|
+
return FunctionRef(callable.__egg_name__)
|
|
298
|
+
if isinstance(callable, RuntimeClassMethod):
|
|
299
|
+
return ClassMethodRef(callable.class_name, callable.__egg_method_name__)
|
|
300
|
+
if isinstance(callable, RuntimeMethod):
|
|
301
|
+
return MethodRef(callable.__egg_typed_expr__.tp.name, callable.__egg_method_name__)
|
|
302
|
+
if isinstance(callable, RuntimeClass):
|
|
303
|
+
return ClassMethodRef(callable.__egg_name__, "__init__")
|
|
304
|
+
raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
|