egglog 7.0.0__cp312-none-win_amd64.whl → 7.2.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Binary file
egglog/bindings.pyi CHANGED
@@ -494,6 +494,12 @@ class Relation:
494
494
  class PrintOverallStatistics:
495
495
  def __init__(self) -> None: ...
496
496
 
497
+ @final
498
+ class UnstableCombinedRuleset:
499
+ name: str
500
+ rulesets: list[str]
501
+ def __init__(self, name: str, rulesets: list[str]) -> None: ...
502
+
497
503
  _Command: TypeAlias = (
498
504
  SetOption
499
505
  | Datatype
@@ -521,6 +527,7 @@ _Command: TypeAlias = (
521
527
  | CheckProof
522
528
  | Relation
523
529
  | PrintOverallStatistics
530
+ | UnstableCombinedRuleset
524
531
  )
525
532
 
526
533
  def termdag_term_to_expr(termdag: TermDag, term: _Term) -> _Expr: ...
egglog/builtins.py CHANGED
@@ -5,10 +5,14 @@ Builtin sorts and function to egg.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union
8
+ from functools import partial
9
+ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, overload
10
+
11
+ from typing_extensions import TypeVarTuple, Unpack
9
12
 
10
13
  from .conversion import converter
11
14
  from .egraph import Expr, Unit, function, method
15
+ from .runtime import RuntimeFunction
12
16
 
13
17
  if TYPE_CHECKING:
14
18
  from collections.abc import Callable
@@ -31,6 +35,7 @@ __all__ = [
31
35
  "py_eval",
32
36
  "py_exec",
33
37
  "py_eval_fn",
38
+ "UnstableFn",
34
39
  ]
35
40
 
36
41
 
@@ -461,3 +466,38 @@ def py_exec(code: StringLike, globals: object = PyObject.dict(), locals: object
461
466
  """
462
467
  Copies the locals, execs the Python code, and returns the locals with any updates.
463
468
  """
469
+
470
+
471
+ TS = TypeVarTuple("TS")
472
+
473
+ T1 = TypeVar("T1")
474
+ T2 = TypeVar("T2")
475
+ T3 = TypeVar("T3")
476
+
477
+
478
+ class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
479
+ @overload
480
+ def __init__(self, f: Callable[[Unpack[TS]], T]) -> None: ...
481
+
482
+ @overload
483
+ def __init__(self, f: Callable[[T1, Unpack[TS]], T], _a: T1, /) -> None: ...
484
+
485
+ @overload
486
+ def __init__(self, f: Callable[[T1, T2, Unpack[TS]], T], _a: T1, _b: T2, /) -> None: ...
487
+
488
+ # Removing due to bug in MyPy
489
+ # https://github.com/python/mypy/issues/17212
490
+ # @overload
491
+ # def __init__(self, f: Callable[[T1, T2, T3, Unpack[TS]], T], _a: T1, _b: T2, _c: T3, /) -> None: ...
492
+
493
+ # etc, for partial application
494
+
495
+ @method(egg_fn="unstable-fn")
496
+ def __init__(self, f, *partial) -> None: ...
497
+
498
+ @method(egg_fn="unstable-app")
499
+ def __call__(self, *args: Unpack[TS]) -> T: ...
500
+
501
+
502
+ converter(RuntimeFunction, UnstableFn, UnstableFn)
503
+ converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
egglog/conversion.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, TypeVar, cast
4
+ from typing import TYPE_CHECKING, NewType, TypeVar, cast
5
5
 
6
6
  from .declarations import *
7
7
  from .pretty import *
@@ -16,7 +16,8 @@ if TYPE_CHECKING:
16
16
 
17
17
  __all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
18
18
  # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
19
- CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
19
+ TypeName = NewType("TypeName", str)
20
+ CONVERSIONS: dict[tuple[type | TypeName, TypeName], tuple[int, Callable]] = {}
20
21
  # Global declerations to store all convertable types so we can query if they have certain methods or not
21
22
  # Defer it as a thunk so we can register conversions without triggering type signature loading
22
23
  CONVERSIONS_DECLS: Callable[[], Declarations] = Thunk.value(Declarations())
@@ -34,12 +35,12 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
34
35
  Register a converter from some type to an egglog type.
35
36
  """
36
37
  to_type_name = process_tp(to_type)
37
- if not isinstance(to_type_name, JustTypeRef):
38
+ if not isinstance(to_type_name, str):
38
39
  raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
39
40
  _register_converter(process_tp(from_type), to_type_name, fn, cost)
40
41
 
41
42
 
42
- def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
43
+ def _register_converter(a: type | TypeName, b: TypeName, a_b: Callable, cost: int) -> None:
43
44
  """
44
45
  Registers a converter from some type to an egglog type, if not already registered.
45
46
 
@@ -94,14 +95,17 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
94
95
  return resolve_literal(tp.to_var(), source)
95
96
 
96
97
 
97
- def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
98
+ def process_tp(tp: type | RuntimeClass) -> TypeName | type:
98
99
  """
99
100
  Process a type before converting it, to add it to the global declerations and resolve to a ref.
100
101
  """
101
102
  global CONVERSIONS_DECLS
102
103
  if isinstance(tp, RuntimeClass):
103
104
  CONVERSIONS_DECLS = Thunk.fn(_combine_decls, CONVERSIONS_DECLS, tp)
104
- return tp.__egg_tp__.to_just()
105
+ egg_tp = tp.__egg_tp__
106
+ if egg_tp.args:
107
+ raise TypeError(f"Cannot register a converter for a generic type, got {tp}")
108
+ return TypeName(egg_tp.name)
105
109
  return tp
106
110
 
107
111
 
@@ -109,7 +113,7 @@ def _combine_decls(d: Callable[[], Declarations], x: HasDeclerations) -> Declara
109
113
  return Declarations.create(d(), x)
110
114
 
111
115
 
112
- def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
116
+ def min_convertable_tp(a: object, b: object, name: str) -> TypeName:
113
117
  """
114
118
  Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
115
119
  """
@@ -117,14 +121,14 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
117
121
  a_tp = _get_tp(a)
118
122
  b_tp = _get_tp(b)
119
123
  a_converts_to = {
120
- to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
124
+ to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to, name)
121
125
  }
122
126
  b_converts_to = {
123
- to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
127
+ to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to, name)
124
128
  }
