egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/declarations.py CHANGED
@@ -13,10 +13,11 @@ from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_c
13
13
  from typing_extensions import Self, assert_never
14
14
 
15
15
  if TYPE_CHECKING:
16
- from collections.abc import Callable, Iterable
16
+ from collections.abc import Callable, Iterable, Mapping
17
17
 
18
18
 
19
19
  __all__ = [
20
+ "replace_typed_expr",
20
21
  "Declarations",
21
22
  "DeclerationsLike",
22
23
  "DelayedDeclerations",
@@ -29,6 +30,7 @@ __all__ = [
29
30
  "MethodRef",
30
31
  "ClassMethodRef",
31
32
  "FunctionRef",
33
+ "UnnamedFunctionRef",
32
34
  "ConstantRef",
33
35
  "ClassVariableRef",
34
36
  "PropertyRef",
@@ -73,6 +75,7 @@ __all__ = [
73
75
  "FunctionSignature",
74
76
  "DefaultRewriteDecl",
75
77
  "InitRef",
78
+ "HasDeclerations",
76
79
  ]
77
80
 
78
81
 
@@ -82,12 +85,13 @@ class DelayedDeclerations:
82
85
 
83
86
  @property
84
87
  def __egg_decls__(self) -> Declarations:
88
+ thunk = self.__egg_decls_thunk__
85
89
  try:
86
- return self.__egg_decls_thunk__()
90
+ return thunk()
87
91
  # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
88
92
  # instead raise explicitly
89
93
  except AttributeError as err:
90
- msg = "Error resolving declerations"
94
+ msg = f"Cannot resolve declerations for {self}"
91
95
  raise RuntimeError(msg) from err
92
96
 
93
97
 
@@ -116,6 +120,7 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D
116
120
 
117
121
  @dataclass
118
122
  class Declarations:
123
+ _unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
119
124
  _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
120
125
  _constants: dict[str, ConstantDecl] = field(default_factory=dict)
121
126
  _classes: dict[str, ClassDecl] = field(default_factory=dict)
@@ -192,6 +197,8 @@ class Declarations:
192
197
  init_fn = self._classes[class_name].init
193
198
  assert init_fn
194
199
  return init_fn
200
+ case UnnamedFunctionRef():
201
+ return ref.to_function_decl()
195
202
  assert_never(ref)
196
203
 
197
204
  def set_function_decl(
@@ -318,6 +325,37 @@ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
318
325
  ##
319
326
 
320
327
 
328
+ @dataclass(frozen=True)
329
+ class UnnamedFunctionRef:
330
+ """
331
+ A reference to a function that doesn't have a name, but does have a body.
332
+ """
333
+
334
+ # tuple of var arg names and their types
335
+ args: tuple[TypedExprDecl, ...]
336
+ res: TypedExprDecl
337
+
338
+ def to_function_decl(self) -> FunctionDecl:
339
+ arg_types = []
340
+ arg_names = []
341
+ for a in self.args:
342
+ arg_types.append(a.tp.to_var())
343
+ assert isinstance(a.expr, VarDecl)
344
+ arg_names.append(a.expr.name)
345
+ return FunctionDecl(
346
+ FunctionSignature(
347
+ arg_types=tuple(arg_types),
348
+ arg_names=tuple(arg_names),
349
+ arg_defaults=(None,) * len(self.args),
350
+ return_type=self.res.tp.to_var(),
351
+ ),
352
+ )
353
+
354
+ @property
355
+ def egg_name(self) -> None | str:
356
+ return None
357
+
358
+
321
359
  @dataclass(frozen=True)
322
360
  class FunctionRef:
323
361
  name: str
@@ -358,7 +396,14 @@ class PropertyRef:
358
396
 
359
397
 
360
398
  CallableRef: TypeAlias = (
361
- FunctionRef | ConstantRef | MethodRef | ClassMethodRef | InitRef | ClassVariableRef | PropertyRef
399
+ FunctionRef
400
+ | ConstantRef
401
+ | MethodRef
402
+ | ClassMethodRef
403
+ | InitRef
404
+ | ClassVariableRef
405
+ | PropertyRef
406
+ | UnnamedFunctionRef
362
407
  )
363
408
 
364
409
 
@@ -455,6 +500,8 @@ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
455
500
  @dataclass(frozen=True)
456
501
  class VarDecl:
457
502
  name: str
503
+ # Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
504
+ is_let: bool
458
505
 
459
506
 
460
507
  @dataclass(frozen=True)
@@ -561,6 +608,37 @@ class TypedExprDecl:
561
608
  return l
562
609
 
563
610
 
611
+ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
612
+ """
613
+ Replace all the typed expressions in the given typed expression with the replacements.
614
+ """
615
+ # keep track of the traversed expressions for memoization
616
+ traversed: dict[TypedExprDecl, TypedExprDecl] = {}
617
+
618
+ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
619
+ if typed_expr in traversed:
620
+ return traversed[typed_expr]
621
+ if typed_expr in replacements:
622
+ res = replacements[typed_expr]
623
+ else:
624
+ match typed_expr.expr:
625
+ case CallDecl(callable, args, bound_tp_params) | PartialCallDecl(
626
+ CallDecl(callable, args, bound_tp_params)
627
+ ):
628
+ new_args = tuple(_inner(a) for a in args)
629
+ call_decl = CallDecl(callable, new_args, bound_tp_params)
630
+ res = TypedExprDecl(
631
+ typed_expr.tp,
632
+ call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
633
+ )
634
+ case _:
635
+ res = typed_expr
636
+ traversed[typed_expr] = res
637
+ return res
638
+
639
+ return _inner(typed_expr)
640
+
641
+
564
642
  ##
565
643
  # Schedules
566
644
  ##