egglog 7.0.0__cp312-none-win_amd64.whl → 7.2.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +7 -0
- egglog/builtins.py +41 -1
- egglog/conversion.py +22 -17
- egglog/declarations.py +122 -37
- egglog/egraph.py +219 -78
- egglog/egraph_state.py +124 -54
- egglog/examples/higher_order_functions.py +50 -0
- egglog/exp/array_api.py +12 -9
- egglog/pretty.py +71 -15
- egglog/runtime.py +118 -33
- egglog/thunk.py +17 -6
- egglog/type_constraint_solver.py +5 -4
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/METADATA +10 -10
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/RECORD +17 -16
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/WHEEL +0 -0
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/license_files/LICENSE +0 -0
|
Binary file
|
egglog/bindings.pyi
CHANGED
|
@@ -494,6 +494,12 @@ class Relation:
|
|
|
494
494
|
class PrintOverallStatistics:
|
|
495
495
|
def __init__(self) -> None: ...
|
|
496
496
|
|
|
497
|
+
@final
|
|
498
|
+
class UnstableCombinedRuleset:
|
|
499
|
+
name: str
|
|
500
|
+
rulesets: list[str]
|
|
501
|
+
def __init__(self, name: str, rulesets: list[str]) -> None: ...
|
|
502
|
+
|
|
497
503
|
_Command: TypeAlias = (
|
|
498
504
|
SetOption
|
|
499
505
|
| Datatype
|
|
@@ -521,6 +527,7 @@ _Command: TypeAlias = (
|
|
|
521
527
|
| CheckProof
|
|
522
528
|
| Relation
|
|
523
529
|
| PrintOverallStatistics
|
|
530
|
+
| UnstableCombinedRuleset
|
|
524
531
|
)
|
|
525
532
|
|
|
526
533
|
def termdag_term_to_expr(termdag: TermDag, term: _Term) -> _Expr: ...
|
egglog/builtins.py
CHANGED
|
@@ -5,10 +5,14 @@ Builtin sorts and function to egg.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, overload
|
|
10
|
+
|
|
11
|
+
from typing_extensions import TypeVarTuple, Unpack
|
|
9
12
|
|
|
10
13
|
from .conversion import converter
|
|
11
14
|
from .egraph import Expr, Unit, function, method
|
|
15
|
+
from .runtime import RuntimeFunction
|
|
12
16
|
|
|
13
17
|
if TYPE_CHECKING:
|
|
14
18
|
from collections.abc import Callable
|
|
@@ -31,6 +35,7 @@ __all__ = [
|
|
|
31
35
|
"py_eval",
|
|
32
36
|
"py_exec",
|
|
33
37
|
"py_eval_fn",
|
|
38
|
+
"UnstableFn",
|
|
34
39
|
]
|
|
35
40
|
|
|
36
41
|
|
|
@@ -461,3 +466,38 @@ def py_exec(code: StringLike, globals: object = PyObject.dict(), locals: object
|
|
|
461
466
|
"""
|
|
462
467
|
Copies the locals, execs the Python code, and returns the locals with any updates.
|
|
463
468
|
"""
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
TS = TypeVarTuple("TS")
|
|
472
|
+
|
|
473
|
+
T1 = TypeVar("T1")
|
|
474
|
+
T2 = TypeVar("T2")
|
|
475
|
+
T3 = TypeVar("T3")
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
|
|
479
|
+
@overload
|
|
480
|
+
def __init__(self, f: Callable[[Unpack[TS]], T]) -> None: ...
|
|
481
|
+
|
|
482
|
+
@overload
|
|
483
|
+
def __init__(self, f: Callable[[T1, Unpack[TS]], T], _a: T1, /) -> None: ...
|
|
484
|
+
|
|
485
|
+
@overload
|
|
486
|
+
def __init__(self, f: Callable[[T1, T2, Unpack[TS]], T], _a: T1, _b: T2, /) -> None: ...
|
|
487
|
+
|
|
488
|
+
# Removing due to bug in MyPy
|
|
489
|
+
# https://github.com/python/mypy/issues/17212
|
|
490
|
+
# @overload
|
|
491
|
+
# def __init__(self, f: Callable[[T1, T2, T3, Unpack[TS]], T], _a: T1, _b: T2, _c: T3, /) -> None: ...
|
|
492
|
+
|
|
493
|
+
# etc, for partial application
|
|
494
|
+
|
|
495
|
+
@method(egg_fn="unstable-fn")
|
|
496
|
+
def __init__(self, f, *partial) -> None: ...
|
|
497
|
+
|
|
498
|
+
@method(egg_fn="unstable-app")
|
|
499
|
+
def __call__(self, *args: Unpack[TS]) -> T: ...
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
converter(RuntimeFunction, UnstableFn, UnstableFn)
|
|
503
|
+
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
|
egglog/conversion.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import TYPE_CHECKING, TypeVar, cast
|
|
4
|
+
from typing import TYPE_CHECKING, NewType, TypeVar, cast
|
|
5
5
|
|
|
6
6
|
from .declarations import *
|
|
7
7
|
from .pretty import *
|
|
@@ -16,7 +16,8 @@ if TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
|
|
18
18
|
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
|
|
19
|
-
|
|
19
|
+
TypeName = NewType("TypeName", str)
|
|
20
|
+
CONVERSIONS: dict[tuple[type | TypeName, TypeName], tuple[int, Callable]] = {}
|
|
20
21
|
# Global declerations to store all convertable types so we can query if they have certain methods or not
|
|
21
22
|
# Defer it as a thunk so we can register conversions without triggering type signature loading
|
|
22
23
|
CONVERSIONS_DECLS: Callable[[], Declarations] = Thunk.value(Declarations())
|
|
@@ -34,12 +35,12 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
|
|
|
34
35
|
Register a converter from some type to an egglog type.
|
|
35
36
|
"""
|
|
36
37
|
to_type_name = process_tp(to_type)
|
|
37
|
-
if not isinstance(to_type_name,
|
|
38
|
+
if not isinstance(to_type_name, str):
|
|
38
39
|
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
|
|
39
40
|
_register_converter(process_tp(from_type), to_type_name, fn, cost)
|
|
40
41
|
|
|
41
42
|
|
|
42
|
-
def _register_converter(a: type |
|
|
43
|
+
def _register_converter(a: type | TypeName, b: TypeName, a_b: Callable, cost: int) -> None:
|
|
43
44
|
"""
|
|
44
45
|
Registers a converter from some type to an egglog type, if not already registered.
|
|
45
46
|
|
|
@@ -94,14 +95,17 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
|
94
95
|
return resolve_literal(tp.to_var(), source)
|
|
95
96
|
|
|
96
97
|
|
|
97
|
-
def process_tp(tp: type | RuntimeClass) ->
|
|
98
|
+
def process_tp(tp: type | RuntimeClass) -> TypeName | type:
|
|
98
99
|
"""
|
|
99
100
|
Process a type before converting it, to add it to the global declerations and resolve to a ref.
|
|
100
101
|
"""
|
|
101
102
|
global CONVERSIONS_DECLS
|
|
102
103
|
if isinstance(tp, RuntimeClass):
|
|
103
104
|
CONVERSIONS_DECLS = Thunk.fn(_combine_decls, CONVERSIONS_DECLS, tp)
|
|
104
|
-
|
|
105
|
+
egg_tp = tp.__egg_tp__
|
|
106
|
+
if egg_tp.args:
|
|
107
|
+
raise TypeError(f"Cannot register a converter for a generic type, got {tp}")
|
|
108
|
+
return TypeName(egg_tp.name)
|
|
105
109
|
return tp
|
|
106
110
|
|
|
107
111
|
|
|
@@ -109,7 +113,7 @@ def _combine_decls(d: Callable[[], Declarations], x: HasDeclerations) -> Declara
|
|
|
109
113
|
return Declarations.create(d(), x)
|
|
110
114
|
|
|
111
115
|
|
|
112
|
-
def min_convertable_tp(a: object, b: object, name: str) ->
|
|
116
|
+
def min_convertable_tp(a: object, b: object, name: str) -> TypeName:
|
|
113
117
|
"""
|
|
114
118
|
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
|
|
115
119
|
"""
|
|
@@ -117,14 +121,14 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
|
|
|
117
121
|
a_tp = _get_tp(a)
|
|
118
122
|
b_tp = _get_tp(b)
|
|
119
123
|
a_converts_to = {
|
|
120
|
-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to
|
|
124
|
+
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to, name)
|
|
121
125
|
}
|
|
122
126
|
b_converts_to = {
|
|
123
|
-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to
|
|
127
|
+
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to, name)
|
|
124
128
|
}
|
|
125
|
-
if isinstance(a_tp,
|
|
129
|
+
if isinstance(a_tp, str):
|
|
126
130
|
a_converts_to[a_tp] = 0
|
|
127
|
-
if isinstance(b_tp,
|
|
131
|
+
if isinstance(b_tp, str):
|
|
128
132
|
b_converts_to[b_tp] = 0
|
|
129
133
|
common = set(a_converts_to) & set(b_converts_to)
|
|
130
134
|
if not common:
|
|
@@ -143,28 +147,29 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
|
143
147
|
try:
|
|
144
148
|
tp_just = tp.to_just()
|
|
145
149
|
except NotImplementedError:
|
|
146
|
-
# If this is a var, it has to be a runtime
|
|
150
|
+
# If this is a var, it has to be a runtime expession
|
|
147
151
|
assert isinstance(arg, RuntimeExpr)
|
|
148
152
|
return arg
|
|
149
|
-
|
|
153
|
+
tp_name = TypeName(tp_just.name)
|
|
154
|
+
if arg_type == tp_name:
|
|
150
155
|
# If the type is an egg type, it has to be a runtime expr
|
|
151
156
|
assert isinstance(arg, RuntimeExpr)
|
|
152
157
|
return arg
|
|
153
158
|
# Try all parent types as well, if we are converting from a Python type
|
|
154
159
|
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
|
|
155
160
|
try:
|
|
156
|
-
fn = CONVERSIONS[(cast(
|
|
161
|
+
fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
|
|
157
162
|
except KeyError:
|
|
158
163
|
continue
|
|
159
164
|
break
|
|
160
165
|
else:
|
|
161
|
-
raise ConvertError(f"Cannot convert {arg_type} to {
|
|
166
|
+
raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
|
|
162
167
|
return fn(arg)
|
|
163
168
|
|
|
164
169
|
|
|
165
|
-
def _get_tp(x: object) ->
|
|
170
|
+
def _get_tp(x: object) -> TypeName | type:
|
|
166
171
|
if isinstance(x, RuntimeExpr):
|
|
167
|
-
return x.__egg_typed_expr__.tp
|
|
172
|
+
return TypeName(x.__egg_typed_expr__.tp.name)
|
|
168
173
|
tp = type(x)
|
|
169
174
|
# If this value has a custom metaclass, let's use that as our index instead of the type
|
|
170
175
|
if type(tp) != type:
|
egglog/declarations.py
CHANGED
|
@@ -39,6 +39,7 @@ __all__ = [
|
|
|
39
39
|
"CallableDecl",
|
|
40
40
|
"VarDecl",
|
|
41
41
|
"PyObjectDecl",
|
|
42
|
+
"PartialCallDecl",
|
|
42
43
|
"LitType",
|
|
43
44
|
"LitDecl",
|
|
44
45
|
"CallDecl",
|
|
@@ -46,6 +47,7 @@ __all__ = [
|
|
|
46
47
|
"TypedExprDecl",
|
|
47
48
|
"ClassDecl",
|
|
48
49
|
"RulesetDecl",
|
|
50
|
+
"CombinedRulesetDecl",
|
|
49
51
|
"SaturateDecl",
|
|
50
52
|
"RepeatDecl",
|
|
51
53
|
"SequenceDecl",
|
|
@@ -67,6 +69,10 @@ __all__ = [
|
|
|
67
69
|
"RewriteOrRuleDecl",
|
|
68
70
|
"ActionCommandDecl",
|
|
69
71
|
"CommandDecl",
|
|
72
|
+
"SpecialFunctions",
|
|
73
|
+
"FunctionSignature",
|
|
74
|
+
"DefaultRewriteDecl",
|
|
75
|
+
"InitRef",
|
|
70
76
|
]
|
|
71
77
|
|
|
72
78
|
|
|
@@ -76,7 +82,13 @@ class DelayedDeclerations:
|
|
|
76
82
|
|
|
77
83
|
@property
|
|
78
84
|
def __egg_decls__(self) -> Declarations:
|
|
79
|
-
|
|
85
|
+
try:
|
|
86
|
+
return self.__egg_decls_thunk__()
|
|
87
|
+
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
|
|
88
|
+
# instead raise explicitly
|
|
89
|
+
except AttributeError as err:
|
|
90
|
+
msg = "Error resolving declerations"
|
|
91
|
+
raise RuntimeError(msg) from err
|
|
80
92
|
|
|
81
93
|
|
|
82
94
|
@runtime_checkable
|
|
@@ -88,9 +100,6 @@ class HasDeclerations(Protocol):
|
|
|
88
100
|
DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
|
|
89
101
|
|
|
90
102
|
|
|
91
|
-
# TODO: Make all ClassDecls take deferred type refs, which return new decls when resolving.
|
|
92
|
-
|
|
93
|
-
|
|
94
103
|
def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
|
|
95
104
|
d = []
|
|
96
105
|
for l in declerations_like:
|
|
@@ -110,7 +119,13 @@ class Declarations:
|
|
|
110
119
|
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
|
|
111
120
|
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
112
121
|
_classes: dict[str, ClassDecl] = field(default_factory=dict)
|
|
113
|
-
_rulesets: dict[str, RulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
122
|
+
_rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def default_ruleset(self) -> RulesetDecl:
|
|
126
|
+
ruleset = self._rulesets[""]
|
|
127
|
+
assert isinstance(ruleset, RulesetDecl)
|
|
128
|
+
return ruleset
|
|
114
129
|
|
|
115
130
|
@classmethod
|
|
116
131
|
def create(cls, *others: DeclerationsLike) -> Declarations:
|
|
@@ -126,7 +141,7 @@ class Declarations:
|
|
|
126
141
|
|
|
127
142
|
def copy(self) -> Declarations:
|
|
128
143
|
new = Declarations()
|
|
129
|
-
new
|
|
144
|
+
self.update_other(new)
|
|
130
145
|
return new
|
|
131
146
|
|
|
132
147
|
def update(self, *others: DeclerationsLike) -> None:
|
|
@@ -153,9 +168,13 @@ class Declarations:
|
|
|
153
168
|
other._functions |= self._functions
|
|
154
169
|
other._classes |= self._classes
|
|
155
170
|
other._constants |= self._constants
|
|
171
|
+
# Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
|
|
172
|
+
# is added to functions.
|
|
173
|
+
combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
|
|
156
174
|
other._rulesets |= self._rulesets
|
|
175
|
+
other._rulesets[""] = RulesetDecl(list(combined_default_rules))
|
|
157
176
|
|
|
158
|
-
def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
|
|
177
|
+
def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
|
|
159
178
|
match ref:
|
|
160
179
|
case FunctionRef(name):
|
|
161
180
|
return self._functions[name]
|
|
@@ -169,8 +188,29 @@ class Declarations:
|
|
|
169
188
|
return self._classes[class_name].class_methods[name]
|
|
170
189
|
case PropertyRef(class_name, property_name):
|
|
171
190
|
return self._classes[class_name].properties[property_name]
|
|
191
|
+
case InitRef(class_name):
|
|
192
|
+
init_fn = self._classes[class_name].init
|
|
193
|
+
assert init_fn
|
|
194
|
+
return init_fn
|
|
172
195
|
assert_never(ref)
|
|
173
196
|
|
|
197
|
+
def set_function_decl(
|
|
198
|
+
self, ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, decl: FunctionDecl
|
|
199
|
+
) -> None:
|
|
200
|
+
match ref:
|
|
201
|
+
case FunctionRef(name):
|
|
202
|
+
self._functions[name] = decl
|
|
203
|
+
case MethodRef(class_name, method_name):
|
|
204
|
+
self._classes[class_name].methods[method_name] = decl
|
|
205
|
+
case ClassMethodRef(class_name, name):
|
|
206
|
+
self._classes[class_name].class_methods[name] = decl
|
|
207
|
+
case PropertyRef(class_name, property_name):
|
|
208
|
+
self._classes[class_name].properties[property_name] = decl
|
|
209
|
+
case InitRef(class_name):
|
|
210
|
+
self._classes[class_name].init = decl
|
|
211
|
+
case _:
|
|
212
|
+
assert_never(ref)
|
|
213
|
+
|
|
174
214
|
def has_method(self, class_name: str, method_name: str) -> bool | None:
|
|
175
215
|
"""
|
|
176
216
|
Returns whether the given class has the given method, or None if we cant find the class.
|
|
@@ -182,12 +222,20 @@ class Declarations:
|
|
|
182
222
|
def get_class_decl(self, name: str) -> ClassDecl:
|
|
183
223
|
return self._classes[name]
|
|
184
224
|
|
|
225
|
+
def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
|
|
226
|
+
"""
|
|
227
|
+
Returns a class reference with type parameters, if the class is paramaterized.
|
|
228
|
+
"""
|
|
229
|
+
type_vars = self._classes[name].type_vars
|
|
230
|
+
return TypeRefWithVars(name, tuple(map(ClassTypeVarRef, type_vars)))
|
|
231
|
+
|
|
185
232
|
|
|
186
233
|
@dataclass
|
|
187
234
|
class ClassDecl:
|
|
188
235
|
egg_name: str | None = None
|
|
189
236
|
type_vars: tuple[str, ...] = ()
|
|
190
237
|
builtin: bool = False
|
|
238
|
+
init: FunctionDecl | None = None
|
|
191
239
|
class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
|
|
192
240
|
# These have to be seperate from class_methods so that printing them can be done easily
|
|
193
241
|
class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
|
|
@@ -196,7 +244,7 @@ class ClassDecl:
|
|
|
196
244
|
preserved_methods: dict[str, Callable] = field(default_factory=dict)
|
|
197
245
|
|
|
198
246
|
|
|
199
|
-
@dataclass
|
|
247
|
+
@dataclass(frozen=True)
|
|
200
248
|
class RulesetDecl:
|
|
201
249
|
rules: list[RewriteOrRuleDecl]
|
|
202
250
|
|
|
@@ -206,6 +254,11 @@ class RulesetDecl:
|
|
|
206
254
|
return hash((type(self), tuple(self.rules)))
|
|
207
255
|
|
|
208
256
|
|
|
257
|
+
@dataclass(frozen=True)
|
|
258
|
+
class CombinedRulesetDecl:
|
|
259
|
+
rulesets: tuple[str, ...]
|
|
260
|
+
|
|
261
|
+
|
|
209
262
|
# Have two different types of type refs, one that can include vars recursively and one that cannot.
|
|
210
263
|
# We only use the one with vars for classmethods and methods, and the other one for egg references as
|
|
211
264
|
# well as runtime values.
|
|
@@ -287,6 +340,11 @@ class ClassMethodRef:
|
|
|
287
340
|
method_name: str
|
|
288
341
|
|
|
289
342
|
|
|
343
|
+
@dataclass(frozen=True)
|
|
344
|
+
class InitRef:
|
|
345
|
+
class_name: str
|
|
346
|
+
|
|
347
|
+
|
|
290
348
|
@dataclass(frozen=True)
|
|
291
349
|
class ClassVariableRef:
|
|
292
350
|
class_name: str
|
|
@@ -299,7 +357,9 @@ class PropertyRef:
|
|
|
299
357
|
property_name: str
|
|
300
358
|
|
|
301
359
|
|
|
302
|
-
CallableRef: TypeAlias =
|
|
360
|
+
CallableRef: TypeAlias = (
|
|
361
|
+
FunctionRef | ConstantRef | MethodRef | ClassMethodRef | InitRef | ClassVariableRef | PropertyRef
|
|
362
|
+
)
|
|
303
363
|
|
|
304
364
|
|
|
305
365
|
##
|
|
@@ -316,10 +376,12 @@ class RelationDecl:
|
|
|
316
376
|
|
|
317
377
|
def to_function_decl(self) -> FunctionDecl:
|
|
318
378
|
return FunctionDecl(
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
379
|
+
FunctionSignature(
|
|
380
|
+
arg_types=tuple(a.to_var() for a in self.arg_types),
|
|
381
|
+
arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
|
|
382
|
+
arg_defaults=self.arg_defaults,
|
|
383
|
+
return_type=TypeRefWithVars("Unit"),
|
|
384
|
+
),
|
|
323
385
|
egg_name=self.egg_name,
|
|
324
386
|
default=LitDecl(None),
|
|
325
387
|
)
|
|
@@ -336,25 +398,40 @@ class ConstantDecl:
|
|
|
336
398
|
|
|
337
399
|
def to_function_decl(self) -> FunctionDecl:
|
|
338
400
|
return FunctionDecl(
|
|
339
|
-
|
|
340
|
-
arg_names=(),
|
|
341
|
-
arg_defaults=(),
|
|
342
|
-
return_type=self.type_ref.to_var(),
|
|
401
|
+
FunctionSignature(return_type=self.type_ref.to_var()),
|
|
343
402
|
egg_name=self.egg_name,
|
|
344
403
|
)
|
|
345
404
|
|
|
346
405
|
|
|
406
|
+
# special cases for partial function creation and application, which cannot use the normal python rules
|
|
407
|
+
SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
|
|
408
|
+
|
|
409
|
+
|
|
347
410
|
@dataclass(frozen=True)
|
|
348
|
-
class
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
arg_names: tuple[str, ...]
|
|
411
|
+
class FunctionSignature:
|
|
412
|
+
arg_types: tuple[TypeOrVarRef, ...] = ()
|
|
413
|
+
arg_names: tuple[str, ...] = ()
|
|
352
414
|
# List of defaults. None for any arg which doesn't have one.
|
|
353
|
-
arg_defaults: tuple[ExprDecl | None, ...]
|
|
415
|
+
arg_defaults: tuple[ExprDecl | None, ...] = ()
|
|
354
416
|
# If None, then the first arg is mutated and returned
|
|
355
|
-
return_type: TypeOrVarRef | None
|
|
417
|
+
return_type: TypeOrVarRef | None = None
|
|
356
418
|
var_arg_type: TypeOrVarRef | None = None
|
|
357
419
|
|
|
420
|
+
@property
|
|
421
|
+
def semantic_return_type(self) -> TypeOrVarRef:
|
|
422
|
+
"""
|
|
423
|
+
The type that is returned by the function, which wil be in the first arg if it mutates it.
|
|
424
|
+
"""
|
|
425
|
+
return self.return_type or self.arg_types[0]
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def mutates(self) -> bool:
|
|
429
|
+
return self.return_type is None
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
@dataclass(frozen=True)
|
|
433
|
+
class FunctionDecl:
|
|
434
|
+
signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
|
|
358
435
|
# Egg params
|
|
359
436
|
builtin: bool = False
|
|
360
437
|
egg_name: str | None = None
|
|
@@ -367,17 +444,6 @@ class FunctionDecl:
|
|
|
367
444
|
def to_function_decl(self) -> FunctionDecl:
|
|
368
445
|
return self
|
|
369
446
|
|
|
370
|
-
@property
|
|
371
|
-
def semantic_return_type(self) -> TypeOrVarRef:
|
|
372
|
-
"""
|
|
373
|
-
The type that is returned by the function, which wil be in the first arg if it mutates it.
|
|
374
|
-
"""
|
|
375
|
-
return self.return_type or self.arg_types[0]
|
|
376
|
-
|
|
377
|
-
@property
|
|
378
|
-
def mutates(self) -> bool:
|
|
379
|
-
return self.return_type is None
|
|
380
|
-
|
|
381
447
|
|
|
382
448
|
CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
|
|
383
449
|
|
|
@@ -445,7 +511,7 @@ class CallDecl:
|
|
|
445
511
|
bound_tp_params: tuple[JustTypeRef, ...] | None = None
|
|
446
512
|
|
|
447
513
|
def __post_init__(self) -> None:
|
|
448
|
-
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
|
|
514
|
+
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
|
|
449
515
|
msg = "Cannot bind type parameters to a non-class method callable."
|
|
450
516
|
raise ValueError(msg)
|
|
451
517
|
|
|
@@ -463,7 +529,20 @@ class CallDecl:
|
|
|
463
529
|
return hash(self) == hash(other)
|
|
464
530
|
|
|
465
531
|
|
|
466
|
-
|
|
532
|
+
@dataclass(frozen=True)
|
|
533
|
+
class PartialCallDecl:
|
|
534
|
+
"""
|
|
535
|
+
A partially applied function aka a function sort.
|
|
536
|
+
|
|
537
|
+
Note it does not need to have any args, in which case it's just a function pointer.
|
|
538
|
+
|
|
539
|
+
Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
|
|
540
|
+
"""
|
|
541
|
+
|
|
542
|
+
call: CallDecl
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
|
|
467
546
|
|
|
468
547
|
|
|
469
548
|
@dataclass(frozen=True)
|
|
@@ -603,7 +682,13 @@ class RuleDecl:
|
|
|
603
682
|
name: str | None
|
|
604
683
|
|
|
605
684
|
|
|
606
|
-
|
|
685
|
+
@dataclass(frozen=True)
|
|
686
|
+
class DefaultRewriteDecl:
|
|
687
|
+
ref: CallableRef
|
|
688
|
+
expr: ExprDecl
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
|
|
607
692
|
|
|
608
693
|
|
|
609
694
|
@dataclass(frozen=True)
|