egglog 6.0.0__cp311-none-win_amd64.whl → 6.1.0__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -170,13 +170,11 @@ class _BaseModule:
170
170
 
171
171
  @deprecated("Remove this decorator and move the egg_sort to the class statement, i.e. E(Expr, egg_sort='MySort').")
172
172
  @overload
173
- def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]:
174
- ...
173
+ def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]: ...
175
174
 
176
175
  @deprecated("Remove this decorator. Simply subclassing Expr is enough now.")
177
176
  @overload
178
- def class_(self, cls: TYPE, /) -> TYPE:
179
- ...
177
+ def class_(self, cls: TYPE, /) -> TYPE: ...
180
178
 
181
179
  def class_(self, *args, **kwargs) -> Any:
182
180
  """
@@ -201,8 +199,7 @@ class _BaseModule:
201
199
  self,
202
200
  *,
203
201
  preserve: Literal[True],
204
- ) -> Callable[[CALLABLE], CALLABLE]:
205
- ...
202
+ ) -> Callable[[CALLABLE], CALLABLE]: ...
206
203
 
207
204
  @overload
208
205
  def method(
@@ -214,8 +211,7 @@ class _BaseModule:
214
211
  on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None,
215
212
  mutates_self: bool = False,
216
213
  unextractable: bool = False,
217
- ) -> Callable[[CALLABLE], CALLABLE]:
218
- ...
214
+ ) -> Callable[[CALLABLE], CALLABLE]: ...
219
215
 
220
216
  @overload
221
217
  def method(
@@ -228,8 +224,7 @@ class _BaseModule:
228
224
  on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
229
225
  mutates_self: bool = False,
230
226
  unextractable: bool = False,
231
- ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
232
- ...
227
+ ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
233
228
 
234
229
  @deprecated("Use top level method function instead")
235
230
  def method(
@@ -249,8 +244,7 @@ class _BaseModule:
249
244
  )
250
245
 
251
246
  @overload
252
- def function(self, fn: CALLABLE, /) -> CALLABLE:
253
- ...
247
+ def function(self, fn: CALLABLE, /) -> CALLABLE: ...
254
248
 
255
249
  @overload
256
250
  def function(
@@ -262,8 +256,7 @@ class _BaseModule:
262
256
  on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None,
263
257
  mutates_first_arg: bool = False,
264
258
  unextractable: bool = False,
265
- ) -> Callable[[CALLABLE], CALLABLE]:
266
- ...
259
+ ) -> Callable[[CALLABLE], CALLABLE]: ...
267
260
 
268
261
  @overload
269
262
  def function(
@@ -276,8 +269,7 @@ class _BaseModule:
276
269
  on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
277
270
  mutates_first_arg: bool = False,
278
271
  unextractable: bool = False,
279
- ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
280
- ...
272
+ ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
281
273
 
282
274
  @deprecated("Use top level function `function` instead")
283
275
  def function(self, *args, **kwargs) -> Any:
@@ -300,24 +292,19 @@ class _BaseModule:
300
292
  @overload
301
293
  def relation(
302
294
  self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], /
303
- ) -> Callable[[E1, E2, E3, E4], Unit]:
304
- ...
295
+ ) -> Callable[[E1, E2, E3, E4], Unit]: ...
305
296
 
306
297
  @overload
307
- def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]:
308
- ...
298
+ def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: ...
309
299
 
310
300
  @overload
311
- def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]:
312
- ...
301
+ def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: ...
313
302
 
314
303
  @overload
315
- def relation(self, name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]:
316
- ...
304
+ def relation(self, name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: ...
317
305
 
318
306
  @overload
319
- def relation(self, name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]:
320
- ...
307
+ def relation(self, name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]: ...
321
308
 
322
309
  @deprecated("Use top level relation function instead")
323
310
  def relation(self, name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..., Unit]:
@@ -362,8 +349,7 @@ class _BaseModule:
362
349
  def method(
363
350
  *,
364
351
  preserve: Literal[True],
365
- ) -> Callable[[CALLABLE], CALLABLE]:
366
- ...
352
+ ) -> Callable[[CALLABLE], CALLABLE]: ...
367
353
 
368
354
 
369
355
  # We have to seperate method/function overloads for those that use the T params and those that don't
@@ -380,8 +366,7 @@ def method(
380
366
  on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None,
381
367
  mutates_self: bool = False,
382
368
  unextractable: bool = False,
383
- ) -> Callable[[CALLABLE], CALLABLE]:
384
- ...
369
+ ) -> Callable[[CALLABLE], CALLABLE]: ...
385
370
 
386
371
 
387
372
  @overload
@@ -394,8 +379,7 @@ def method(
394
379
  on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
395
380
  mutates_self: bool = False,
396
381
  unextractable: bool = False,
397
- ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
398
- ...
382
+ ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
399
383
 
400
384
 
401
385
  def method(
@@ -494,11 +478,11 @@ class _ClassDeclerationsConstructor:
494
478
  locals = self.frame.f_locals.copy()
495
479
  locals[self.cls_name] = self.current_cls
496
480
  for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items():
497
- if v.__origin__ == ClassVar:
481
+ if getattr(v, "__origin__", None) == ClassVar:
498
482
  (inner_tp,) = v.__args__
499
483
  _register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None)
500
484
  else:
501
- msg = "The only supported annotations on class attributes are class vars"
485
+ msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
502
486
  raise NotImplementedError(msg)
503
487
 
504
488
  # Then register each of its methods
@@ -565,8 +549,7 @@ class _ClassDeclerationsConstructor:
565
549
 
566
550
 
567
551
  @overload
568
- def function(fn: CALLABLE, /) -> CALLABLE:
569
- ...
552
+ def function(fn: CALLABLE, /) -> CALLABLE: ...
570
553
 
571
554
 
572
555
  @overload
@@ -579,8 +562,7 @@ def function(
579
562
  mutates_first_arg: bool = False,
580
563
  unextractable: bool = False,
581
564
  builtin: bool = False,
582
- ) -> Callable[[CALLABLE], CALLABLE]:
583
- ...
565
+ ) -> Callable[[CALLABLE], CALLABLE]: ...
584
566
 
585
567
 
586
568
  @overload
@@ -593,8 +575,7 @@ def function(
593
575
  on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
594
576
  mutates_first_arg: bool = False,
595
577
  unextractable: bool = False,
596
- ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
597
- ...
578
+ ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
598
579
 
599
580
 
600
581
  def function(*args, **kwargs) -> Any:
@@ -760,28 +741,23 @@ def _register_function(
760
741
  @overload
761
742
  def relation(
762
743
  name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], /
763
- ) -> Callable[[E1, E2, E3, E4], Unit]:
764
- ...
744
+ ) -> Callable[[E1, E2, E3, E4], Unit]: ...
765
745
 
766
746
 
767
747
  @overload
768
- def relation(name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]:
769
- ...
748
+ def relation(name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: ...
770
749
 
771
750
 
772
751
  @overload
773
- def relation(name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]:
774
- ...
752
+ def relation(name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: ...
775
753
 
776
754
 
777
755
  @overload
778
- def relation(name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]:
779
- ...
756
+ def relation(name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: ...
780
757
 
781
758
 
782
759
  @overload
783
- def relation(name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]:
784
- ...
760
+ def relation(name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]: ...
785
761
 
786
762
 
787
763
  def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..., Unit]:
@@ -977,7 +953,7 @@ class EGraph(_BaseModule):
977
953
  # By default we want to split primitive outputs
978
954
  kwargs.setdefault("split_primitive_outputs", True)
979
955
  n_inline = kwargs.pop("n_inline_leaves", 0)
980
- serialized = self._egraph.serialize(**kwargs) # type: ignore[misc]
956
+ serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
981
957
  serialized.map_ops(self._state.decls.op_mapping())
982
958
  for _ in range(n_inline):
983
959
  serialized.inline_leaves()
@@ -1053,12 +1029,10 @@ class EGraph(_BaseModule):
1053
1029
  return cast(EXPR, RuntimeExpr(expr.__egg_decls__, TypedExprDecl(expr.__egg_typed_expr__.tp, VarDecl(name))))
1054
1030
 
1055
1031
  @overload
1056
- def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR:
1057
- ...
1032
+ def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ...
1058
1033
 
1059
1034
  @overload
1060
- def simplify(self, expr: EXPR, schedule: Schedule, /) -> EXPR:
1061
- ...
1035
+ def simplify(self, expr: EXPR, schedule: Schedule, /) -> EXPR: ...
1062
1036
 
1063
1037
  def simplify(
1064
1038
  self, expr: EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
@@ -1102,12 +1076,10 @@ class EGraph(_BaseModule):
1102
1076
  raise NotImplementedError(msg)
1103
1077
 
1104
1078
  @overload
1105
- def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport:
1106
- ...
1079
+ def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
1107
1080
 
1108
1081
  @overload
1109
- def run(self, schedule: Schedule, /) -> bindings.RunReport:
1110
- ...
1082
+ def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
1111
1083
 
1112
1084
  def run(
1113
1085
  self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
@@ -1147,12 +1119,10 @@ class EGraph(_BaseModule):
1147
1119
  return bindings.Check(egg_facts)
1148
1120
 
1149
1121
  @overload
1150
- def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR:
1151
- ...
1122
+ def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ...
1152
1123
 
1153
1124
  @overload
1154
- def extract(self, expr: EXPR, /, include_cost: Literal[True]) -> tuple[EXPR, int]:
1155
- ...
1125
+ def extract(self, expr: EXPR, /, include_cost: Literal[True]) -> tuple[EXPR, int]: ...
1156
1126
 
1157
1127
  def extract(self, expr: EXPR, include_cost: bool = False) -> EXPR | tuple[EXPR, int]:
1158
1128
  """
