egglog 11.1.0__cp312-cp312-manylinux_2_17_ppc64.manylinux2014_ppc64.whl → 11.3.0__cp312-cp312-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/bindings.pyi CHANGED
@@ -7,15 +7,16 @@ from typing_extensions import final
7
7
  __all__ = [
8
8
  "ActionCommand",
9
9
  "AddRuleset",
10
- "Best",
11
10
  "BiRewriteCommand",
12
11
  "Bool",
12
+ "CSVPrintFunctionMode",
13
13
  "Call",
14
14
  "Change",
15
15
  "Check",
16
16
  "Constructor",
17
17
  "Datatype",
18
18
  "Datatypes",
19
+ "DefaultPrintFunctionMode",
19
20
  "Delete",
20
21
  "EGraph",
21
22
  "EggSmolError",
@@ -23,10 +24,13 @@ __all__ = [
23
24
  "Eq",
24
25
  "Expr_",
25
26
  "Extract",
27
+ "ExtractBest",
28
+ "ExtractVariants",
26
29
  "Fact",
27
30
  "Fail",
28
31
  "Float",
29
32
  "Function",
33
+ "FunctionCommand",
30
34
  "IdentSort",
31
35
  "Include",
32
36
  "Input",
@@ -35,10 +39,14 @@ __all__ = [
35
39
  "Lit",
36
40
  "NewSort",
37
41
  "Output",
42
+ "OverallStatistics",
38
43
  "Panic",
39
44
  "PanicSpan",
40
45
  "Pop",
46
+ "PrintAllFunctionsSize",
41
47
  "PrintFunction",
48
+ "PrintFunctionOutput",
49
+ "PrintFunctionSize",
42
50
  "PrintOverallStatistics",
43
51
  "PrintSize",
44
52
  "Push",
@@ -53,13 +61,13 @@ __all__ = [
53
61
  "RunConfig",
54
62
  "RunReport",
55
63
  "RunSchedule",
64
+ "RunScheduleOutput",
56
65
  "RustSpan",
57
66
  "Saturate",
58
67
  "Schema",
59
68
  "Sequence",
60
69
  "SerializedEGraph",
61
70
  "Set",
62
- "SetOption",
63
71
  "Sort",
64
72
  "SrcFile",
65
73
  "String",
@@ -73,13 +81,18 @@ __all__ = [
73
81
  "Unit",
74
82
  "UnstableCombinedRuleset",
75
83
  "UserDefined",
84
+ "UserDefinedCommandOutput",
85
+ "UserDefinedOutput",
76
86
  "Var",
77
87
  "Variant",
78
- "Variants",
79
88
  ]
80
89
 
81
90
  @final
82
91
  class SerializedEGraph:
92
+ @property
93
+ def truncated_functions(self) -> list[str]: ...
94
+ @property
95
+ def discarded_functions(self) -> list[str]: ...
83
96
  def inline_leaves(self) -> None: ...
84
97
  def saturate_inline_leaves(self) -> None: ...
85
98
  def to_dot(self) -> str: ...
@@ -106,9 +119,7 @@ class EGraph:
106
119
  ) -> None: ...
107
120
  def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ...
108
121
  def commands(self) -> str | None: ...
109
- def run_program(self, *commands: _Command) -> list[str]: ...
110
- def extract_report(self) -> _ExtractReport | None: ...
111
- def run_report(self) -> RunReport | None: ...
122
+ def run_program(self, *commands: _Command) -> list[_CommandOutput]: ...
112
123
  def serialize(
113
124
  self,
114
125
  root_eclasses: list[_Expr],
@@ -356,6 +367,13 @@ class IdentSort:
356
367
  sort: str
357
368
  def __init__(self, ident: str, sort: str) -> None: ...
358
369
 
370
+ @final
371
+ class UserDefinedCommandOutput: ...
372
+
373
+ @final
374
+ class Function:
375
+ name: str
376
+
359
377
  @final
360
378
  class RunReport:
361
379
  updated: bool
@@ -375,20 +393,80 @@ class RunReport:
375
393
  rebuild_time_per_ruleset: dict[str, timedelta],
376
394
  ) -> None: ...
377
395
 
396
+ ##
397
+ # Command Outputs
398
+ ##
399
+
400
+ @final
401
+ class PrintFunctionSize:
402
+ size: int
403
+ def __init__(self, size: int) -> None: ...
404
+
405
+ @final
406
+ class PrintAllFunctionsSize:
407
+ sizes: list[tuple[str, int]]
408
+ def __init__(self, sizes: list[tuple[str, int]]) -> None: ...
409
+
378
410
  @final
379
- class Variants:
411
+ class ExtractVariants:
380
412
  termdag: TermDag
381
413
  terms: list[_Term]
382
414
  def __init__(self, termdag: TermDag, terms: list[_Term]) -> None: ...
383
415
 
384
416
  @final
385
- class Best:
417
+ class ExtractBest:
386
418
  termdag: TermDag
387
419
  cost: int
388
420
  term: _Term
389
421
  def __init__(self, termdag: TermDag, cost: int, term: _Term) -> None: ...
390
422
 
391
- _ExtractReport: TypeAlias = Variants | Best
423
+ @final
424
+ class OverallStatistics:
425
+ report: RunReport
426
+ def __init__(self, report: RunReport) -> None: ...
427
+
428
+ @final
429
+ class RunScheduleOutput:
430
+ report: RunReport
431
+ def __init__(self, report: RunReport) -> None: ...
432
+
433
+ @final
434
+ class PrintFunctionOutput:
435
+ function: Function
436
+ termdag: TermDag
437
+ terms: list[tuple[_Term, _Term]]
438
+ mode: _PrintFunctionMode
439
+ def __init__(
440
+ self, function: Function, termdag: TermDag, terms: list[tuple[_Term, _Term]], mode: _PrintFunctionMode
441
+ ) -> None: ...
442
+
443
+ @final
444
+ class UserDefinedOutput:
445
+ output: UserDefinedCommandOutput
446
+ def __init__(self, output: UserDefinedCommandOutput) -> None: ...
447
+
448
+ _CommandOutput: TypeAlias = (
449
+ PrintFunctionSize
450
+ | PrintAllFunctionsSize
451
+ | ExtractVariants
452
+ | ExtractBest
453
+ | OverallStatistics
454
+ | RunScheduleOutput
455
+ | PrintFunctionOutput
456
+ | UserDefinedOutput
457
+ )
458
+
459
+ ##
460
+ # Print Function Modes
461
+ ##
462
+
463
+ @final
464
+ class DefaultPrintFunctionMode: ...
465
+
466
+ @final
467
+ class CSVPrintFunctionMode: ...
468
+
469
+ _PrintFunctionMode: TypeAlias = DefaultPrintFunctionMode | CSVPrintFunctionMode
392
470
 
393
471
  ##
394
472
  # Schedules
@@ -442,12 +520,6 @@ _Subdatatypes: TypeAlias = SubVariants | NewSort
442
520
  # Commands
443
521
  ##
444
522
 
445
- @final
446
- class SetOption:
447
- name: str
448
- value: _Expr
449
- def __init__(self, name: str, value: _Expr) -> None: ...
450
-
451
523
  @final
452
524
  class Datatype:
453
525
  span: _Span
@@ -469,7 +541,7 @@ class Sort:
469
541
  def __init__(self, span: _Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
470
542
 
471
543
  @final
472
- class Function:
544
+ class FunctionCommand:
473
545
  span: _Span
474
546
  name: str
475
547
  schema: Schema
@@ -531,8 +603,12 @@ class Check:
531
603
  class PrintFunction:
532
604
  span: _Span
533
605
  name: str
534
- length: int
535
- def __init__(self, span: _Span, name: str, length: int) -> None: ...
606
+ length: int | None
607
+ filename: str | None
608
+ mode: _PrintFunctionMode
609
+ def __init__(
610
+ self, span: _Span, name: str, length: int | None, filename: str | None, mode: _PrintFunctionMode
611
+ ) -> None: ...
536
612
 
537
613
  @final
538
614
  class PrintSize:
@@ -613,11 +689,10 @@ class UnstableCombinedRuleset:
613
689
  def __init__(self, span: _Span, name: str, rulesets: list[str]) -> None: ...
614
690
 
615
691
  _Command: TypeAlias = (
616
- SetOption
617
- | Datatype
692
+ Datatype
618
693
  | Datatypes
619
694
  | Sort
620
- | Function
695
+ | FunctionCommand
621
696
  | AddRuleset
622
697
  | RuleCommand
623
698
  | RewriteCommand
egglog/builtins.py CHANGED
@@ -103,6 +103,10 @@ class String(BuiltinExpr):
103
103
  @method(egg_fn="replace")
104
104
  def replace(self, old: StringLike, new: StringLike) -> String: ...
105
105
 
106
+ @method(preserve=True)
107
+ def __add__(self, other: StringLike) -> String:
108
+ return join(self, other)
109
+
106
110
 
107
111
  StringLike: TypeAlias = String | str
108
112
 
egglog/declarations.py CHANGED
@@ -9,6 +9,7 @@ from __future__ import annotations
9
9
  from dataclasses import dataclass, field
10
10
  from functools import cached_property
11
11
  from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
12
+ from uuid import UUID
12
13
  from weakref import WeakValueDictionary
13
14
 
14
15
  from typing_extensions import Self, assert_never
@@ -20,6 +21,7 @@ if TYPE_CHECKING:
20
21
  __all__ = [
21
22
  "ActionCommandDecl",
22
23
  "ActionDecl",
24
+ "BackOffDecl",
23
25
  "BiRewriteDecl",
24
26
  "CallDecl",
25
27
  "CallableDecl",
@@ -52,6 +54,7 @@ __all__ = [
52
54
  "JustTypeRef",
53
55
  "LetDecl",
54
56
  "LetRefDecl",
57
+ "LetSchedulerDecl",
55
58
  "LitDecl",
56
59
  "LitType",
57
60
  "MethodRef",
@@ -69,6 +72,7 @@ __all__ = [
69
72
  "SaturateDecl",
70
73
  "ScheduleDecl",
71
74
  "SequenceDecl",
75
+ "SetCostDecl",
72
76
  "SetDecl",
73
77
  "SpecialFunctions",
74
78
  "TypeOrVarRef",
@@ -789,9 +793,24 @@ class SequenceDecl:
789
793
  class RunDecl:
790
794
  ruleset: str
791
795
  until: tuple[FactDecl, ...] | None
796
+ scheduler: BackOffDecl | None = None
792
797
 
793
798
 
794
- ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
799
+ @dataclass(frozen=True)
800
+ class LetSchedulerDecl:
801
+ scheduler: BackOffDecl
802
+ inner: ScheduleDecl
803
+
804
+
805
+ ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl | LetSchedulerDecl
806
+
807
+
808
+ @dataclass(frozen=True)
809
+ class BackOffDecl:
810
+ id: UUID
811
+ match_limit: int | None
812
+ ban_length: int | None
813
+
795
814
 
796
815
  ##
797
816
  # Facts
@@ -854,7 +873,14 @@ class PanicDecl:
854
873
  msg: str
855
874
 
856
875
 
857
- ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
876
+ @dataclass(frozen=True)
877
+ class SetCostDecl:
878
+ tp: JustTypeRef
879
+ expr: CallDecl
880
+ cost: ExprDecl
881
+
882
+
883
+ ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl | SetCostDecl
858
884
 
859
885
 
860
886
  ##
egglog/egraph.py CHANGED
@@ -23,6 +23,8 @@ from typing import (
23
23
  get_type_hints,
24
24
  overload,
25
25
  )
26
+ from uuid import uuid4
27
+ from warnings import warn
26
28
 
27
29
  import graphviz
28
30
  from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
@@ -39,11 +41,12 @@ from .thunk import *
39
41
  from .version_compat import *
40
42
 
41
43
  if TYPE_CHECKING:
42
- from .builtins import String, Unit
44
+ from .builtins import String, Unit, i64Like
43
45
 
44
46
 
45
47
  __all__ = [
46
48
  "Action",
49
+ "BackOff",
47
50
  "BaseExpr",
48
51
  "BuiltinExpr",
49
52
  "Command",
@@ -62,6 +65,7 @@ __all__ = [
62
65
  "_RewriteBuilder",
63
66
  "_SetBuilder",
64
67
  "_UnionBuilder",
68
+ "back_off",
65
69
  "birewrite",
66
70
  "check",
67
71
  "check_eq",
@@ -83,6 +87,7 @@ __all__ = [
83
87
  "run",
84
88
  "seq",
85
89
  "set_",
90
+ "set_cost",
86
91
  "subsume",
87
92
  "union",
88
93
  "unstable_combine_rulesets",
@@ -804,7 +809,7 @@ class GraphvizKwargs(TypedDict, total=False):
804
809
  max_calls_per_function: int | None
805
810
  n_inline_leaves: int
806
811
  split_primitive_outputs: bool
807
- split_functions: list[object]
812
+ split_functions: list[ExprCallable]
808
813
  include_temporary_functions: bool
809
814
 
810
815
 
@@ -851,12 +856,12 @@ class EGraph:
851
856
  """
852
857
  Loads a CSV file and sets it as *input, output of the function.
853
858
  """
854
- self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path))
859
+ self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn)[1], path))
855
860
 
856
- def _callable_to_egg(self, fn: object) -> str:
861
+ def _callable_to_egg(self, fn: ExprCallable) -> tuple[CallableRef, str]:
857
862
  ref, decls = resolve_callable(fn)
858
863
  self._add_decls(decls)
859
- return self._state.callable_ref_to_egg(ref)[0]
864
+ return ref, self._state.callable_ref_to_egg(ref)[0]
860
865
 
861
866
  # TODO: Change let to be action...
862
867
  def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR:
@@ -903,13 +908,18 @@ class EGraph:
903
908
 
904
909
  def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
905
910
  self._add_decls(schedule)
906
- egg_schedule = self._state.schedule_to_egg(schedule.schedule)
907
- self._egraph.run_program(bindings.RunSchedule(egg_schedule))
908
- run_report = self._egraph.run_report()
909
- if not run_report:
910
- msg = "No run report saved"
911
- raise ValueError(msg)
912
- return run_report
911
+ cmd = self._state.run_schedule_to_egg(schedule.schedule)
912
+ (command_output,) = self._egraph.run_program(cmd)
913
+ assert isinstance(command_output, bindings.RunScheduleOutput)
914
+ return command_output.report
915
+
916
+ def stats(self) -> bindings.RunReport:
917
+ """
918
+ Returns the overall run report for the egraph.
919
+ """
920
+ (output,) = self._egraph.run_program(bindings.PrintOverallStatistics())
921
+ assert isinstance(output, bindings.OverallStatistics)
922
+ return output.report
913
923
 
914
924
  def check_bool(self, *facts: FactLike) -> bool:
915
925
  """
@@ -954,45 +964,41 @@ class EGraph:
954
964
  """
955
965
  runtime_expr = to_runtime_expr(expr)
956
966
  extract_report = self._run_extract(runtime_expr, 0)
957
-
958
- if not isinstance(extract_report, bindings.Best):
959
- msg = "No extract report saved"
960
- raise ValueError(msg) # noqa: TRY004
961
- (new_typed_expr,) = self._state.exprs_from_egg(
962
- extract_report.termdag, [extract_report.term], runtime_expr.__egg_typed_expr__.tp
963
- )
964
-
965
- res = cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
967
+ assert isinstance(extract_report, bindings.ExtractBest)
968
+ res = self._from_termdag(extract_report.termdag, extract_report.term, runtime_expr.__egg_typed_expr__.tp)
966
969
  if include_cost:
967
970
  return res, extract_report.cost
968
971
  return res
969
972
 
973
+ def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any:
974
+ (new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp)
975
+ return RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)
976
+
970
977
  def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
971
978
  """
972
979
  Extract multiple expressions from the egraph.
973
980
  """
974
981
  runtime_expr = to_runtime_expr(expr)
975
982
  extract_report = self._run_extract(runtime_expr, n)
976
- if not isinstance(extract_report, bindings.Variants):
977
- msg = "Wrong extract report type"
978
- raise ValueError(msg) # noqa: TRY004
983
+ assert isinstance(extract_report, bindings.ExtractVariants)
979
984
  new_exprs = self._state.exprs_from_egg(
980
985
  extract_report.termdag, extract_report.terms, runtime_expr.__egg_typed_expr__.tp
981
986
  )
982
987
  return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
983
988
 
984
- def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._ExtractReport:
989
+ def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
985
990
  self._add_decls(expr)
986
991
  expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
992
+ # If we have defined any cost tables use the custom extraction
993
+ args = (expr, bindings.Lit(span(2), bindings.Int(n)))
994
+ if self._state.cost_callables:
995
+ cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args))
996
+ else:
997
+ cmd = bindings.Extract(span(2), *args)
987
998
  try:
988
- self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
999
+ return self._egraph.run_program(cmd)[0]
989
1000
  except BaseException as e:
990
1001
  raise add_note("Extracting: " + str(expr), e) # noqa: B904
991
- extract_report = self._egraph.extract_report()
992
- if not extract_report:
993
- msg = "No extract report saved"
994
- raise ValueError(msg)
995
- return extract_report
996
1002
 
997
1003
  def push(self) -> None:
998
1004
  """
@@ -1030,15 +1036,21 @@ class EGraph:
1030
1036
  split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
1031
1037
  split_functions = kwargs.pop("split_functions", [])
1032
1038
  include_temporary_functions = kwargs.pop("include_temporary_functions", False)
1033
- n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
1039
+ n_inline_leaves = kwargs.pop("n_inline_leaves", 0)
1034
1040
  serialized = self._egraph.serialize(
1035
1041
  [],
1036
1042
  max_functions=max_functions,
1037
1043
  max_calls_per_function=max_calls_per_function,
1038
1044
  include_temporary_functions=include_temporary_functions,
1039
1045
  )
1046
+ if serialized.discarded_functions:
1047
+ msg = ", ".join(set(self._state.possible_egglog_functions(serialized.discarded_functions)))
1048
+ warn(f"Omitted: {msg}", stacklevel=3)
1049
+ if serialized.truncated_functions:
1050
+ msg = ", ".join(set(self._state.possible_egglog_functions(serialized.truncated_functions)))
1051
+ warn(f"Truncated: {msg}", stacklevel=3)
1040
1052
  if split_primitive_outputs or split_functions:
1041
- additional_ops = set(map(self._callable_to_egg, split_functions))
1053
+ additional_ops = {self._callable_to_egg(f)[1] for f in split_functions}
1042
1054
  serialized.split_classes(self._egraph, additional_ops)
1043
1055
  serialized.map_ops(self._state.op_mapping())
1044
1056
 
@@ -1185,6 +1197,58 @@ class EGraph:
1185
1197
  assert_never(cmd)
1186
1198
  return self._state.command_to_egg(cmd_decl, ruleset_name)
1187
1199
 
1200
+ def function_size(self, fn: ExprCallable) -> int:
1201
+ """
1202
+ Returns the number of rows in a certain function
1203
+ """
1204
+ egg_name = self._callable_to_egg(fn)[1]
1205
+ (output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name))
1206
+ assert isinstance(output, bindings.PrintFunctionSize)
1207
+ return output.size
1208
+
1209
+ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
1210
+ """
1211
+ Returns a list of all functions and their sizes.
1212
+ """
1213
+ (output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
1214
+ assert isinstance(output, bindings.PrintAllFunctionsSize)
1215
+ return [
1216
+ (
1217
+ cast(
1218
+ "ExprCallable",
1219
+ create_callable(self._state.__egg_decls__, next(iter(refs))),
1220
+ ),
1221
+ size,
1222
+ )
1223
+ for (name, size) in output.sizes
1224
+ if (refs := self._state.egg_fn_to_callable_refs[name])
1225
+ ]
1226
+
1227
+ def function_values(
1228
+ self, fn: Callable[..., BASE_EXPR] | BASE_EXPR, length: int | None = None
1229
+ ) -> dict[BASE_EXPR, BASE_EXPR]:
1230
+ """
1231
+ Given a callable that is a "function", meaning it returns a primitive or has a merge set,
1232
+ returns a mapping of the function applied with its arguments to its values
1233
+
1234
+ If length is specified, only the first `length` values will be returned.
1235
+ """
1236
+ ref, egg_name = self._callable_to_egg(fn)
1237
+ cmd = bindings.PrintFunction(span(1), egg_name, length, None, bindings.DefaultPrintFunctionMode())
1238
+ (output,) = self._egraph.run_program(cmd)
1239
+ assert isinstance(output, bindings.PrintFunctionOutput)
1240
+ signature = self.__egg_decls__.get_callable_decl(ref).signature
1241
+ assert isinstance(signature, FunctionSignature)
1242
+ tp = signature.semantic_return_type.to_just()
1243
+ return {
1244
+ self._from_termdag(output.termdag, call, tp): self._from_termdag(output.termdag, res, tp)
1245
+ for (call, res) in output.terms
1246
+ }
1247
+
1248
+
1249
+ # Either a constant or a function.
1250
+ ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr
1251
+
1188
1252
 
1189
1253
  @dataclass(frozen=True)
1190
1254
  class _WrappedMethod:
@@ -1406,10 +1470,13 @@ class Fact:
1406
1470
  """
1407
1471
  Returns True if the two sides of an equality are structurally equal.
1408
1472
  """
1409
- if not isinstance(self.fact, EqDecl):
1410
- msg = "Can only check equality facts"
1411
- raise TypeError(msg)
1412
- return self.fact.left == self.fact.right
1473
+ match self.fact:
1474
+ case EqDecl(_, left, right):
1475
+ return left == right
1476
+ case ExprFactDecl(TypedExprDecl(_, CallDecl(FunctionRef("!="), (left_tp, right_tp)))):
1477
+ return left_tp != right_tp
1478
+ msg = f"Can only check equality for == or != not {self}"
1479
+ raise ValueError(msg)
1413
1480
 
1414
1481
 
1415
1482
  @dataclass
@@ -1457,6 +1524,18 @@ def panic(message: str) -> Action:
1457
1524
  return Action(Declarations(), PanicDecl(message))
1458
1525
 
1459
1526
 
1527
+ def set_cost(expr: BaseExpr, cost: i64Like) -> Action:
1528
+ """Set the cost of the given expression."""
1529
+ from .builtins import i64 # noqa: PLC0415
1530
+
1531
+ expr_runtime = to_runtime_expr(expr)
1532
+ typed_expr_decl = expr_runtime.__egg_typed_expr__
1533
+ expr_decl = typed_expr_decl.expr
1534
+ assert isinstance(expr_decl, CallDecl), "Can only set cost of calls, not literals or vars"
1535
+ cost_decl = to_runtime_expr(convert(cost, i64)).__egg_typed_expr__.expr
1536
+ return Action(expr_runtime.__egg_decls__, SetCostDecl(typed_expr_decl.tp, expr_decl, cost_decl))
1537
+
1538
+
1460
1539
  def let(name: str, expr: BaseExpr) -> Action:
1461
1540
  """Create a let binding."""
1462
1541
  runtime_expr = to_runtime_expr(expr)
@@ -1710,17 +1789,51 @@ def to_runtime_expr(expr: BaseExpr) -> RuntimeExpr:
1710
1789
  return expr
1711
1790
 
1712
1791
 
1713
- def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
1792
+ def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | None = None) -> Schedule:
1714
1793
  """
1715
1794
  Create a run configuration.
1716
1795
  """
1717
1796
  facts = _fact_likes(until)
1718
1797
  return Schedule(
1719
1798
  Thunk.fn(Declarations.create, ruleset, *facts),
1720
- RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
1799
+ RunDecl(
1800
+ ruleset.__egg_name__ if ruleset else "",
1801
+ tuple(f.fact for f in facts) or None,
1802
+ scheduler.scheduler if scheduler else None,
1803
+ ),
1721
1804
  )
1722
1805
 
1723
1806
 
1807
+ def back_off(match_limit: None | int = None, ban_length: None | int = None) -> BackOff:
1808
+ """
1809
+ Create a backoff scheduler configuration.
1810
+
1811
+ ```python
1812
+ schedule = run(analysis_ruleset).saturate() + run(ruleset, scheduler=back_off(match_limit=1000, ban_length=5)) * 10
1813
+ ```
1814
+ This will run the `analysis_ruleset` until saturation, then run `ruleset` 10 times, using a backoff scheduler.
1815
+ """
1816
+ return BackOff(BackOffDecl(id=uuid4(), match_limit=match_limit, ban_length=ban_length))
1817
+
1818
+
1819
+ @dataclass(frozen=True)
1820
+ class BackOff:
1821
+ scheduler: BackOffDecl
1822
+
1823
+ def scope(self, schedule: Schedule) -> Schedule:
1824
+ """
1825
+ Defines the scheduler to be created directly before the inner schedule, instead of the default which is at the
1826
+ most outer scope.
1827
+ """
1828
+ return Schedule(schedule.__egg_decls_thunk__, LetSchedulerDecl(self.scheduler, schedule.schedule))
1829
+
1830
+ def __str__(self) -> str:
1831
+ return pretty_decl(Declarations(), self.scheduler)
1832
+
1833
+ def __repr__(self) -> str:
1834
+ return str(self)
1835
+
1836
+
1724
1837
  def seq(*schedules: Schedule) -> Schedule:
1725
1838
  """
1726
1839
  Run a sequence of schedules.
egglog/egraph_state.py CHANGED
@@ -6,8 +6,9 @@ from __future__ import annotations
6
6
 
7
7
  import re
8
8
  from collections import defaultdict
9
- from dataclasses import dataclass, field
9
+ from dataclasses import dataclass, field, replace
10
10
  from typing import TYPE_CHECKING, Literal, overload
11
+ from uuid import UUID
11
12
 
12
13
  from typing_extensions import assert_never
13
14
 
@@ -71,6 +72,9 @@ class EGraphState:
71
72
  # Cache of egg expressions for converting to egg
72
73
  expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
73
74
 
75
+ # Callables which have cost tables associated with them
76
+ cost_callables: set[CallableRef] = field(default_factory=set)
77
+
74
78
  def copy(self) -> EGraphState:
75
79
  """
76
80
  Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
@@ -83,20 +87,143 @@ class EGraphState:
83
87
  callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
84
88
  type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
85
89
  expr_to_egg_cache=self.expr_to_egg_cache.copy(),
90
+ cost_callables=self.cost_callables.copy(),
86
91
  )
87
92
 
88
- def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
93
+ def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
94
+ """
95
+ Turn a run schedule into an egg command.
96
+
97
+ If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise
98
+ will be a normal run command.
99
+ """
100
+ processed_schedule = self._process_schedule(schedule)
101
+ if processed_schedule is None:
102
+ return bindings.RunSchedule(self._schedule_to_egg(schedule))
103
+ top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, [])
104
+ if len(top_level_schedules) == 1:
105
+ schedule_expr = top_level_schedules[0]
106
+ else:
107
+ schedule_expr = bindings.Call(span(), "seq", top_level_schedules)
108
+ return bindings.UserDefined(span(), "run-schedule", [schedule_expr])
109
+
110
+ def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
111
+ """
112
+ Processes a schedule to determine if it contains any custom schedulers.
113
+
114
+ If it does, it returns a new schedule with all the required let bindings added to the other scope.
115
+ If not, returns none.
116
+
117
+ Also processes all rulesets in the schedule to make sure they are registered.
118
+ """
119
+ bound_schedulers: list[UUID] = []
120
+ unbound_schedulers: list[BackOffDecl] = []
121
+
122
+ def helper(s: ScheduleDecl) -> None:
123
+ match s:
124
+ case LetSchedulerDecl(scheduler, inner):
125
+ bound_schedulers.append(scheduler.id)
126
+ return helper(inner)
127
+ case RunDecl(ruleset_name, _, scheduler):
128
+ self.ruleset_to_egg(ruleset_name)
129
+ if scheduler and scheduler.id not in bound_schedulers:
130
+ unbound_schedulers.append(scheduler)
131
+ case SaturateDecl(inner) | RepeatDecl(inner, _):
132
+ return helper(inner)
133
+ case SequenceDecl(schedules):
134
+ for sc in schedules:
135
+ helper(sc)
136
+ case _:
137
+ assert_never(s)
138
+ return None
139
+
140
+ helper(schedule)
141
+ if not bound_schedulers and not unbound_schedulers:
142
+ return None
143
+ for scheduler in unbound_schedulers:
144
+ schedule = LetSchedulerDecl(scheduler, schedule)
145
+ return schedule
146
+
147
+ def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
148
+ msg = "Should never reach this, let schedulers should be handled by custom scheduler"
89
149
  match schedule:
90
150
  case SaturateDecl(schedule):
91
- return bindings.Saturate(span(), self.schedule_to_egg(schedule))
151
+ return bindings.Saturate(span(), self._schedule_to_egg(schedule))
92
152
  case RepeatDecl(schedule, times):
93
- return bindings.Repeat(span(), times, self.schedule_to_egg(schedule))
153
+ return bindings.Repeat(span(), times, self._schedule_to_egg(schedule))
94
154
  case SequenceDecl(schedules):
95
- return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules])
96
- case RunDecl(ruleset_name, until):
97
- self.ruleset_to_egg(ruleset_name)
155
+ return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules])
156
+ case RunDecl(ruleset_name, until, scheduler):
157
+ if scheduler is not None:
158
+ raise ValueError(msg)
98
159
  config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
99
160
  return bindings.Run(span(), config)
161
+ case LetSchedulerDecl():
162
+ raise ValueError(msg)
163
+ case _:
164
+ assert_never(schedule)
165
+
166
+ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
167
+ self, schedule: ScheduleDecl, bound_schedulers: list[UUID]
168
+ ) -> list[bindings._Expr]:
169
+ """
170
+ Turns a scheduler into an egg expression, to be used with a custom extract command.
171
+
172
+ The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`.
173
+ """
174
+ match schedule:
175
+ case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner):
176
+ name = f"_scheduler_{len(bound_schedulers)}"
177
+ bound_schedulers.append(id)
178
+ args: list[bindings._Expr] = []
179
+ if match_limit is not None:
180
+ args.append(bindings.Var(span(), ":match-limit"))
181
+ args.append(bindings.Lit(span(), bindings.Int(match_limit)))
182
+ if ban_length is not None:
183
+ args.append(bindings.Var(span(), ":ban-length"))
184
+ args.append(bindings.Lit(span(), bindings.Int(ban_length)))
185
+ back_off_decl = bindings.Call(span(), "back-off", args)
186
+ let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl])
187
+ return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)]
188
+ case RunDecl(ruleset_name, until, scheduler):
189
+ args = [bindings.Var(span(), ruleset_name)]
190
+ if scheduler:
191
+ name = "run-with"
192
+ scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}"
193
+ args.insert(0, bindings.Var(span(), scheduler_name))
194
+ else:
195
+ name = "run"
196
+ if until:
197
+ if len(until) > 1:
198
+ msg = "Can only have one until fact with custom scheduler"
199
+ raise ValueError(msg)
200
+ args.append(bindings.Var(span(), ":until"))
201
+ fact_egg = self.fact_to_egg(until[0])
202
+ if isinstance(fact_egg, bindings.Eq):
203
+ msg = "Cannot use equality fact with custom scheduler"
204
+ raise ValueError(msg)
205
+ args.append(fact_egg.expr)
206
+ return [bindings.Call(span(), name, args)]
207
+ case SaturateDecl(inner):
208
+ return [
209
+ bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers))
210
+ ]
211
+ case RepeatDecl(inner, times):
212
+ return [
213
+ bindings.Call(
214
+ span(),
215
+ "repeat",
216
+ [
217
+ bindings.Lit(span(), bindings.Int(times)),
218
+ *self._schedule_with_scheduler_to_egg(inner, bound_schedulers),
219
+ ],
220
+ )
221
+ ]
222
+ case SequenceDecl(schedules):
223
+ res = []
224
+ for s in schedules:
225
+ res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers))
226
+ return res
100
227
  case _:
101
228
  assert_never(schedule)
102
229
 
@@ -212,9 +339,32 @@ class EGraphState:
212
339
  return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs))
213
340
  case PanicDecl(name):
214
341
  return bindings.Panic(span(), name)
342
+ case SetCostDecl(tp, expr, cost):
343
+ self.type_ref_to_egg(tp)
344
+ cost_table = self.create_cost_table(expr.callable)
345
+ args_egg = [self.typed_expr_to_egg(x, False) for x in expr.args]
346
+ return bindings.Set(span(), cost_table, args_egg, self._expr_to_egg(cost))
215
347
  case _:
216
348
  assert_never(action)
217
349
 
350
+ def create_cost_table(self, ref: CallableRef) -> str:
351
+ """
352
+ Creates the egg cost table if needed and gets the name of the table.
353
+ """
354
+ name = self.cost_table_name(ref)
355
+ if ref not in self.cost_callables:
356
+ self.cost_callables.add(ref)
357
+ signature = self.__egg_decls__.get_callable_decl(ref).signature
358
+ assert isinstance(signature, FunctionSignature), "Can only add cost tables for functions"
359
+ signature = replace(signature, return_type=TypeRefWithVars("i64"))
360
+ self.egraph.run_program(
361
+ bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None)
362
+ )
363
+ return name
364
+
365
+ def cost_table_name(self, ref: CallableRef) -> str:
366
+ return f"cost_table_{self.callable_ref_to_egg(ref)[0]}"
367
+
218
368
  def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
219
369
  match fact:
220
370
  case EqDecl(tp, left, right):
@@ -225,7 +375,7 @@ class EGraphState:
225
375
  case _:
226
376
  assert_never(fact)
227
377
 
228
- def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
378
+ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C901, PLR0912
229
379
  """
230
380
  Returns the egg function name for a callable reference, registering it if it is not already registered.
231
381
 
@@ -245,9 +395,12 @@ class EGraphState:
245
395
  case ConstantDecl(tp, _):
246
396
  # Use constructor decleration instead of constant b/c constants cannot be extracted
247
397
  # https://github.com/egraphs-good/egglog/issues/334
248
- self.egraph.run_program(
249
- bindings.Constructor(span(), egg_name, bindings.Schema([], self.type_ref_to_egg(tp)), None, False)
250
- )
398
+ is_function = self.__egg_decls__._classes[tp.name].builtin
399
+ schema = bindings.Schema([], self.type_ref_to_egg(tp))
400
+ if is_function:
401
+ self.egraph.run_program(bindings.FunctionCommand(span(), egg_name, schema, None))
402
+ else:
403
+ self.egraph.run_program(bindings.Constructor(span(), egg_name, schema, None, False))
251
404
  case FunctionDecl(signature, builtin, _, merge):
252
405
  if isinstance(signature, FunctionSignature):
253
406
  reverse_args = signature.reverse_args
@@ -263,7 +416,7 @@ class EGraphState:
263
416
  self.egraph.run_program(bindings.Relation(span(), egg_name, schema.input))
264
417
  else:
265
418
  self.egraph.run_program(
266
- bindings.Function(
419
+ bindings.FunctionCommand(
267
420
  span(),
268
421
  egg_name,
269
422
  self._signature_to_egg_schema(signature),
@@ -347,13 +500,26 @@ class EGraphState:
347
500
  """
348
501
  Create a mapping of egglog function name to Python function name, for use in the serialized format
349
502
  for better visualization.
503
+
504
+ Includes cost tables
350
505
  """
351
506
  return {
352
507
  k: pretty_callable_ref(self.__egg_decls__, next(iter(v)))
353
508
  for k, v in self.egg_fn_to_callable_refs.items()
354
509
  if len(v) == 1
510
+ } | {
511
+ self.cost_table_name(ref): f"cost({pretty_callable_ref(self.__egg_decls__, ref, include_all_args=True)})"
512
+ for ref in self.cost_callables
355
513
  }
356
514
 
515
+ def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:
516
+ """
517
+ Given a list of egglog functions, returns all the possible Python function strings
518
+ """
519
+ for name in names:
520
+ for c in self.egg_fn_to_callable_refs[name]:
521
+ yield pretty_callable_ref(self.__egg_decls__, c)
522
+
357
523
  def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
358
524
  # transform all expressions with multiple parents into a let binding, so that less expressions
359
525
  # are sent to egglog. Only for performance reasons.
@@ -0,0 +1,64 @@
1
+ # mypy: disable-error-code="empty-body"
2
+
3
+ """
4
+ Join Tree (custom costs)
5
+ ========================
6
+
7
+ Example of using custom cost functions for jointree.
8
+
9
+ From https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from egglog import *
15
+
16
+
17
+ class JoinTree(Expr):
18
+ def __init__(self, name: StringLike) -> None: ...
19
+
20
+ def join(self, other: JoinTree) -> JoinTree: ...
21
+
22
+ @method(merge=lambda old, new: old.min(new)) # type:ignore[prop-decorator]
23
+ @property
24
+ def size(self) -> i64: ...
25
+
26
+
27
+ ra = JoinTree("a")
28
+ rb = JoinTree("b")
29
+ rc = JoinTree("c")
30
+ rd = JoinTree("d")
31
+ re = JoinTree("e")
32
+ rf = JoinTree("f")
33
+
34
+ query = ra.join(rb).join(rc).join(rd).join(re).join(rf)
35
+
36
+ egraph = EGraph()
37
+ egraph.register(
38
+ set_(ra.size).to(50),
39
+ set_(rb.size).to(200),
40
+ set_(rc.size).to(10),
41
+ set_(rd.size).to(123),
42
+ set_(re.size).to(10000),
43
+ set_(rf.size).to(1),
44
+ )
45
+
46
+
47
+ @egraph.register
48
+ def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: i64):
49
+ # cost of relation is its size minus 1, since the string arg will have a cost of 1 as well
50
+ yield rule(JoinTree(s).size == asize).then(set_cost(JoinTree(s), asize - 1))
51
+ # cost/size of join is product of sizes
52
+ yield rule(a.join(b), a.size == asize, b.size == bsize).then(
53
+ set_(a.join(b).size).to(asize * bsize), set_cost(a.join(b), asize * bsize)
54
+ )
55
+ # associativity
56
+ yield rewrite(a.join(b)).to(b.join(a))
57
+ # commutativity
58
+ yield rewrite(a.join(b).join(c)).to(a.join(b.join(c)))
59
+
60
+
61
+ egraph.register(query)
62
+ egraph.run(1000)
63
+ print(egraph.extract(query))
64
+ print(egraph.extract(query.size))
egglog/exp/array_api.py CHANGED
@@ -729,7 +729,7 @@ int64 = DType.int64
729
729
  _DTYPES = [float64, float32, int32, int64, DType.object]
730
730
 
731
731
  converter(type, DType, lambda x: convert(np.dtype(x), DType))
732
- converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
732
+ converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type:ignore[call-overload]
733
733
 
734
734
 
735
735
  @array_api_ruleset.register
egglog/pretty.py CHANGED
@@ -67,7 +67,9 @@ UNARY_METHODS = {
67
67
  "__invert__": "~",
68
68
  }
69
69
 
70
- AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
70
+ AllDecls: TypeAlias = (
71
+ RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl | BackOffDecl
72
+ )
71
73
 
72
74
 
73
75
  def pretty_decl(
@@ -98,6 +100,7 @@ def pretty_callable_ref(
98
100
  ref: CallableRef,
99
101
  first_arg: ExprDecl | None = None,
100
102
  bound_tp_params: tuple[JustTypeRef, ...] | None = None,
103
+ include_all_args: bool = False,
101
104
  ) -> str:
102
105
  """
103
106
  Pretty print a callable reference, using a dummy value for
@@ -115,6 +118,13 @@ def pretty_callable_ref(
115
118
  # Either returns a function or a function with args. If args are provided, they would just be called,
116
119
  # on the function, so return them, because they are dummies
117
120
  if isinstance(res, tuple):
121
+ # If we want to include all args as ARG_STR, then we need to figure out how many to use
122
+ # used for set_cost so that `cost(E(...))` will show up as a call
123
+ if include_all_args:
124
+ signature = decls.get_callable_decl(ref).signature
125
+ assert isinstance(signature, FunctionSignature)
126
+ correct_args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * len(signature.arg_types)
127
+ return f"{res[0]}({', '.join(context(a, parens=False, unwrap_lit=True) for a in correct_args)})"
118
128
  return res[0]
119
129
  return res
120
130
 
@@ -180,16 +190,27 @@ class TraverseContext:
180
190
  case _:
181
191
  for e in exprs:
182
192
  self(e.expr)
183
- case RunDecl(_, until):
193
+ case RunDecl(_, until, scheduler):
184
194
  if until:
185
195
  for f in until:
186
196
  self(f)
197
+ if scheduler:
198
+ self(scheduler)
187
199
  case PartialCallDecl(c):
188
200
  self(c)
189
201
  case CombinedRulesetDecl(_):
190
202
  pass
191
203
  case DefaultRewriteDecl():
192
204
  pass
205
+ case SetCostDecl(_, e, c):
206
+ self(e)
207
+ self(c)
208
+ case BackOffDecl():
209
+ pass
210
+ case LetSchedulerDecl(scheduler, schedule):
211
+ self(scheduler)
212
+ self(schedule)
213
+
193
214
  case _:
194
215
  assert_never(decl)
195
216
 
@@ -227,7 +248,11 @@ class PrettyContext:
227
248
  # it would take up is > than some constant (~ line length).
228
249
  line_diff: int = len(expr) - LINE_DIFFERENCE
229
250
  n_parents = self.parents[decl]
230
- if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
251
+ if n_parents > 1 and (
252
+ n_parents * line_diff > MAX_LINE_LENGTH
253
+ # Schedulers with multiple parents need to be the same object, b/c are created with hidden UUIDs
254
+ or tp_name == "scheduler"
255
+ ):
231
256
  self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False)
232
257
  return expr_name
233
258
  return expr
@@ -285,6 +310,8 @@ class PrettyContext:
285
310
  return f"{change}({self(expr)})", "action"
286
311
  case PanicDecl(s):
287
312
  return f"panic({s!r})", "action"
313
+ case SetCostDecl(_, expr, cost):
314
+ return f"set_cost({self(expr)}, {self(cost, unwrap_lit=True)})", "action"
288
315
  case EqDecl(_, left, right):
289
316
  return f"eq({self(left)}).to({self(right)})", "fact"
290
317
  case RulesetDecl(rules):
@@ -305,16 +332,27 @@ class PrettyContext:
305
332
  return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule"
306
333
  args = ", ".join(map(self, schedules))
307
334
  return f"seq({args})", "schedule"
308
- case RunDecl(ruleset_name, until):
335
+ case LetSchedulerDecl(scheduler, schedule):
336
+ return f"{self(scheduler, parens=True)}.scope({self(schedule, parens=True)})", "schedule"
337
+ case RunDecl(ruleset_name, until, scheduler):
309
338
  ruleset = self.decls._rulesets[ruleset_name]
310
339
  ruleset_str = self(ruleset, ruleset_name=ruleset_name)
311
- if not until:
340
+ if not until and not scheduler:
312
341
  return ruleset_str, "schedule"
313
- args = ", ".join(map(self, until))
314
- return f"run({ruleset_str}, {args})", "schedule"
342
+ arg_lst = list(map(self, until or []))
343
+ if scheduler:
344
+ arg_lst.append(f"scheduler={self(scheduler)}")
345
+ return f"run({ruleset_str}, {', '.join(arg_lst)})", "schedule"
315
346
  case DefaultRewriteDecl():
316
347
  msg = "default rewrites should not be pretty printed"
317
348
  raise TypeError(msg)
349
+ case BackOffDecl(_, match_limit, ban_length):
350
+ list_args: list[str] = []
351
+ if match_limit is not None:
352
+ list_args.append(f"match_limit={match_limit}")
353
+ if ban_length is not None:
354
+ list_args.append(f"ban_length={ban_length}")
355
+ return f"back_off({', '.join(list_args)})", "scheduler"
318
356
  assert_never(decl)
319
357
 
320
358
  def _call(
egglog/runtime.py CHANGED
@@ -20,6 +20,8 @@ from inspect import Parameter, Signature
20
20
  from itertools import zip_longest
21
21
  from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin
22
22
 
23
+ from typing_extensions import assert_never
24
+
23
25
  from .declarations import *
24
26
  from .pretty import *
25
27
  from .thunk import Thunk
@@ -36,6 +38,7 @@ __all__ = [
36
38
  "RuntimeClass",
37
39
  "RuntimeExpr",
38
40
  "RuntimeFunction",
41
+ "create_callable",
39
42
  "define_expr_method",
40
43
  "resolve_callable",
41
44
  "resolve_type_annotation",
@@ -340,7 +343,7 @@ class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory):
340
343
 
341
344
  # Make hashable so can go in Union
342
345
  def __hash__(self) -> int:
343
- return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
346
+ return hash(self.__egg_tp__)
344
347
 
345
348
  def __eq__(self, other: object) -> bool:
346
349
  """
@@ -478,6 +481,9 @@ class RuntimeFunction(DelayedDeclerations):
478
481
  bound_tp_params = args
479
482
  return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
480
483
 
484
+ def __repr__(self) -> str:
485
+ return str(self)
486
+
481
487
 
482
488
  def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
483
489
  """
@@ -629,14 +635,22 @@ for name in TYPE_DEFINED_METHODS:
629
635
 
630
636
 
631
637
  for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
638
+ method_name = f"__r{name[2:]}" if r_method else name
632
639
 
633
- def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object:
640
+ def _numeric_binary_method(
641
+ self: object, other: object, name: str = name, r_method: bool = r_method, method_name: str = method_name
642
+ ) -> object:
634
643
  """
635
644
  Implements numeric binary operations.
636
645
 
637
646
  Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
638
647
  the LHS or the RHS as exactly the right type and then upcasting the other to that type.
639
648
  """
649
+ # First check if we have a preserved method for this:
650
+ if isinstance(self, RuntimeExpr) and (
651
+ (preserved_method := self.__egg_class_decl__.preserved_methods.get(method_name)) is not None
652
+ ):
653
+ return preserved_method.__get__(self)(other)
640
654
  # 1. switch if reversed method
641
655
  if r_method:
642
656
  self, other = other, self
@@ -662,7 +676,6 @@ for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
662
676
  fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
663
677
  return fn(other)
664
678
 
665
- method_name = f"__r{name[2:]}" if r_method else name
666
679
  setattr(RuntimeExpr, method_name, _numeric_binary_method)
667
680
 
668
681
 
@@ -670,16 +683,39 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
670
683
  """
671
684
  Resolves a runtime callable into a ref
672
685
  """
686
+ # TODO: Make runtime class work with __match_args__
687
+ if isinstance(callable, RuntimeClass):
688
+ return InitRef(callable.__egg_tp__.name), callable.__egg_decls__
673
689
  match callable:
674
690
  case RuntimeFunction(decls, ref, _):
675
691
  return ref(), decls()
676
- case RuntimeClass(thunk, tp):
677
- return InitRef(tp.name), thunk()
678
692
  case RuntimeExpr(decl_thunk, expr_thunk):
679
693
  if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(
680
694
  expr.callable, ConstantRef | ClassVariableRef
681
695
  ):
682
696
  raise NotImplementedError(f"Can only turn constants or classvars into callable refs, not {expr}")
683
697
  return expr.callable, decl_thunk()
698
+ case types.MethodWrapperType() if isinstance((slf := callable.__self__), RuntimeClass):
699
+ return MethodRef(slf.__egg_tp__.name, callable.__name__), slf.__egg_decls__
684
700
  case _:
685
701
  raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
702
+
703
+
704
+ def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeClass | RuntimeFunction | RuntimeExpr:
705
+ """
706
+ Creates a callable object from a callable ref. This might not actually be callable, if the ref is a constant
707
+ or classvar then it is a value
708
+ """
709
+ match ref:
710
+ case InitRef(name):
711
+ return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name))
712
+ case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef():
713
+ bound = JustTypeRef(ref.class_name) if isinstance(ref, ClassMethodRef) else None
714
+ return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound)
715
+ case ConstantRef(name):
716
+ tp = decls._constants[name].type_ref
717
+ case ClassVariableRef(cls_name, var_name):
718
+ tp = decls._classes[cls_name].class_variables[var_name].type_ref
719
+ case _:
720
+ assert_never(ref)
721
+ return RuntimeExpr.__from_values__(decls, TypedExprDecl(tp, CallDecl(ref)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 11.1.0
3
+ Version: 11.3.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers
@@ -51,6 +51,7 @@ Requires-Dist: egglog[array] ; extra == 'docs'
51
51
  Requires-Dist: line-profiler ; extra == 'docs'
52
52
  Requires-Dist: sphinxcontrib-mermaid ; extra == 'docs'
53
53
  Requires-Dist: ablog ; extra == 'docs'
54
+ Requires-Dist: jupytext ; extra == 'docs'
54
55
  Provides-Extra: array
55
56
  Provides-Extra: dev
56
57
  Provides-Extra: test
@@ -72,3 +73,20 @@ Please see the [documentation](https://egglog-python.readthedocs.io/) for more i
72
73
 
73
74
  Come say hello [on the e-graphs Zulip](https://egraphs.zulipchat.com/#narrow/stream/375765-egglog/) or [open an issue](https://github.com/egraphs-good/egglog-python/issues/new/choose)!
74
75
 
76
+ ## How to cite
77
+
78
+ If you use **egglog-python** in academic work, please cite the paper:
79
+
80
+ ```bibtex
81
+ @misc{Shanabrook2023EgglogPython,
82
+ title = {Egglog Python: A Pythonic Library for E-graphs},
83
+ author = {Saul Shanabrook},
84
+ year = {2023},
85
+ eprint = {2305.04311},
86
+ archivePrefix = {arXiv},
87
+ primaryClass = {cs.PL},
88
+ doi = {10.48550/arXiv.2305.04311},
89
+ url = {https://arxiv.org/abs/2305.04311},
90
+ note = {Presented at EGRAPHS@PLDI 2023}
91
+ }
92
+
@@ -1,16 +1,16 @@
1
- egglog-11.1.0.dist-info/METADATA,sha256=u53IcPcs6w0Rry3b4tw2IDFN0qwIOIgbGmj_3Op0FPo,4009
2
- egglog-11.1.0.dist-info/WHEEL,sha256=Ez8BfzzV1AeIEg8pRvkaJ17Xz6wwwvCZlob2fTcZ8v0,127
3
- egglog-11.1.0.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
1
+ egglog-11.3.0.dist-info/METADATA,sha256=_7z3OUAk7vWdO7bwhHHUwhirFqvbHsguTnkGY9qvidQ,4554
2
+ egglog-11.3.0.dist-info/WHEEL,sha256=2zSSD3q1g7ifuISAz_to-1QnVz3YSdCfiH5A6NWI9so,127
3
+ egglog-11.3.0.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
4
4
  egglog/__init__.py,sha256=0r3MzQbU-9U0fSCeAoJ3deVhZ77tI-1tf8A_WFOhbJs,344
5
- egglog/bindings.cpython-312-powerpc64-linux-gnu.so,sha256=TiW9sRvhUe13SijXYNrhC5CgDSZMYHcF4VC5VqV4lzc,204517056
6
- egglog/bindings.pyi,sha256=QDgs0hKl2xl4Q6f7hXjfDANoMADRrGTYTdGYdri7Vqg,13642
7
- egglog/builtins.py,sha256=bF3GanSmsrIIutvuMUmr4SnSNF0zAEwjmDYWbPEEayo,30204
5
+ egglog/bindings.cpython-312-powerpc64-linux-gnu.so,sha256=X3bwWPHcfAyJa5AZwD2SFOYcc4zkrE75pbn8j4kj7_o,170295976
6
+ egglog/bindings.pyi,sha256=Y_YpdAKmVHZ0nIHTTPeg0sigBEPiS8z_U-Z161zKSK4,15330
7
+ egglog/builtins.py,sha256=qXbBOtT1qwgR9uQu9yb1gUp4dm2L6BgvJIWYU4zCzuw,30317
8
8
  egglog/config.py,sha256=yM3FIcVCKnhWZmHD0pxkzx7ah7t5PxZx3WUqKtA9tjU,168
9
9
  egglog/conversion.py,sha256=DO76lxRbbTqHs6hRo_Lckvtwu0c6LaKoX7k5_B2AfuY,11238
10
- egglog/declarations.py,sha256=CA8gWolJChVt6-xDL7XvV5ORjNj3O8PREobYCzVFxDI,25784
10
+ egglog/declarations.py,sha256=pc2KEYwyKNQXuKndbBCC6iuVROgHkaSKJJf_s9liZi8,26260
11
11
  egglog/deconstruct.py,sha256=CovORrpROMIwOLgERPUw8doqRUDUehj6LJEB5FMbpmI,5635
12
- egglog/egraph.py,sha256=niBx_EAbQhzjupJs_wZ9oIbjF7W_ttx9D0u1iB059HM,60601
13
- egglog/egraph_state.py,sha256=8-u0BLAHlWgoo_KiyW_aqC6nj_X0XgMFRlXWIyPjF64,28356
12
+ egglog/egraph.py,sha256=zJpAoC6JXXqnRsp24CvQN5M5EZ0PrOj93R9U4w6bqlw,65417
13
+ egglog/egraph_state.py,sha256=3VLwkAsR3oCydHLx_BXmFw4UHXgdZ9jooQdWUcQeUD0,36375
14
14
  egglog/examples/README.rst,sha256=ztTvpofR0eotSqGoCy_C1fPLDPCncjvcqDanXtLHNNU,232
15
15
  egglog/examples/__init__.py,sha256=wm9evUbMPfbtylXIjbDdRTAVMLH4OjT4Z77PCBFyaPU,31
16
16
  egglog/examples/bignum.py,sha256=jfL57XXpQqIqizQQ3sSUCCjTrkdjtB71BmjrQIQorQk,535
@@ -18,6 +18,7 @@ egglog/examples/bool.py,sha256=e0z2YoYJsLlhpSADZK1yRYHzilyxSZWGiYAaM0DQ_Gw,695
18
18
  egglog/examples/eqsat_basic.py,sha256=2xtM81gG9Br72mr58N-2BUeksR7C_UXnZJ4MvzSPplc,869
19
19
  egglog/examples/fib.py,sha256=BOHxKWA7jGx4FURBmfmuZKfLo6xq9-uXAwAXjYid7LU,492
20
20
  egglog/examples/higher_order_functions.py,sha256=DNLIQfPJCX_DOLbHNiaYsfvcFIYCYOsRUqp99r9bpc8,1063
21
+ egglog/examples/jointree.py,sha256=TLlH7TsQzWfadqDo7qeTprFhLdQmj59AQaGse81RIKk,1714
21
22
  egglog/examples/lambda_.py,sha256=iQvwaXVhp2VNOMS7j1WwceZaiq3dqqilwUkMcW5GFBE,8194
22
23
  egglog/examples/matrix.py,sha256=7_mPcMcgE-t_GJDyf76-nv3xhPIeN2mvFkc_p_Gnr8g,4961
23
24
  egglog/examples/multiset.py,sha256=IBOsB80DkXQ07dYnk1odi27q77LH80Z8zysuLE-Q8F8,1445
@@ -25,7 +26,7 @@ egglog/examples/ndarrays.py,sha256=mfr410eletH8gfdg-P8L90vlF6TUifvYV_-ryOwvZZE,4
25
26
  egglog/examples/resolution.py,sha256=BJd5JClA3DBVGfiVRa-H0gbbFvIqeP3uYbhCXHblSQc,2119
26
27
  egglog/examples/schedule_demo.py,sha256=JbXdPII7_adxtgyKVAiqCyV2sj88VZ-DhomYrdn8vuc,618
27
28
  egglog/exp/__init__.py,sha256=nPtzrH1bz1LVZhZCuS0S9Qild8m5gEikjOVqWAFIa88,49
28
- egglog/exp/array_api.py,sha256=B4tr5DyC8jkKxdax7k98LK-iXOrsBceSaijwYTEnG-Q,65548
29
+ egglog/exp/array_api.py,sha256=dKgEufUIyoT7J_RvnyGtOkg_DK25ZnxIgt7olVygaH8,65547
29
30
  egglog/exp/array_api_jit.py,sha256=Ak4QhmfYLKimjPf8ffUvPv62OhxOneJ9NEWQJuMxKJc,1680
30
31
  egglog/exp/array_api_loopnest.py,sha256=-kbyorlGxvlaNsLx1nmLfEZHQM7VMEBwSKtV0l-bs0g,2444
31
32
  egglog/exp/array_api_numba.py,sha256=X3H1TnCjPL92uVm6OvcWMJ11IeorAE58zWiOX6huPv4,2696
@@ -33,13 +34,13 @@ egglog/exp/array_api_program_gen.py,sha256=qnve8iqklRQVyGChllG8ZAjAffRpezmdxc3Id
33
34
  egglog/exp/program_gen.py,sha256=CavsD70x0ERS87V4OU9zkgMvLXswGEpb1ZZFK0WyN_g,13033
34
35
  egglog/exp/siu_examples.py,sha256=yZ-sgH2Y12iTdwBUumP7D2OtCGL83M6pPW7PMobVFXc,719
35
36
  egglog/ipython_magic.py,sha256=2hs3g2cSiyDmbCvE2t1OINmu17Bb8MWV--2DpEWwO7I,1189
36
- egglog/pretty.py,sha256=TCkwZIecdFT5eMIy0GGAZuKvPUt0fQeh5K1srgnCJLo,20136
37
+ egglog/pretty.py,sha256=Sv3H9e0CJcZv3-ylijP58ApCQ5w1BOdXl2VDw6Hst4Y,22061
37
38
  egglog/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
- egglog/runtime.py,sha256=8y1YQ5vFuoHr9B8LGTeQBHZin7S8LYHlsR93LlNl2uo,27979
39
+ egglog/runtime.py,sha256=NUA0O-_urneP54RqXRcPLQIlFzNwPacPKIMxGpwAkus,29672
39
40
  egglog/thunk.py,sha256=MrAlPoGK36VQrUrq8PWSaJFu42sPL0yupwiH18lNips,2271
40
41
  egglog/type_constraint_solver.py,sha256=U2GjLgbebTLv5QY8_TU0As5wMKL5_NxkHLen9rpfMwI,4518
41
42
  egglog/version_compat.py,sha256=EaKRMIOPcatrx9XjCofxZD6Nr5WOooiWNdoapkKleww,3512
42
43
  egglog/visualizer.css,sha256=eL0POoThQRc0P4OYnDT-d808ln9O5Qy6DizH9Z5LgWc,259398
43
44
  egglog/visualizer.js,sha256=2qZZ-9W_INJx4gZMYjnVXl27IjT_JNuQyEeI2dbjWoU,3753315
44
45
  egglog/visualizer_widget.py,sha256=LtVfzOtv2WeKtNuILQQ_9SOHWvRr8YdBYQDKQSgry_s,1319
45
- egglog-11.1.0.dist-info/RECORD,,
46
+ egglog-11.3.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: maturin (1.9.3)
2
+ Generator: maturin (1.9.4)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp312-cp312-manylinux_2_17_ppc64.manylinux2014_ppc64