125
- if isinstance(a_tp, JustTypeRef):
129
+ if isinstance(a_tp, str):
126
130
  a_converts_to[a_tp] = 0
127
- if isinstance(b_tp, JustTypeRef):
131
+ if isinstance(b_tp, str):
128
132
  b_converts_to[b_tp] = 0
129
133
  common = set(a_converts_to) & set(b_converts_to)
130
134
  if not common:
@@ -143,28 +147,29 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
143
147
  try:
144
148
  tp_just = tp.to_just()
145
149
  except NotImplementedError:
146
- # If this is a var, it has to be a runtime exprssions
150
+ # If this is a var, it has to be a runtime expession
147
151
  assert isinstance(arg, RuntimeExpr)
148
152
  return arg
149
- if arg_type == tp_just:
153
+ tp_name = TypeName(tp_just.name)
154
+ if arg_type == tp_name:
150
155
  # If the type is an egg type, it has to be a runtime expr
151
156
  assert isinstance(arg, RuntimeExpr)
152
157
  return arg
153
158
  # Try all parent types as well, if we are converting from a Python type
154
159
  for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
155
160
  try:
156
- fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1]
161
+ fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
157
162
  except KeyError:
158
163
  continue
159
164
  break
160
165
  else:
161
- raise ConvertError(f"Cannot convert {arg_type} to {tp_just}")
166
+ raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
162
167
  return fn(arg)
163
168
 
164
169
 
165
- def _get_tp(x: object) -> JustTypeRef | type:
170
+ def _get_tp(x: object) -> TypeName | type:
166
171
  if isinstance(x, RuntimeExpr):