@@ -1229,24 +1199,19 @@ class EGraph(_BaseModule):
1229
1199
  self.pop()
1230
1200
 
1231
1201
  @overload
1232
- def eval(self, expr: i64) -> int:
1233
- ...
1202
+ def eval(self, expr: i64) -> int: ...
1234
1203
 
1235
1204
  @overload
1236
- def eval(self, expr: f64) -> float:
1237
- ...
1205
+ def eval(self, expr: f64) -> float: ...
1238
1206
 
1239
1207
  @overload
1240
- def eval(self, expr: Bool) -> bool:
1241
- ...
1208
+ def eval(self, expr: Bool) -> bool: ...
1242
1209
 
1243
1210
  @overload
1244
- def eval(self, expr: String) -> str:
1245
- ...
1211
+ def eval(self, expr: String) -> str: ...
1246
1212
 
1247
1213
  @overload
1248
- def eval(self, expr: PyObject) -> object:
1249
- ...
1214
+ def eval(self, expr: PyObject) -> object: ...
1250
1215
 
1251
1216
  def eval(self, expr: Expr) -> object:
1252
1217
  """
@@ -1347,8 +1312,7 @@ class Unit(Expr, egg_sort="Unit", builtin=True):
1347
1312
  The unit type. This is also used to reprsent if a value exists, if it is resolved or not.
1348
1313
  """
