egglog 0.4.0__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

@@ -0,0 +1,159 @@
1
+ """
2
+ N-Dimensional Arrays
3
+ ====================
4
+
5
+ Example of building NDarray in the vein of Mathemetics of Arrays.
6
+ """
7
+ # mypy: disable-error-code=empty-body
8
+ from __future__ import annotations
9
+
10
+ from egglog import *
11
+
12
+ egraph = EGraph()
13
+
14
+
15
+ @egraph.class_
16
+ class Value(BaseExpr):
17
+ def __init__(self, v: i64Like) -> None:
18
+ ...
19
+
20
+ def __mul__(self, other: Value) -> Value:
21
+ ...
22
+
23
+ def __add__(self, other: Value) -> Value:
24
+ ...
25
+
26
+
27
+ i, j = vars_("i j", i64)
28
+ egraph.register(
29
+ rewrite(Value(i) * Value(j)).to(Value(i * j)),
30
+ rewrite(Value(i) + Value(j)).to(Value(i + j)),
31
+ )
32
+
33
+
34
+ @egraph.class_
35
+ class Values(BaseExpr):
36
+ def __init__(self, v: Vec[Value]) -> None:
37
+ ...
38
+
39
+ def __getitem__(self, idx: Value) -> Value:
40
+ ...
41
+
42
+ def length(self) -> Value:
43
+ ...
44
+
45
+ def concat(self, other: Values) -> Values:
46
+ ...
47
+
48
+
49
+ @egraph.register
50
+ def _values(vs: Vec[Value], other: Vec[Value]):
51
+ yield rewrite(Values(vs)[Value(i)]).to(vs[i])
52
+ yield rewrite(Values(vs).length()).to(Value(vs.length()))
53
+ yield rewrite(Values(vs).concat(Values(other))).to(Values(vs.append(other)))
54
+ # yield rewrite(l.concat(r).length()).to(l.length() + r.length())
55
+ # yield rewrite(l.concat(r)[idx])
56
+
57
+
58
+ @egraph.class_
59
+ class NDArray(BaseExpr):
60
+ """
61
+ An n-dimensional array.
62
+ """
63
+
64
+ def __getitem__(self, idx: Values) -> Value:
65
+ ...
66
+
67
+ def shape(self) -> Values:
68
+ ...
69
+
70
+
71
+ @egraph.function
72
+ def arange(n: Value) -> NDArray:
73
+ ...
74
+
75
+
76
+ @egraph.register
77
+ def _ndarray_arange(n: Value, idx: Values):
78
+ yield rewrite(arange(n).shape()).to(Values(Vec(n)))
79
+ yield rewrite(arange(n)[idx]).to(idx[Value(0)])
80
+
81
+
82
+ def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
83
+ """
84
+ Simplify and print
85
+ """
86
+ egraph.register(left)
87
+ egraph.run(30)
88
+ res = egraph.extract(left)
89
+ print(f"{left} == {right} ➡ {res}")
90
+ egraph.check(eq(left).to(right))
91
+
92
+
93
+ assert_simplifies(arange(Value(10)).shape(), Values(Vec(Value(10))))
94
+ assert_simplifies(arange(Value(10))[Values(Vec(Value(0)))], Value(0))
95
+ assert_simplifies(arange(Value(10))[Values(Vec(Value(1)))], Value(1))
96
+
97
+
98
+ @egraph.function
99
+ def py_value(s: StringLike) -> Value:
100
+ ...
101
+
102
+
103
+ @egraph.register
104
+ def _py_value(l: String, r: String):
105
+ yield rewrite(py_value(l) + py_value(r)).to(py_value(join(l, " + ", r)))
106
+ yield rewrite(py_value(l) * py_value(r)).to(py_value(join(l, " * ", r)))
107
+
108
+
109
+ @egraph.function
110
+ def py_values(s: StringLike) -> Values:
111
+ ...
112
+
113
+
114
+ @egraph.register
115
+ def _py_values(l: String, r: String):
116
+ yield rewrite(py_values(l)[py_value(r)]).to(py_value(join(l, "[", r, "]")))
117
+ yield rewrite(py_values(l).length()).to(py_value(join("len(", l, ")")))
118
+ yield rewrite(py_values(l).concat(py_values(r))).to(py_values(join(l, " + ", r)))
119
+
120
+
121
+ @egraph.function
122
+ def py_ndarray(s: StringLike) -> NDArray:
123
+ ...
124
+
125
+
126
+ @egraph.register
127
+ def _py_ndarray(l: String, r: String):
128
+ yield rewrite(py_ndarray(l)[py_values(r)]).to(py_value(join(l, "[", r, "]")))
129
+ yield rewrite(py_ndarray(l).shape()).to(py_values(join(l, ".shape")))
130
+ yield rewrite(arange(py_value(l))).to(py_ndarray(join("np.arange(", l, ")")))
131
+
132
+
133
+ assert_simplifies(py_ndarray("x").shape(), py_values("x.shape"))
134
+ assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("np.arange(x)[y]"))
135
+ # assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("y[0]"))
136
+
137
+
138
+ @egraph.function
139
+ def cross(l: NDArray, r: NDArray) -> NDArray:
140
+ ...
141
+
142
+
143
+ @egraph.register
144
+ def _cross(l: NDArray, r: NDArray, idx: Values):
145
+ yield rewrite(cross(l, r).shape()).to(l.shape().concat(r.shape()))
146
+ yield rewrite(cross(l, r)[idx]).to(l[idx] * r[idx])
147
+
148
+
149
+ assert_simplifies(cross(arange(Value(10)), arange(Value(11))).shape(), Values(Vec(Value(10), Value(11))))
150
+ assert_simplifies(cross(py_ndarray("x"), py_ndarray("y")).shape(), py_values("x.shape + y.shape"))
151
+ assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("x[idx] * y[idx]"))
152
+
153
+
154
+ @egraph.register
155
+ def _cross_py(l: String, r: String):
156
+ yield rewrite(cross(py_ndarray(l), py_ndarray(r))).to(py_ndarray(join("np.multiply.outer(", l, ", ", r, ")")))
157
+
158
+
159
+ assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("np.multiply.outer(x, y)[idx]"))
@@ -0,0 +1,84 @@
1
+ """
2
+ Resolution theorem proving.
3
+ ===========================
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from typing import ClassVar
8
+
9
+ from egglog import *
10
+
11
+ egraph = EGraph()
12
+
13
+
14
+ @egraph.class_
15
+ class Bool(BaseExpr):
16
+ FALSE: ClassVar[Bool]
17
+
18
+ def __or__(self, other: Bool) -> Bool: # type: ignore[empty-body]
19
+ ...
20
+
21
+ def __invert__(self) -> Bool: # type: ignore[empty-body]
22
+ ...
23
+
24
+
25
+ # Show off two ways of creating constants, either as top level values or as classvars
26
+ T = egraph.constant("T", Bool)
27
+ F = Bool.FALSE
28
+
29
+ p, a, b, c, as_, bs = vars_("p a b c as bs", Bool)
30
+ egraph.register(
31
+ # clauses are assumed in the normal form (or a (or b (or c False)))
32
+ set_(~F).to(T),
33
+ set_(~T).to(F),
34
+ # "Solving" negation equations
35
+ rule(eq(~p).to(T)).then(union(p).with_(F)),
36
+ rule(eq(~p).to(F)).then(union(p).with_(T)),
37
+ # canonicalize associtivity. "append" for clauses terminate with false
38
+ rewrite((a | b) | c).to(a | (b | c)),
39
+ # commutativity
40
+ rewrite(a | (b | c)).to(b | (a | c)),
41
+ # absoprtion
42
+ rewrite(a | (a | b)).to(a | b),
43
+ rewrite(a | (~a | b)).to(T),
44
+ # Simplification
45
+ rewrite(F | a).to(a),
46
+ rewrite(a | F).to(a),
47
+ rewrite(T | a).to(T),
48
+ rewrite(a | T).to(T),
49
+ # unit propagation
50
+ # This is kind of interesting actually.
51
+ # Looks a bit like equation solving
52
+ rule(eq(T).to(p | F)).then(union(p).with_(T)),
53
+ # resolution
54
+ # This counts on commutativity to bubble everything possible up to the front of the clause.
55
+ rule(
56
+ eq(T).to(a | as_),
57
+ eq(T).to(~a | bs),
58
+ ).then(
59
+ set_(as_ | bs).to(T),
60
+ ),
61
+ )
62
+
63
+
64
+ # Example predicate
65
+ @egraph.function
66
+ def pred(x: i64Like) -> Bool: # type: ignore[empty-body]
67
+ ...
68
+
69
+
70
+ p0 = egraph.define("p0", pred(0))
71
+ p1 = egraph.define("p1", pred(1))
72
+ p2 = egraph.define("p2", pred(2))
73
+ egraph.register(
74
+ set_(p1 | (~p2 | F)).to(T),
75
+ set_(p2 | (~p0 | F)).to(T),
76
+ set_(p0 | (~p1 | F)).to(T),
77
+ union(p1).with_(F),
78
+ set_(~p0 | (~p1 | (p2 | F))).to(T),
79
+ )
80
+ egraph.run(10)
81
+ egraph.check(T != F)
82
+ egraph.check(eq(p0).to(F))
83
+ egraph.check(eq(p2).to(F))
84
+ egraph
@@ -0,0 +1,33 @@
1
+ """
2
+ Schedule demo
3
+ =============
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from egglog import *
8
+
9
+ egraph = EGraph()
10
+
11
+ left = egraph.relation("left", i64)
12
+ right = egraph.relation("right", i64)
13
+
14
+ egraph.register(left(i64(0)), right(i64(0)))
15
+
16
+ x, y = vars_("x y", i64)
17
+
18
+ step_left = egraph.ruleset("step-left")
19
+ egraph.register(rule(left(x), right(x), ruleset=step_left).then(left(x + 1)))
20
+
21
+ step_right = egraph.ruleset("step-right")
22
+ egraph.register(rule(left(x), right(y), eq(x).to(y + 1), ruleset=step_right).then(right(x)))
23
+
24
+ egraph.run(
25
+ seq(
26
+ run(step_right).saturate(),
27
+ run(step_left).saturate(),
28
+ )
29
+ * 10
30
+ )
31
+ egraph.check(left(i64(10)), right(i64(9)))
32
+ egraph.check_fail(left(i64(11)), right(i64(10)))
33
+ egraph
@@ -0,0 +1,40 @@
1
+ from .bindings import EGraph
2
+
3
+ EGRAPH_VAR = "_MAGIC_EGRAPH"
4
+
5
+ try:
6
+ get_ipython() # type: ignore[name-defined]
7
+ in_ipython = True
8
+ except NameError:
9
+ in_ipython = False
10
+
11
+ if in_ipython:
12
+ import graphviz
13
+ from IPython.core.magic import needs_local_scope, register_cell_magic
14
+
15
+ @needs_local_scope
16
+ @register_cell_magic
17
+ def egglog(line, cell, local_ns):
18
+ """
19
+ Run an egglog program
20
+
21
+ Usage:
22
+
23
+ %%egglog [output] [continue] [graph]
24
+ (egglog program)
25
+
26
+ If `output` is specified, the output of the program will be printed.
27
+ If `continue` is specified, the program will be run in the same EGraph as the previous cell.
28
+ If `graph` is specified, the EGraph will be displayed as a graph.
29
+ """
30
+ if EGRAPH_VAR in local_ns and "continue" in line:
31
+ e = local_ns[EGRAPH_VAR]
32
+ else:
33
+ e = EGraph()
34
+ local_ns[EGRAPH_VAR] = e
35
+ cmds = e.parse_program(cell)
36
+ res = e.run_program(*cmds)
37
+ if "output" in line:
38
+ print("\n".join(res))
39
+ if "graph" in line:
40
+ return graphviz.Source(e.to_graphviz_string())
egglog/monkeypatch.py ADDED
@@ -0,0 +1,33 @@
1
+ import sys
2
+ import typing
3
+
4
+ __all__ = ["monkeypatch_forward_ref"]
5
+
6
+
7
+ def monkeypatch_forward_ref():
8
+ """
9
+ Monkeypatch to backport https://github.com/python/cpython/pull/21553.
10
+ Removed recursive gaurd for simplicity
11
+ Can be removed once Pytho 3.8 is no longer supported
12
+ """
13
+ if sys.version_info >= (3, 9):
14
+ return
15
+ typing.ForwardRef._evaluate = _evaluate_monkeypatch # type: ignore
16
+
17
+
18
+ def _evaluate_monkeypatch(self, globalns, localns):
19
+ if not self.__forward_evaluated__ or localns is not globalns:
20
+ if globalns is None and localns is None:
21
+ globalns = localns = {}
22
+ elif globalns is None:
23
+ globalns = localns
24
+ elif localns is None:
25
+ localns = globalns
26
+ type_ = typing._type_check( # type: ignore
27
+ eval(self.__forward_code__, globalns, localns),
28
+ "Forward references must evaluate to types.",
29
+ is_argument=self.__forward_is_argument__,
30
+ )
31
+ self.__forward_value__ = typing._eval_type(type_, globalns, localns) # type: ignore
32
+ self.__forward_evaluated__ = True
33
+ return self.__forward_value__
egglog/py.typed ADDED
File without changes
egglog/runtime.py ADDED
@@ -0,0 +1,304 @@
1
+ """
2
+ This module holds a number of types which are only used at runtime to emulate Python objects.
3
+
4
+ Users will not import anything from this module, and statically they won't know these are the types they are using.
5
+
6
+ But at runtime they will be exposed.
7
+
8
+ Note that all their internal fields are prefixed with __egg_ to avoid name collisions with user code, but will end in __
9
+ so they are not mangled by Python and can be accessed by the user.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass, field
15
+ from typing import Collection, Iterable, Optional, Union
16
+
17
+ import black
18
+ from typing_extensions import assert_never
19
+
20
+ from . import config # noqa: F401
21
+ from .declarations import *
22
+ from .declarations import BINARY_METHODS, UNARY_METHODS
23
+ from .type_constraint_solver import *
24
+
25
+ __all__ = [
26
+ "LIT_CLASS_NAMES",
27
+ "RuntimeClass",
28
+ "RuntimeParamaterizedClass",
29
+ "RuntimeClassMethod",
30
+ "RuntimeExpr",
31
+ "RuntimeFunction",
32
+ "ArgType",
33
+ ]
34
+
35
+
36
+ BLACK_MODE = black.Mode(line_length=120) # type: ignore
37
+
38
+ UNIT_CLASS_NAME = "Unit"
39
+ UNARY_LIT_CLASS_NAMES = {"i64", "f64", "String"}
40
+ LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME}
41
+
42
+
43
+ @dataclass
44
+ class RuntimeClass:
45
+ __egg_decls__: ModuleDeclarations
46
+ __egg_name__: str
47
+
48
+ def __call__(self, *args: ArgType) -> RuntimeExpr:
49
+ """
50
+ Create an instance of this kind by calling the __init__ classmethod
51
+ """
52
+ # If this is a literal type, initializing it with a literal should return a literal
53
+ if self.__egg_name__ in UNARY_LIT_CLASS_NAMES:
54
+ assert len(args) == 1
55
+ assert isinstance(args[0], (int, float, str))
56
+ return RuntimeExpr(self.__egg_decls__, TypedExprDecl(JustTypeRef(self.__egg_name__), LitDecl(args[0])))
57
+ if self.__egg_name__ == UNIT_CLASS_NAME:
58
+ assert len(args) == 0
59
+ return RuntimeExpr(self.__egg_decls__, TypedExprDecl(JustTypeRef(self.__egg_name__), LitDecl(None)))
60
+
61
+ return RuntimeClassMethod(self.__egg_decls__, self.__egg_name__, "__init__")(*args)
62
+
63
+ def __dir__(self) -> list[str]:
64
+ cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
65
+ possible_methods = list(cls_decl.class_methods) + list(cls_decl.class_variables)
66
+ if "__init__" in possible_methods:
67
+ possible_methods.remove("__init__")
68
+ possible_methods.append("__call__")
69
+ return possible_methods
70
+
71
+ def __getitem__(self, args: tuple[RuntimeTypeArgType, ...] | RuntimeTypeArgType) -> RuntimeParamaterizedClass:
72
+ if not isinstance(args, tuple):
73
+ args = (args,)
74
+ tp = JustTypeRef(self.__egg_name__, tuple(class_to_ref(arg) for arg in args))
75
+ return RuntimeParamaterizedClass(self.__egg_decls__, tp)
76
+
77
+ def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr:
78
+ cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
79
+ # if this is a class variable, return an expr for it, otherwise, assume it's a method
80
+ if name in cls_decl.class_variables:
81
+ return_tp = cls_decl.class_variables[name]
82
+ return RuntimeExpr(
83
+ self.__egg_decls__, TypedExprDecl(return_tp, CallDecl(ClassVariableRef(self.__egg_name__, name)))
84
+ )
85
+ return RuntimeClassMethod(self.__egg_decls__, self.__egg_name__, name)
86
+
87
+ def __str__(self) -> str:
88
+ return self.__egg_name__
89
+
90
+ # Make hashable so can go in Union
91
+ def __hash__(self) -> int:
92
+ return hash((id(self.__egg_decls__), self.__egg_name__))
93
+
94
+
95
+ @dataclass
96
+ class RuntimeParamaterizedClass:
97
+ __egg_decls__: ModuleDeclarations
98
+ # Note that this will never be a typevar because we don't use RuntimeParamaterizedClass for maps on their own methods
99
+ # which is the only time we define function which take typevars
100
+ __egg_tp__: JustTypeRef
101
+
102
+ def __post_init__(self):
103
+ desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).n_type_vars
104
+ if len(self.__egg_tp__.args) != desired_args:
105
+ raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
106
+
107
+ def __call__(self, *args: ArgType) -> RuntimeExpr:
108
+ return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args)
109
+
110
+ def __getattr__(self, name: str) -> RuntimeClassMethod:
111
+ return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name)
112
+
113
+ def __str__(self) -> str:
114
+ return self.__egg_tp__.pretty()
115
+
116
+
117
+ # Type args can either be typevars or classes
118
+ RuntimeTypeArgType = Union[RuntimeClass, RuntimeParamaterizedClass]
119
+
120
+
121
+ def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef:
122
+ if isinstance(cls, RuntimeClass):
123
+ return JustTypeRef(cls.__egg_name__)
124
+ if isinstance(cls, RuntimeParamaterizedClass):
125
+ return cls.__egg_tp__
126
+ assert_never(cls)
127
+
128
+
129
+ @dataclass
130
+ class RuntimeFunction:
131
+ __egg_decls__: ModuleDeclarations
132
+ __egg_name__: str
133
+ __egg_fn_ref__: FunctionRef = field(init=False)
134
+ __egg_fn_decl__: FunctionDecl = field(init=False)
135
+
136
+ def __post_init__(self):
137
+ self.__egg_fn_ref__ = FunctionRef(self.__egg_name__)
138
+ self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__)
139
+
140
+ def __call__(self, *args: ArgType) -> RuntimeExpr:
141
+ return _call(self.__egg_decls__, self.__egg_fn_ref__, self.__egg_fn_decl__, args)
142
+
143
+ def __str__(self) -> str:
144
+ return self.__egg_name__
145
+
146
+
147
+ def _call(
148
+ decls: ModuleDeclarations,
149
+ callable_ref: CallableRef,
150
+ # Not included if this is the != method
151
+ fn_decl: Optional[FunctionDecl],
152
+ args: Collection[ArgType],
153
+ bound_params: Optional[tuple[JustTypeRef, ...]] = None,
154
+ ) -> RuntimeExpr:
155
+ upcasted_args = [_resolve_literal(decls, arg) for arg in args]
156
+
157
+ arg_types = [arg.__egg_typed_expr__.tp for arg in upcasted_args]
158
+
159
+ if bound_params is not None:
160
+ tcs = TypeConstraintSolver.from_type_parameters(bound_params)
161
+ else:
162
+ tcs = TypeConstraintSolver()
163
+
164
+ if fn_decl is not None:
165
+ return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types)
166
+ else:
167
+ return_tp = JustTypeRef("Unit")
168
+
169
+ arg_decls = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
170
+ expr_decl = CallDecl(callable_ref, arg_decls, bound_params)
171
+ return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl))
172
+
173
+
174
+ @dataclass
175
+ class RuntimeClassMethod:
176
+ __egg_decls__: ModuleDeclarations
177
+ # Either a string if it isn't bound or a tp if it s
178
+ __egg_tp__: JustTypeRef | str
179
+ __egg_method_name__: str
180
+ __egg_callable_ref__: ClassMethodRef = field(init=False)
181
+ __egg_fn_decl__: FunctionDecl = field(init=False)
182
+
183
+ def __post_init__(self):
184
+ self.__egg_callable_ref__ = ClassMethodRef(self.class_name, self.__egg_method_name__)
185
+ try:
186
+ self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
187
+ except KeyError:
188
+ raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}")
189
+
190
+ def __call__(self, *args: ArgType) -> RuntimeExpr:
191
+ bound_params = self.__egg_tp__.args if isinstance(self.__egg_tp__, JustTypeRef) else None
192
+ return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, bound_params)
193
+
194
+ def __str__(self) -> str:
195
+ return f"{self.class_name}.{self.__egg_method_name__}"
196
+
197
+ @property
198
+ def class_name(self) -> str:
199
+ if isinstance(self.__egg_tp__, str):
200
+ return self.__egg_tp__
201
+ return self.__egg_tp__.name
202
+
203
+
204
+ @dataclass
205
+ class RuntimeMethod:
206
+ __egg_decls__: ModuleDeclarations
207
+ __egg_typed_expr__: TypedExprDecl
208
+ __egg_method_name__: str
209
+ __egg_callable_ref__: MethodRef = field(init=False)
210
+ __egg_fn_decl__: Optional[FunctionDecl] = field(init=False)
211
+
212
+ def __post_init__(self):
213
+ self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__)
214
+ # Special case for __ne__ which does not have a normal function defintion since
215
+ # it relies of type parameters
216
+ if self.__egg_method_name__ == "__ne__":
217
+ self.__egg_fn_decl__ = None
218
+ else:
219
+ try:
220
+ self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
221
+ except KeyError:
222
+ raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}")
223
+
224
+ def __call__(self, *args: ArgType) -> RuntimeExpr:
225
+ first_arg = RuntimeExpr(self.__egg_decls__, self.__egg_typed_expr__)
226
+ args = (first_arg, *args)
227
+ return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args)
228
+
229
+ @property
230
+ def class_name(self) -> str:
231
+ return self.__egg_typed_expr__.tp.name
232
+
233
+
234
+ @dataclass
235
+ class RuntimeExpr:
236
+ __egg_decls__: ModuleDeclarations
237
+ __egg_typed_expr__: TypedExprDecl
238
+
239
+ def __getattr__(self, name: str) -> RuntimeMethod:
240
+ return RuntimeMethod(self.__egg_decls__, self.__egg_typed_expr__, name)
241
+
242
+ def __repr__(self) -> str:
243
+ """
244
+ The repr of the expr is the pretty printed version of the expr.
245
+ """
246
+ return str(self)
247
+
248
+ def __str__(self) -> str:
249
+ pretty_expr = self.__egg_typed_expr__.expr.pretty(parens=False)
250
+ if config.SHOW_TYPES:
251
+ s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}"
252
+ return black.format_str(s, mode=black.FileMode()).strip()
253
+ else:
254
+ return black.format_str(pretty_expr, mode=black.FileMode(line_length=180)).strip()
255
+
256
+ def __dir__(self) -> Iterable[str]:
257
+ return list(self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).methods)
258
+
259
+ # Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because
260
+ # we don't wany any type that MyPy thinks is an expr to be used with __eq__.
261
+ # That's because we want to reserve __eq__ for domain specific equality checks, overloading this method.
262
+ # To check if two exprs are equal, use the expr_eq method.
263
+ def __eq__(self, other: NoReturn) -> Expr: # type: ignore
264
+ raise NotImplementedError(
265
+ "Do not use == on RuntimeExpr. Compare the __egg_typed_expr__ attribute instead for structural equality."
266
+ )
267
+
268
+
269
+ # Define each of the special methods, since we have already declared them for pretty printing
270
+ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__"]:
271
+
272
+ def _special_method(self: RuntimeExpr, *args: ArgType, __name: str = name) -> RuntimeExpr:
273
+ return RuntimeMethod(self.__egg_decls__, self.__egg_typed_expr__, __name)(*args)
274
+
275
+ setattr(RuntimeExpr, name, _special_method)
276
+
277
+
278
+ # Args can either be expressions or literals which are automatically promoted
279
+ ArgType = Union[RuntimeExpr, int, str, float]
280
+
281
+
282
+ def _resolve_literal(decls: ModuleDeclarations, arg: ArgType) -> RuntimeExpr:
283
+ if isinstance(arg, int):
284
+ return RuntimeExpr(decls, TypedExprDecl(JustTypeRef("i64"), LitDecl(arg)))
285
+ elif isinstance(arg, float):
286
+ return RuntimeExpr(decls, TypedExprDecl(JustTypeRef("f64"), LitDecl(arg)))
287
+ elif isinstance(arg, str):
288
+ return RuntimeExpr(decls, TypedExprDecl(JustTypeRef("String"), LitDecl(arg)))
289
+ return arg
290
+
291
+
292
+ def _resolve_callable(callable: object) -> CallableRef:
293
+ """
294
+ Resolves a runtime callable into a ref
295
+ """
296
+ if isinstance(callable, RuntimeFunction):
297
+ return FunctionRef(callable.__egg_name__)
298
+ if isinstance(callable, RuntimeClassMethod):
299
+ return ClassMethodRef(callable.class_name, callable.__egg_method_name__)
300
+ if isinstance(callable, RuntimeMethod):
301
+ return MethodRef(callable.__egg_typed_expr__.tp.name, callable.__egg_method_name__)
302
+ if isinstance(callable, RuntimeClass):
303
+ return ClassMethodRef(callable.__egg_name__, "__init__")
304
+ raise NotImplementedError(f"Cannot turn {callable} into a callable ref")