167
- return x.__egg_typed_expr__.tp
172
+ return TypeName(x.__egg_typed_expr__.tp.name)
168
173
  tp = type(x)
169
174
  # If this value has a custom metaclass, let's use that as our index instead of the type
170
175
  if type(tp) != type:
egglog/declarations.py CHANGED
@@ -39,6 +39,7 @@ __all__ = [
39
39
  "CallableDecl",
40
40
  "VarDecl",
41
41
  "PyObjectDecl",
42
+ "PartialCallDecl",
42
43
  "LitType",
43
44
  "LitDecl",
44
45
  "CallDecl",
@@ -46,6 +47,7 @@ __all__ = [
46
47
  "TypedExprDecl",
47
48
  "ClassDecl",
48
49
  "RulesetDecl",
50
+ "CombinedRulesetDecl",
49
51
  "SaturateDecl",
50
52
  "RepeatDecl",
51
53
  "SequenceDecl",
@@ -67,6 +69,10 @@ __all__ = [
67
69
  "RewriteOrRuleDecl",
68
70
  "ActionCommandDecl",
69
71
  "CommandDecl",
72
+ "SpecialFunctions",
73
+ "FunctionSignature",
74
+ "DefaultRewriteDecl",
75
+ "InitRef",
70
76
  ]
71
77
 
72
78
 
@@ -76,7 +82,13 @@ class DelayedDeclerations:
76
82
 
77
83
  @property
78
84
  def __egg_decls__(self) -> Declarations:
79
- return self.__egg_decls_thunk__()
85
+ try:
86
+ return self.__egg_decls_thunk__()
87
+ # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
88
+ # instead raise explicitly
89
+ except AttributeError as err:
90
+ msg = "Error resolving declerations"
91
+ raise RuntimeError(msg) from err
80
92
 
81
93
 
82
94
  @runtime_checkable
@@ -88,9 +100,6 @@ class HasDeclerations(Protocol):
88
100
  DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
89
101
 
90
102
 
91
- # TODO: Make all ClassDecls take deferred type refs, which return new decls when resolving.
92
-
93
-
94
103
  def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
95
104
  d = []
96
105
  for l in declerations_like:
@@ -110,7 +119,13 @@ class Declarations:
110
119
  _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
111
120
  _constants: dict[str, ConstantDecl] = field(default_factory=dict)
112
121
  _classes: dict[str, ClassDecl] = field(default_factory=dict)
113
- _rulesets: dict[str, RulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
122
+ _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
123
+
124
+ @property
125
+ def default_ruleset(self) -> RulesetDecl:
126
+ ruleset = self._rulesets[""]
127
+ assert isinstance(ruleset, RulesetDecl)
128
+ return ruleset
114
129
 
115
130
  @classmethod
116
131
  def create(cls, *others: DeclerationsLike) -> Declarations:
@@ -126,7 +141,7 @@ class Declarations:
126
141
 
127
142
  def copy(self) -> Declarations:
128
143
  new = Declarations()
129
- new |= self
144
+ self.update_other(new)
130
145
  return new
131
146
 
132
147
  def update(self, *others: DeclerationsLike) -> None:
@@ -153,9 +168,13 @@ class Declarations:
153
168
  other._functions |= self._functions
154
169
  other._classes |= self._classes
155
170
  other._constants |= self._constants
171
+ # Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
172
+ # is added to functions.
173
+ combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
156
174
  other._rulesets |= self._rulesets
175
+ other._rulesets[""] = RulesetDecl(list(combined_default_rules))
157
176
 
158
- def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
177
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
159
178
  match ref:
160
179
  case FunctionRef(name):
161
180
  return self._functions[name]
@@ -169,8 +188,29 @@ class Declarations:
169
188
  return self._classes[class_name].class_methods[name]
170
189
  case PropertyRef(class_name, property_name):
171
190
  return self._classes[class_name].properties[property_name]
191
+ case InitRef(class_name):
192
+ init_fn = self._classes[class_name].init
193
+ assert init_fn
194
+ return init_fn
172
195
  assert_never(ref)
173
196
 
197
+ def set_function_decl(
198
+ self, ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, decl: FunctionDecl
199
+ ) -> None:
200
+ match ref:
201
+ case FunctionRef(name):
202
+ self._functions[name] = decl
203
+ case MethodRef(class_name, method_name):
204
+ self._classes[class_name].methods[method_name] = decl
205
+ case ClassMethodRef(class_name, name):
206
+ self._classes[class_name].class_methods[name] = decl
207
+ case PropertyRef(class_name, property_name):
208
+ self._classes[class_name].properties[property_name] = decl
209
+ case InitRef(class_name):
210
+ self._classes[class_name].init = decl
211
+ case _:
212
+ assert_never(ref)
213
+
174
214
  def has_method(self, class_name: str, method_name: str) -> bool | None:
175
215
  """
176
216
  Returns whether the given class has the given method, or None if we cant find the class.
@@ -182,12 +222,20 @@ class Declarations:
182
222
  def get_class_decl(self, name: str) -> ClassDecl:
183
223
  return self._classes[name]
184
224
 
225
+ def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
226
+ """
227
+ Returns a class reference with type parameters, if the class is paramaterized.
228
+ """
229
+ type_vars = self._classes[name].type_vars
230
+ return TypeRefWithVars(name, tuple(map(ClassTypeVarRef, type_vars)))
231
+
185
232
 
186
233
  @dataclass
187
234
  class ClassDecl:
188
235
  egg_name: str | None = None
189
236
  type_vars: tuple[str, ...] = ()
190
237
  builtin: bool = False
238
+ init: FunctionDecl | None = None
191
239
  class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
192
240
  # These have to be seperate from class_methods so that printing them can be done easily
193
241
  class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
@@ -196,7 +244,7 @@ class ClassDecl:
196
244
  preserved_methods: dict[str, Callable] = field(default_factory=dict)
197
245
 
198
246
 
199
- @dataclass
247
+ @dataclass(frozen=True)
200
248
  class RulesetDecl:
201
249
  rules: list[RewriteOrRuleDecl]
202
250
 
@@ -206,6 +254,11 @@ class RulesetDecl:
206
254
  return hash((type(self), tuple(self.rules)))
207
255
 
208
256
 
257
+ @dataclass(frozen=True)
258
+ class CombinedRulesetDecl:
259
+ rulesets: tuple[str, ...]
260
+
261
+
209
262
  # Have two different types of type refs, one that can include vars recursively and one that cannot.
210
263
  # We only use the one with vars for classmethods and methods, and the other one for egg references as
211
264
  # well as runtime values.
@@ -287,6 +340,11 @@ class ClassMethodRef:
287
340
  method_name: str
288
341
 
289
342
 
343
+ @dataclass(frozen=True)
344
+ class InitRef:
345
+ class_name: str
346
+
347
+
290
348
  @dataclass(frozen=True)
291
349
  class ClassVariableRef:
292
350
  class_name: str
@@ -299,7 +357,9 @@ class PropertyRef:
299
357
  property_name: str
300
358
 
301
359
 
302
- CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef
360
+ CallableRef: TypeAlias = (
361
+ FunctionRef | ConstantRef | MethodRef | ClassMethodRef | InitRef | ClassVariableRef | PropertyRef
362
+ )
303
363
 
304
364
 
305
365
  ##
@@ -316,10 +376,12 @@ class RelationDecl:
316
376
 
317
377
  def to_function_decl(self) -> FunctionDecl:
318
378
  return FunctionDecl(
319
- arg_types=tuple(a.to_var() for a in self.arg_types),
320
- arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
321
- arg_defaults=self.arg_defaults,
322
- return_type=TypeRefWithVars("Unit"),
379
+ FunctionSignature(
380
+ arg_types=tuple(a.to_var() for a in self.arg_types),
381
+ arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
382
+ arg_defaults=self.arg_defaults,
383
+ return_type=TypeRefWithVars("Unit"),
384
+ ),
323
385
  egg_name=self.egg_name,
324
386
  default=LitDecl(None),
325
387
  )
@@ -336,25 +398,40 @@ class ConstantDecl:
336
398
 
337
399
  def to_function_decl(self) -> FunctionDecl:
338
400
  return FunctionDecl(
339
- arg_types=(),
340
- arg_names=(),
341
- arg_defaults=(),
342
- return_type=self.type_ref.to_var(),
401
+ FunctionSignature(return_type=self.type_ref.to_var()),
343
402
  egg_name=self.egg_name,
344
403
  )
345
404
 
346
405
 
406
+ # special cases for partial function creation and application, which cannot use the normal python rules
407
+ SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
408
+
409
+
347
410
  @dataclass(frozen=True)
348
- class FunctionDecl:
349
- # All args are delayed except for relations converted to function decls
350
- arg_types: tuple[TypeOrVarRef, ...]
351
- arg_names: tuple[str, ...]
411
+ class FunctionSignature:
412
+ arg_types: tuple[TypeOrVarRef, ...] = ()
413
+ arg_names: tuple[str, ...] = ()
352
414
  # List of defaults. None for any arg which doesn't have one.
353
- arg_defaults: tuple[ExprDecl | None, ...]
415
+ arg_defaults: tuple[ExprDecl | None, ...] = ()
354
416
  # If None, then the first arg is mutated and returned
355
- return_type: TypeOrVarRef | None
417
+ return_type: TypeOrVarRef | None = None
356
418
  var_arg_type: TypeOrVarRef | None = None
357
419
 
420
+ @property
421
+ def semantic_return_type(self) -> TypeOrVarRef:
422
+ """
423
+ The type that is returned by the function, which wil be in the first arg if it mutates it.
424
+ """
425
+ return self.return_type or self.arg_types[0]
426
+
427
+ @property
428
+ def mutates(self) -> bool:
429
+ return self.return_type is None
430
+
431
+
432
+ @dataclass(frozen=True)
433
+ class FunctionDecl:
434
+ signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
358
435
  # Egg params
359
436
  builtin: bool = False
360
437
  egg_name: str | None = None
@@ -367,17 +444,6 @@ class FunctionDecl:
367
444
  def to_function_decl(self) -> FunctionDecl:
368
445
  return self
369
446
 
370
- @property
371
- def semantic_return_type(self) -> TypeOrVarRef:
372
- """
373
- The type that is returned by the function, which wil be in the first arg if it mutates it.
374
- """
375
- return self.return_type or self.arg_types[0]
376
-
377
- @property
378
- def mutates(self) -> bool:
379
- return self.return_type is None
380
-
381
447
 
382
448
  CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
383
449
 
@@ -445,7 +511,7 @@ class CallDecl:
445
511
  bound_tp_params: tuple[JustTypeRef, ...] | None = None
446
512
 
447
513
  def __post_init__(self) -> None:
448
- if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
514
+ if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
449
515
  msg = "Cannot bind type parameters to a non-class method callable."
450
516
  raise ValueError(msg)
451
517
 
@@ -463,7 +529,20 @@ class CallDecl:
463
529
  return hash(self) == hash(other)
464
530
 
465
531
 
466
- ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl
532
+ @dataclass(frozen=True)
533
+ class PartialCallDecl:
534
+ """
535
+ A partially applied function aka a function sort.
536
+
537
+ Note it does not need to have any args, in which case it's just a function pointer.
538
+
539
+ Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
540
+ """
541
+
542
+ call: CallDecl
543
+
544
+
545
+ ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
467
546
 
468
547
 
469
548
  @dataclass(frozen=True)
@@ -603,7 +682,13 @@ class RuleDecl:
603
682
  name: str | None
604
683
 
605
684
 
606
- RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl
685
+ @dataclass(frozen=True)
686
+ class DefaultRewriteDecl:
687
+ ref: CallableRef
688
+ expr: ExprDecl
689
+
690
+
691
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
607
692
 
608
693
 
609
694
  @dataclass(frozen=True)