1349
1314
 
1350
- def __init__(self) -> None:
1351
- ...
1315
+ def __init__(self) -> None: ...
1352
1316
 
1353
1317
 
1354
1318
  def ruleset(
@@ -1501,6 +1465,7 @@ class Rewrite(Command):
1501
1465
  _lhs: RuntimeExpr
1502
1466
  _rhs: RuntimeExpr
1503
1467
  _conditions: tuple[Fact, ...]
1468
+ _subsume: bool
1504
1469
  _fn_name: ClassVar[str] = "rewrite"
1505
1470
 
1506
1471
  def __str__(self) -> str:
@@ -1509,7 +1474,7 @@ class Rewrite(Command):
1509
1474
 
1510
1475
  def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1511
1476
  return bindings.RewriteCommand(
1512
- self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
1477
+ self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite(), self._subsume
1513
1478
  )
1514
1479
 
1515
1480
  def _to_egg_rewrite(self) -> bindings.Rewrite:
@@ -1524,7 +1489,7 @@ class Rewrite(Command):
1524
1489
  return Declarations.create(self._lhs, self._rhs, *self._conditions)
1525
1490
 
1526
1491
  def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
1527
- return Rewrite(ruleset, self._lhs, self._rhs, self._conditions)
1492
+ return Rewrite(ruleset, self._lhs, self._rhs, self._conditions, self._subsume)
1528
1493
 
1529
1494
 
1530
1495
  @dataclass
@@ -1688,21 +1653,23 @@ class ExprAction(Action):
1688
1653
 
1689
1654
 
1690
1655
  @dataclass
1691
- class Delete(Action):
1656
+ class Change(Action):
1692
1657
  """
1693
- Remove a function call from an EGraph.
1658
+ Change a function call in an EGraph.
1694
1659
  """
1695
1660
 
1661
+ change: Literal["delete", "subsume"]
1696
1662
  _call: RuntimeExpr
1697
1663
 
1698
1664
  def __str__(self) -> str:
1699
- return f"delete({self._call})"
1665
+ return f"{self.change}({self._call})"
1700
1666
 
1701
- def _to_egg_action(self) -> bindings.Delete:
1667
+ def _to_egg_action(self) -> bindings.Change:
1702
1668
  egg_call = self._call.__egg__
1703
1669
  if not isinstance(egg_call, bindings.Call):
1704
1670
  raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
1705
- return bindings.Delete(egg_call.name, egg_call.args)
1671
+ change: bindings._Change = bindings.Delete() if self.change == "delete" else bindings.Subsume()
1672
+ return bindings.Change(change, egg_call.name, egg_call.args)
1706
1673
 
1707
1674
  @property
1708
1675
  def __egg_decls__(self) -> Declarations:
@@ -1836,29 +1803,25 @@ class Sequence(Schedule):
1836
1803
 
1837
1804
  @deprecated("Use <ruleset>.register(<rewrite>) instead of passing rulesets as arguments to rewrites.")
1838
1805
  @overload
1839
- def rewrite(lhs: EXPR, ruleset: Ruleset) -> _RewriteBuilder[EXPR]:
1840
- ...
1806
+ def rewrite(lhs: EXPR, ruleset: Ruleset, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...
1841
1807
 
1842
1808
 
1843
1809
  @overload
1844
- def rewrite(lhs: EXPR, ruleset: None = None) -> _RewriteBuilder[EXPR]:
1845
- ...
1810
+ def rewrite(lhs: EXPR, ruleset: None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...
1846
1811
 
1847
1812
 
1848
- def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]:
1813
+ def rewrite(lhs: EXPR, ruleset: Ruleset | None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]:
1849
1814
  """Rewrite the given expression to a new expression."""
1850
- return _RewriteBuilder(lhs, ruleset)
1815
+ return _RewriteBuilder(lhs, ruleset, subsume)
1851
1816
 
1852
1817
 
1853
1818
  @deprecated("Use <ruleset>.register(<birewrite>) instead of passing rulesets as arguments to birewrites.")
1854
1819
  @overload
1855
- def birewrite(lhs: EXPR, ruleset: Ruleset) -> _BirewriteBuilder[EXPR]:
1856
- ...
1820
+ def birewrite(lhs: EXPR, ruleset: Ruleset) -> _BirewriteBuilder[EXPR]: ...
1857
1821
 
1858
1822
 
1859
1823
  @overload
1860
- def birewrite(lhs: EXPR, ruleset: None = None) -> _BirewriteBuilder[EXPR]:
1861
- ...
1824
+ def birewrite(lhs: EXPR, ruleset: None = None) -> _BirewriteBuilder[EXPR]: ...
1862
1825
 
1863
1826
 
1864
1827
  def birewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _BirewriteBuilder[EXPR]:
@@ -1892,7 +1855,12 @@ def expr_action(expr: Expr) -> Action:
1892
1855
 
1893
1856
  def delete(expr: Expr) -> Action:
1894
1857
  """Create a delete expression."""
1895
- return Delete(to_runtime_expr(expr))
1858
+ return Change("delete", to_runtime_expr(expr))
1859
+
1860
+
1861
+ def subsume(expr: Expr) -> Action:
1862
+ """Subsume an expression."""
1863
+ return Change("subsume", to_runtime_expr(expr))
1896
1864
 
1897
1865
 
1898
1866
  def expr_fact(expr: Expr) -> Fact:
@@ -1911,13 +1879,11 @@ def set_(lhs: EXPR) -> _SetBuilder[EXPR]:
1911
1879
 
1912
1880
  @deprecated("Use <ruleset>.register(<rule>) instead of passing rulesets as arguments to rules.")
1913
1881
  @overload
1914
- def rule(*facts: FactLike, ruleset: Ruleset, name: str | None = None) -> _RuleBuilder:
1915
- ...
1882
+ def rule(*facts: FactLike, ruleset: Ruleset, name: str | None = None) -> _RuleBuilder: ...
1916
1883
 
1917
1884
 
1918
1885
  @overload
1919
- def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _RuleBuilder:
1920
- ...
1886
+ def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _RuleBuilder: ...
1921
1887
 
1922
1888
 
1923
1889
  def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = None) -> _RuleBuilder:
@@ -1947,10 +1913,11 @@ def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
1947
1913
  class _RewriteBuilder(Generic[EXPR]):
1948
1914
  lhs: EXPR
1949
1915
  ruleset: Ruleset | None
1916
+ subsume: bool
1950
1917
 
1951
1918
  def to(self, rhs: EXPR, *conditions: FactLike) -> Rewrite:
1952
1919
  lhs = to_runtime_expr(self.lhs)
1953
- rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions))
1920
+ rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), self.subsume)
1954
1921
  if self.ruleset:
