egglog 7.1.0__cp310-none-win_amd64.whl → 8.0.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/declarations.py CHANGED
@@ -13,10 +13,11 @@ from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_c
13
13
  from typing_extensions import Self, assert_never
14
14
 
15
15
  if TYPE_CHECKING:
16
- from collections.abc import Callable, Iterable
16
+ from collections.abc import Callable, Iterable, Mapping
17
17
 
18
18
 
19
19
  __all__ = [
20
+ "replace_typed_expr",
20
21
  "Declarations",
21
22
  "DeclerationsLike",
22
23
  "DelayedDeclerations",
@@ -29,6 +30,7 @@ __all__ = [
29
30
  "MethodRef",
30
31
  "ClassMethodRef",
31
32
  "FunctionRef",
33
+ "UnnamedFunctionRef",
32
34
  "ConstantRef",
33
35
  "ClassVariableRef",
34
36
  "PropertyRef",
@@ -71,6 +73,9 @@ __all__ = [
71
73
  "CommandDecl",
72
74
  "SpecialFunctions",
73
75
  "FunctionSignature",
76
+ "DefaultRewriteDecl",
77
+ "InitRef",
78
+ "HasDeclerations",
74
79
  ]
75
80
 
76
81
 
@@ -80,7 +85,14 @@ class DelayedDeclerations:
80
85
 
81
86
  @property
82
87
  def __egg_decls__(self) -> Declarations:
83
- return self.__egg_decls_thunk__()
88
+ thunk = self.__egg_decls_thunk__
89
+ try:
90
+ return thunk()
91
+ # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
92
+ # instead raise explicitly
93
+ except AttributeError as err:
94
+ msg = f"Cannot resolve declerations for {self}"
95
+ raise RuntimeError(msg) from err
84
96
 
85
97
 
86
98
  @runtime_checkable
@@ -108,11 +120,18 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D
108
120
 
109
121
  @dataclass
110
122
  class Declarations:
123
+ _unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
111
124
  _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
112
125
  _constants: dict[str, ConstantDecl] = field(default_factory=dict)
113
126
  _classes: dict[str, ClassDecl] = field(default_factory=dict)
114
127
  _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
115
128
 
129
+ @property
130
+ def default_ruleset(self) -> RulesetDecl:
131
+ ruleset = self._rulesets[""]
132
+ assert isinstance(ruleset, RulesetDecl)
133
+ return ruleset
134
+
116
135
  @classmethod
117
136
  def create(cls, *others: DeclerationsLike) -> Declarations:
118
137
  others = upcast_declerations(others)
@@ -127,7 +146,7 @@ class Declarations:
127
146
 
128
147
  def copy(self) -> Declarations:
129
148
  new = Declarations()
130
- new |= self
149
+ self.update_other(new)
131
150
  return new
132
151
 
133
152
  def update(self, *others: DeclerationsLike) -> None:
@@ -154,9 +173,13 @@ class Declarations:
154
173
  other._functions |= self._functions
155
174
  other._classes |= self._classes
156
175
  other._constants |= self._constants
176
+ # Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
177
+ # is added to functions.
178
+ combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
157
179
  other._rulesets |= self._rulesets
180
+ other._rulesets[""] = RulesetDecl(list(combined_default_rules))
158
181
 
159
- def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
182
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
160
183
  match ref:
161
184
  case FunctionRef(name):
162
185
  return self._functions[name]
@@ -170,8 +193,31 @@ class Declarations:
170
193
  return self._classes[class_name].class_methods[name]
171
194
  case PropertyRef(class_name, property_name):
172
195
  return self._classes[class_name].properties[property_name]
196
+ case InitRef(class_name):
197
+ init_fn = self._classes[class_name].init
198
+ assert init_fn
199
+ return init_fn
200
+ case UnnamedFunctionRef():
201
+ return ref.to_function_decl()
173
202
  assert_never(ref)
174
203
 
204
+ def set_function_decl(
205
+ self, ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, decl: FunctionDecl
206
+ ) -> None:
207
+ match ref:
208
+ case FunctionRef(name):
209
+ self._functions[name] = decl
210
+ case MethodRef(class_name, method_name):
211
+ self._classes[class_name].methods[method_name] = decl
212
+ case ClassMethodRef(class_name, name):
213
+ self._classes[class_name].class_methods[name] = decl
214
+ case PropertyRef(class_name, property_name):
215
+ self._classes[class_name].properties[property_name] = decl
216
+ case InitRef(class_name):
217
+ self._classes[class_name].init = decl
218
+ case _:
219
+ assert_never(ref)
220
+
175
221
  def has_method(self, class_name: str, method_name: str) -> bool | None:
176
222
  """
177
223
  Returns whether the given class has the given method, or None if we cant find the class.
@@ -183,12 +229,20 @@ class Declarations:
183
229
  def get_class_decl(self, name: str) -> ClassDecl:
184
230
  return self._classes[name]
185
231
 
232
+ def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
233
+ """
234
+ Returns a class reference with type parameters, if the class is paramaterized.
235
+ """
236
+ type_vars = self._classes[name].type_vars
237
+ return TypeRefWithVars(name, tuple(map(ClassTypeVarRef, type_vars)))
238
+
186
239
 
187
240
  @dataclass
188
241
  class ClassDecl:
189
242
  egg_name: str | None = None
190
243
  type_vars: tuple[str, ...] = ()
191
244
  builtin: bool = False
245
+ init: FunctionDecl | None = None
192
246
  class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
193
247
  # These have to be seperate from class_methods so that printing them can be done easily
194
248
  class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
@@ -271,6 +325,37 @@ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
271
325
  ##
272
326
 
273
327
 
328
+ @dataclass(frozen=True)
329
+ class UnnamedFunctionRef:
330
+ """
331
+ A reference to a function that doesn't have a name, but does have a body.
332
+ """
333
+
334
+ # tuple of var arg names and their types
335
+ args: tuple[TypedExprDecl, ...]
336
+ res: TypedExprDecl
337
+
338
+ def to_function_decl(self) -> FunctionDecl:
339
+ arg_types = []
340
+ arg_names = []
341
+ for a in self.args:
342
+ arg_types.append(a.tp.to_var())
343
+ assert isinstance(a.expr, VarDecl)
344
+ arg_names.append(a.expr.name)
345
+ return FunctionDecl(
346
+ FunctionSignature(
347
+ arg_types=tuple(arg_types),
348
+ arg_names=tuple(arg_names),
349
+ arg_defaults=(None,) * len(self.args),
350
+ return_type=self.res.tp.to_var(),
351
+ ),
352
+ )
353
+
354
+ @property
355
+ def egg_name(self) -> None | str:
356
+ return None
357
+
358
+
274
359
  @dataclass(frozen=True)
275
360
  class FunctionRef:
276
361
  name: str
@@ -293,6 +378,11 @@ class ClassMethodRef:
293
378
  method_name: str
294
379
 
295
380
 
381
+ @dataclass(frozen=True)
382
+ class InitRef:
383
+ class_name: str
384
+
385
+
296
386
  @dataclass(frozen=True)
297
387
  class ClassVariableRef:
298
388
  class_name: str
@@ -305,7 +395,16 @@ class PropertyRef:
305
395
  property_name: str
306
396
 
307
397
 
308
- CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef
398
+ CallableRef: TypeAlias = (
399
+ FunctionRef
400
+ | ConstantRef
401
+ | MethodRef
402
+ | ClassMethodRef
403
+ | InitRef
404
+ | ClassVariableRef
405
+ | PropertyRef
406
+ | UnnamedFunctionRef
407
+ )
309
408
 
310
409
 
311
410
  ##
@@ -378,7 +477,6 @@ class FunctionSignature:
378
477
  @dataclass(frozen=True)
379
478
  class FunctionDecl:
380
479
  signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
381
-
382
480
  # Egg params
383
481
  builtin: bool = False
384
482
  egg_name: str | None = None
@@ -402,6 +500,8 @@ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
402
500
  @dataclass(frozen=True)
403
501
  class VarDecl:
404
502
  name: str
503
+ # Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
504
+ is_let: bool
405
505
 
406
506
 
407
507
  @dataclass(frozen=True)
@@ -458,7 +558,7 @@ class CallDecl:
458
558
  bound_tp_params: tuple[JustTypeRef, ...] | None = None
459
559
 
460
560
  def __post_init__(self) -> None:
461
- if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
561
+ if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
462
562
  msg = "Cannot bind type parameters to a non-class method callable."
463
563
  raise ValueError(msg)
464
564
 
@@ -508,6 +608,38 @@ class TypedExprDecl:
508
608
  return l
509
609
 
510
610
 
611
+ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
612
+ """
613
+ Replace all the typed expressions in the given typed expression with the replacements.
614
+ """
615
+ # keep track of the traversed expressions for memoization
616
+ traversed: dict[TypedExprDecl, TypedExprDecl] = {}
617
+
618
+ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
619
+ if typed_expr in traversed:
620
+ return traversed[typed_expr]
621
+ if typed_expr in replacements:
622
+ res = replacements[typed_expr]
623
+ else:
624
+ match typed_expr.expr:
625
+ case (
626
+ CallDecl(callable, args, bound_tp_params)
627
+ | PartialCallDecl(CallDecl(callable, args, bound_tp_params))
628
+ ):
629
+ new_args = tuple(_inner(a) for a in args)
630
+ call_decl = CallDecl(callable, new_args, bound_tp_params)
631
+ res = TypedExprDecl(
632
+ typed_expr.tp,
633
+ call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
634
+ )
635
+ case _:
636
+ res = typed_expr
637
+ traversed[typed_expr] = res
638
+ return res
639
+
640
+ return _inner(typed_expr)
641
+
642
+
511
643
  ##
512
644
  # Schedules
513
645
  ##
@@ -629,7 +761,13 @@ class RuleDecl:
629
761
  name: str | None
630
762
 
631
763
 
632
- RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl
764
+ @dataclass(frozen=True)
765
+ class DefaultRewriteDecl:
766
+ ref: CallableRef
767
+ expr: ExprDecl
768
+
769
+
770
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
633
771
 
634
772
 
635
773
  @dataclass(frozen=True)