1955
1922
  self.ruleset.append(rule)
1956
1923
  return rule
@@ -1966,7 +1933,7 @@ class _BirewriteBuilder(Generic[EXPR]):
1966
1933
 
1967
1934
  def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
1968
1935
  lhs = to_runtime_expr(self.lhs)
1969
- rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions))
1936
+ rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), False)
1970
1937
  if self.ruleset:
1971
1938
  self.ruleset.append(rule)
1972
1939
  return rule
egglog/examples/bool.py CHANGED
@@ -27,8 +27,7 @@ R = relation("R", i64)
27
27
 
28
28
 
29
29
  @function
30
- def f(i: i64Like) -> Bool:
31
- ...
30
+ def f(i: i64Like) -> Bool: ...
32
31
 
33
32
 
34
33
  i = var("i", i64)
@@ -3,6 +3,7 @@
3
3
  Basic equality saturation example.
4
4
  ==================================
5
5
  """
6
+
6
7
  from __future__ import annotations
7
8
 
8
9
  from egglog import *
@@ -11,18 +12,14 @@ egraph = EGraph()
11
12
 
12
13
 
13
14
  class Num(Expr):
14
- def __init__(self, value: i64Like) -> None:
15
- ...
15
+ def __init__(self, value: i64Like) -> None: ...
16
16
 
17
17
  @classmethod
18
- def var(cls, name: StringLike) -> Num:
19
- ...
18
+ def var(cls, name: StringLike) -> Num: ...
20
19
 
21
- def __add__(self, other: Num) -> Num:
22
- ...
20
+ def __add__(self, other: Num) -> Num: ...
23
21
 
24
- def __mul__(self, other: Num) -> Num:
25
- ...
22
+ def __mul__(self, other: Num) -> Num: ...
26
23
 
27
24
 
28
25
  expr1 = Num(2) * (Num.var("x") + Num(3))
egglog/examples/fib.py CHANGED
@@ -3,14 +3,14 @@
3
3
  Fibonacci numbers example
4
4
  =========================
5
5
  """
6
+
6
7
  from __future__ import annotations
7
8
 
8
9
  from egglog import *
9
10
 
10
11
 
11
12
  @function
12
- def fib(x: i64Like) -> i64:
13
- ...
13
+ def fib(x: i64Like) -> i64: ...
14
14
 
15
15
 
16
16
  f0, f1, x = vars_("f0 f1 x", i64)
@@ -4,6 +4,7 @@
4
4
  Lambda Calculus
5
5
  ===============
6
6
  """
7
+
7
8
  from __future__ import annotations
8
9
 
9
10
  from typing import TYPE_CHECKING, ClassVar
@@ -22,66 +23,53 @@ class Val(Expr):
22
23
  TRUE: ClassVar[Val]
23
24
  FALSE: ClassVar[Val]
24
25
 
25
- def __init__(self, v: i64Like) -> None:
26
- ...
26
+ def __init__(self, v: i64Like) -> None: ...
27
27
 
28
28
 
29
29
  class Var(Expr):
30
- def __init__(self, v: StringLike) -> None:
31
- ...
30
+ def __init__(self, v: StringLike) -> None: ...
32
31
 
33
32
 
34
33
  class Term(Expr):
35
34
  @classmethod
36
- def val(cls, v: Val) -> Term:
37
- ...
35
+ def val(cls, v: Val) -> Term: ...
38
36
 
39
37
  @classmethod
40
- def var(cls, v: Var) -> Term:
41
- ...
38
+ def var(cls, v: Var) -> Term: ...
42
39
 
43
- def __add__(self, other: Term) -> Term:
44
- ...
40
+ def __add__(self, other: Term) -> Term: ...
45
41
 
46
42
  def __eq__(self, other: Term) -> Term: # type: ignore[override]
47
43
  ...
48
44
 
49
- def __call__(self, other: Term) -> Term:
50
- ...
45
+ def __call__(self, other: Term) -> Term: ...
51
46
 
52
- def eval(self) -> Val:
53
- ...
47
+ def eval(self) -> Val: ...
54
48
 
55
- def v(self) -> Var:
56
- ...
49
+ def v(self) -> Var: ...
57
50
 
58
51
 
59
52
  @function
60
- def lam(x: Var, t: Term) -> Term:
61
- ...
53
+ def lam(x: Var, t: Term) -> Term: ...
62
54
 
63
55
 
64
56
  @function
65
- def let_(x: Var, t: Term, b: Term) -> Term:
66
- ...
57
+ def let_(x: Var, t: Term, b: Term) -> Term: ...
67
58
 
68
59
 
69
60
  @function
70
- def fix(x: Var, t: Term) -> Term:
71
- ...
61
+ def fix(x: Var, t: Term) -> Term: ...
72
62
 
73
63
 
74
64
  @function
75
- def if_(c: Term, t: Term, f: Term) -> Term:
76
- ...
65
+ def if_(c: Term, t: Term, f: Term) -> Term: ...
77
66
 
78
67
 
79
68
  StringSet = Set[Var]
80
69
 
81
70
 
82
71
  @function(merge=lambda old, new: old & new)
83
- def freer(t: Term) -> StringSet:
84
- ...
72
+ def freer(t: Term) -> StringSet: ...
85
73
 
86
74
 
87
75
  (v, v1, v2) = vars_("v v1 v2", Val)
egglog/examples/matrix.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Matrix multiplication and Kronecker product.
3
3
  ============================================
4
4
  """
5
+
5
6
  from __future__ import annotations
6
7
 
7
8
  from egglog import *
@@ -18,8 +19,7 @@ class Dim(Expr):
18
19
  """
19
20
 
20
21
  @method(egg_fn="Lit")
21
- def __init__(self, value: i64Like) -> None:
22
- ...
22
+ def __init__(self, value: i64Like) -> None: ...
23
23
 
24
24
  @method(egg_fn="NamedDim")
25
25
  @classmethod