egglog 8.0.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Binary file
egglog/bindings.pyi CHANGED
@@ -5,8 +5,6 @@ from typing import TypeAlias
5
5
 
6
6
  from typing_extensions import final
7
7
 
8
- HIGH_COST: int
9
-
10
8
  @final
11
9
  class SerializedEGraph:
12
10
  def inline_leaves(self) -> None: ...
@@ -14,13 +12,14 @@ class SerializedEGraph:
14
12
  def to_dot(self) -> str: ...
15
13
  def to_json(self) -> str: ...
16
14
  def map_ops(self, map: dict[str, str]) -> None: ...
17
- def split_e_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
15
+ def split_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
18
16
 
19
17
  @final
20
18
  class PyObjectSort:
21
19
  def __init__(self) -> None: ...
22
20
  def store(self, __o: object, /) -> _Expr: ...
23
21
 
22
+ def parse_program(__input: str, /, filename: str | None = None) -> list[_Command]: ...
24
23
  @final
25
24
  class EGraph:
26
25
  def __init__(
@@ -29,11 +28,9 @@ class EGraph:
29
28
  *,
30
29
  fact_directory: str | Path | None = None,
31
30
  seminaive: bool = True,
32
- terms_encoding: bool = False,
33
31
  record: bool = False,
34
32
  ) -> None: ...
35
33
  def commands(self) -> str | None: ...
36
- def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ...
37
34
  def run_program(self, *commands: _Command) -> list[str]: ...
38
35
  def extract_report(self) -> _ExtractReport | None: ...
39
36
  def run_report(self) -> RunReport | None: ...
@@ -250,6 +247,7 @@ _Action: TypeAlias = Let | Set | Change | Union | Panic | Expr_ | Extract
250
247
 
251
248
  @final
252
249
  class FunctionDecl:
250
+ span: Span
253
251
  name: str
254
252
  schema: Schema
255
253
  default: _Expr | None
@@ -261,6 +259,7 @@ class FunctionDecl:
261
259
 
262
260
  def __init__(
263
261
  self,
262
+ span: Span,
264
263
  name: str,
265
264
  schema: Schema,
266
265
  default: _Expr | None = None,
@@ -273,7 +272,8 @@ class FunctionDecl:
273
272
 
274
273
  @final
275
274
  class Variant:
276
- def __init__(self, name: str, types: list[str], cost: int | None = None) -> None: ...
275
+ def __init__(self, span: Span, name: str, types: list[str], cost: int | None = None) -> None: ...
276
+ span: Span
277
277
  name: str
278
278
  types: list[str]
279
279
  cost: int | None
@@ -379,6 +379,23 @@ class Sequence:
379
379
 
380
380
  _Schedule: TypeAlias = Saturate | Repeat | Run | Sequence
381
381
 
382
+ ##
383
+ # Subdatatypes
384
+ ##
385
+
386
+ @final
387
+ class SubVariants:
388
+ def __init__(self, variants: list[Variant]) -> None: ...
389
+ variants: list[Variant]
390
+
391
+ @final
392
+ class NewSort:
393
+ def __init__(self, name: str, args: list[_Expr]) -> None: ...
394
+ name: str
395
+ args: list[_Expr]
396
+
397
+ _Subdatatypes: TypeAlias = SubVariants | NewSort
398
+
382
399
  ##
383
400
  # Commands
384
401
  ##
@@ -391,22 +408,23 @@ class SetOption:
391
408
 
392
409
  @final
393
410
  class Datatype:
411
+ span: Span
394
412
  name: str
395
413
  variants: list[Variant]
396
- def __init__(self, name: str, variants: list[Variant]) -> None: ...
414
+ def __init__(self, span: Span, name: str, variants: list[Variant]) -> None: ...
397
415
 
398
416
  @final
399
- class Declare:
417
+ class Datatypes:
400
418
  span: Span
401
- name: str
402
- sort: str
403
- def __init__(self, span: Span, name: str, sort: str) -> None: ...
419
+ datatypes: list[tuple[Span, str, _Subdatatypes]]
420
+ def __init__(self, span: Span, datatypes: list[tuple[Span, str, _Subdatatypes]]) -> None: ...
404
421
 
405
422
  @final
406
423
  class Sort:
424
+ span: Span
407
425
  name: str
408
426
  presort_and_args: tuple[str, list[_Expr]] | None
409
- def __init__(self, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
427
+ def __init__(self, span: Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
410
428
 
411
429
  @final
412
430
  class Function:
@@ -452,22 +470,17 @@ class RunSchedule:
452
470
 
453
471
  @final
454
472
  class Simplify:
473
+ span: Span
455
474
  expr: _Expr
456
475
  schedule: _Schedule
457
- def __init__(self, expr: _Expr, schedule: _Schedule) -> None: ...
458
-
459
- @final
460
- class Calc:
461
- span: Span
462
- identifiers: list[IdentSort]
463
- exprs: list[_Expr]
464
- def __init__(self, span: Span, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
476
+ def __init__(self, span: Span, expr: _Expr, schedule: _Schedule) -> None: ...
465
477
 
466
478
  @final
467
479
  class QueryExtract:
480
+ span: Span
468
481
  variants: int
469
482
  expr: _Expr
470
- def __init__(self, variants: int, expr: _Expr) -> None: ...
483
+ def __init__(self, span: Span, variants: int, expr: _Expr) -> None: ...
471
484
 
472
485
  @final
473
486
  class Check:
@@ -484,20 +497,23 @@ class PrintFunction:
484
497
 
485
498
  @final
486
499
  class PrintSize:
500
+ span: Span
487
501
  name: str | None
488
- def __init__(self, name: str | None) -> None: ...
502
+ def __init__(self, span: Span, name: str | None) -> None: ...
489
503
 
490
504
  @final
491
505
  class Output:
506
+ span: Span
492
507
  file: str
493
508
  exprs: list[_Expr]
494
- def __init__(self, file: str, exprs: list[_Expr]) -> None: ...
509
+ def __init__(self, span: Span, file: str, exprs: list[_Expr]) -> None: ...
495
510
 
496
511
  @final
497
512
  class Input:
513
+ span: Span
498
514
  name: str
499
515
  file: str
500
- def __init__(self, name: str, file: str) -> None: ...
516
+ def __init__(self, span: Span, name: str, file: str) -> None: ...
501
517
 
502
518
  @final
503
519
  class Push:
@@ -506,29 +522,29 @@ class Push:
506
522
 
507
523
  @final
508
524
  class Pop:
525
+ span: Span
509
526
  length: int
510
- def __init__(self, length: int) -> None: ...
527
+ def __init__(self, span: Span, length: int) -> None: ...
511
528
 
512
529
  @final
513
530
  class Fail:
531
+ span: Span
514
532
  command: _Command
515
- def __init__(self, command: _Command) -> None: ...
533
+ def __init__(self, span: Span, command: _Command) -> None: ...
516
534
 
517
535
  @final
518
536
  class Include:
537
+ span: Span
519
538
  path: str
520
- def __init__(self, path: str) -> None: ...
521
-
522
- @final
523
- class CheckProof:
524
- def __init__(self) -> None: ...
539
+ def __init__(self, span: Span, path: str) -> None: ...
525
540
 
526
541
  @final
527
542
  class Relation:
543
+ span: Span
528
544
  constructor: str
529
545
  inputs: list[str]
530
546
 
531
- def __init__(self, constructor: str, inputs: list[str]) -> None: ...
547
+ def __init__(self, span: Span, constructor: str, inputs: list[str]) -> None: ...
532
548
 
533
549
  @final
534
550
  class PrintOverallStatistics:
@@ -543,7 +559,7 @@ class UnstableCombinedRuleset:
543
559
  _Command: TypeAlias = (
544
560
  SetOption
545
561
  | Datatype
546
- | Declare
562
+ | Datatypes
547
563
  | Sort
548
564
  | Function
549
565
  | AddRuleset
@@ -552,7 +568,6 @@ _Command: TypeAlias = (
552
568
  | BiRewriteCommand
553
569
  | ActionCommand
554
570
  | RunSchedule
555
- | Calc
556
571
  | Simplify
557
572
  | QueryExtract
558
573
  | Check
@@ -564,7 +579,6 @@ _Command: TypeAlias = (
564
579
  | Pop
565
580
  | Fail
566
581
  | Include
567
- | CheckProof
568
582
  | Relation
569
583
  | PrintOverallStatistics
570
584
  | UnstableCombinedRuleset
egglog/conversion.py CHANGED
@@ -195,6 +195,6 @@ def _get_tp(x: object) -> TypeName | type:
195
195
  return TypeName(x.__egg_typed_expr__.tp.name)
196
196
  tp = type(x)
197
197
  # If this value has a custom metaclass, let's use that as our index instead of the type
198
- if type(tp) != type:
198
+ if type(tp) is not type:
199
199
  return type(tp)
200
200
  return tp
egglog/declarations.py CHANGED
@@ -622,9 +622,8 @@ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExp
622
622
  res = replacements[typed_expr]
623
623
  else:
624
624
  match typed_expr.expr:
625
- case (
625
+ case CallDecl(callable, args, bound_tp_params) | PartialCallDecl(
626
626
  CallDecl(callable, args, bound_tp_params)
627
- | PartialCallDecl(CallDecl(callable, args, bound_tp_params))
628
627
  ):
629
628
  new_args = tuple(_inner(a) for a in args)
630
629
  call_decl = CallDecl(callable, new_args, bound_tp_params)
egglog/egraph.py CHANGED
@@ -486,7 +486,7 @@ class _ExprMetaclass(type):
486
486
  return isinstance(instance, RuntimeExpr)
487
487
 
488
488
 
489
- def _generate_class_decls( # noqa: C901
489
+ def _generate_class_decls( # noqa: C901,PLR0912
490
490
  namespace: dict[str, Any],
491
491
  frame: FrameType,
492
492
  builtin: bool,
@@ -704,8 +704,12 @@ def _fn_decl(
704
704
  # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
705
705
  if "Callable" not in hint_globals:
706
706
  hint_globals["Callable"] = Callable
707
-
708
- hints = get_type_hints(fn, hint_globals, hint_locals)
707
+ # Instead of passing both globals and locals, just pass the globals. Otherwise, for some reason forward references
708
+ # won't be resolved correctly
709
+ # We need this to be false so it returns "__forward_value__" https://github.com/python/cpython/blob/440ed18e08887b958ad50db1b823e692a747b671/Lib/typing.py#L919
710
+ # https://github.com/egraphs-good/egglog-python/issues/210
711
+ hint_globals.update(hint_locals)
712
+ hints = get_type_hints(fn, hint_globals)
709
713
 
710
714
  params = list(signature(fn).parameters.values())
711
715
 
@@ -1021,7 +1025,7 @@ class EGraph(_BaseModule):
1021
1025
  """
1022
1026
  Loads a CSV file and sets it as *input, output of the function.
1023
1027
  """
1024
- self._egraph.run_program(bindings.Input(self._callable_to_egg(fn), path))
1028
+ self._egraph.run_program(bindings.Input(bindings.DUMMY_SPAN, self._callable_to_egg(fn), path))
1025
1029
 
1026
1030
  def _callable_to_egg(self, fn: object) -> str:
1027
1031
  ref, decls = resolve_callable(fn)
@@ -1063,7 +1067,7 @@ class EGraph(_BaseModule):
1063
1067
  typed_expr = runtime_expr.__egg_typed_expr__
1064
1068
  # Must also register type
1065
1069
  egg_expr = self._state.typed_expr_to_egg(typed_expr)
1066
- self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
1070
+ self._egraph.run_program(bindings.Simplify(bindings.DUMMY_SPAN, egg_expr, egg_schedule))
1067
1071
  extract_report = self._egraph.extract_report()
1068
1072
  if not isinstance(extract_report, bindings.Best):
1069
1073
  msg = "No extract report saved"
@@ -1118,7 +1122,7 @@ class EGraph(_BaseModule):
1118
1122
  """
1119
1123
  Checks that one of the facts is not true
1120
1124
  """
1121
- self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
1125
+ self._egraph.run_program(bindings.Fail(bindings.DUMMY_SPAN, self._facts_to_check(facts)))
1122
1126
 
1123
1127
  def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
1124
1128
  facts = _fact_likes(fact_likes)
@@ -1191,7 +1195,7 @@ class EGraph(_BaseModule):
1191
1195
  """
1192
1196
  Pop the current state of the egraph, reverting back to the previous state.
1193
1197
  """
1194
- self._egraph.run_program(bindings.Pop(1))
1198
+ self._egraph.run_program(bindings.Pop(bindings.DUMMY_SPAN, 1))
1195
1199
  self._state = self._state_stack.pop()
1196
1200
 
1197
1201
  def __enter__(self) -> Self:
@@ -1204,7 +1208,7 @@ class EGraph(_BaseModule):
1204
1208
  self.push()
1205
1209
  return self
1206
1210
 
1207
- def __exit__(self, exc_type, exc, exc_tb) -> None: # noqa: ANN001
1211
+ def __exit__(self, exc_type, exc, exc_tb) -> None:
1208
1212
  CURRENT_EGRAPH.reset(self._token_stack.pop())
1209
1213
  self.pop()
1210
1214
 
@@ -1262,7 +1266,7 @@ class EGraph(_BaseModule):
1262
1266
  )
1263
1267
  if split_primitive_outputs or split_functions:
1264
1268
  additional_ops = set(map(self._callable_to_egg, split_functions))
1265
- serialized.split_e_classes(self._egraph, additional_ops)
1269
+ serialized.split_classes(self._egraph, additional_ops)
1266
1270
  serialized.map_ops(self._state.op_mapping())
1267
1271
 
1268
1272
  for _ in range(n_inline_leaves):
@@ -1322,7 +1326,14 @@ class EGraph(_BaseModule):
1322
1326
  serialized = self._serialize(**kwargs)
1323
1327
  VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
1324
1328
 
1325
- def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None:
1329
+ def saturate(
1330
+ self,
1331
+ schedule: Schedule | None = None,
1332
+ *,
1333
+ expr: Expr | None = None,
1334
+ max: int = 1000,
1335
+ **kwargs: Unpack[GraphvizKwargs],
1336
+ ) -> None:
1326
1337
  """
1327
1338
  Saturate the egraph, running the given schedule until the egraph is saturated.
1328
1339
  It serializes the egraph at each step and returns a widget to visualize the egraph.
@@ -1330,6 +1341,8 @@ class EGraph(_BaseModule):
1330
1341
  from .visualizer_widget import VisualizerWidget
1331
1342
 
1332
1343
  def to_json() -> str:
1344
+ if expr:
1345
+ print(self.extract(expr))
1333
1346
  return self._serialize(**kwargs).to_json()
1334
1347
 
1335
1348
  egraphs = [to_json()]
egglog/egraph_state.py CHANGED
@@ -206,18 +206,25 @@ class EGraphState:
206
206
  self.egg_fn_to_callable_refs[egg_name].add(ref)
207
207
  match decl:
208
208
  case RelationDecl(arg_types, _, _):
209
- self.egraph.run_program(bindings.Relation(egg_name, [self.type_ref_to_egg(a) for a in arg_types]))
209
+ self.egraph.run_program(
210
+ bindings.Relation(bindings.DUMMY_SPAN, egg_name, [self.type_ref_to_egg(a) for a in arg_types])
211
+ )
210
212
  case ConstantDecl(tp, _):
211
213
  # Use function decleration instead of constant b/c constants cannot be extracted
212
214
  # https://github.com/egraphs-good/egglog/issues/334
213
215
  self.egraph.run_program(
214
- bindings.Function(bindings.FunctionDecl(egg_name, bindings.Schema([], self.type_ref_to_egg(tp))))
216
+ bindings.Function(
217
+ bindings.FunctionDecl(
218
+ bindings.DUMMY_SPAN, egg_name, bindings.Schema([], self.type_ref_to_egg(tp))
219
+ )
220
+ )
215
221
  )
216
222
  case FunctionDecl():
217
223
  if not decl.builtin:
218
224
  signature = decl.signature
219
225
  assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
220
226
  egg_fn_decl = bindings.FunctionDecl(
227
+ bindings.DUMMY_SPAN,
221
228
  egg_name,
222
229
  bindings.Schema(
223
230
  [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
@@ -261,7 +268,7 @@ class EGraphState:
261
268
  args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
262
269
  else:
263
270
  args = None
264
- self.egraph.run_program(bindings.Sort(egg_name, args))
271
+ self.egraph.run_program(bindings.Sort(bindings.DUMMY_SPAN, egg_name, args))
265
272
  # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
266
273
  # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
267
274
  # even if you never use that function.
@@ -323,7 +330,7 @@ class EGraphState:
323
330
  @overload
324
331
  def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
325
332
 
326
- def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
333
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912,C901
327
334
  """
328
335
  Convert an ExprDecl to an egg expression.
329
336
  """
@@ -486,7 +493,7 @@ class FromEggState:
486
493
  if term.name == "py-object":
487
494
  call = bindings.termdag_term_to_expr(self.termdag, term)
488
495
  expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
489
- if term.name == "unstable-fn":
496
+ elif term.name == "unstable-fn":
490
497
  # Get function name
491
498
  fn_term, *arg_terms = term.args
492
499
  fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
@@ -26,7 +26,7 @@ class MathList(Expr):
26
26
 
27
27
 
28
28
  @ruleset
29
- def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
29
+ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
30
30
  yield rewrite(Math(i) + Math(j)).to(Math(i + j))
31
31
  yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
32
32
  yield rewrite(MathList().map(f)).to(MathList())
egglog/exp/array_api.py CHANGED
@@ -1538,20 +1538,23 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
1538
1538
  egraph.register(expr)
1539
1539
  egraph.run(array_api_schedule)
1540
1540
  try:
1541
- return egraph.eval(prim_expr)
1541
+ extracted = egraph.extract(prim_expr)
1542
1542
  except EggSmolError as exc:
1543
- egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1543
+ # Try giving some context, by showing the smallest version of the larger expression
1544
1544
  try:
1545
- msg = f"Cannot simplify to primitive {egraph.extract(expr)}"
1546
- except EggSmolError:
1547
- msg = f"Cannot simplify to primitive or extract {expr}"
1548
-
1549
- # string = (
1550
- # egraph.as_egglog_string
1551
- # + "\n"
1552
- # + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
1553
- # )
1554
- # # save to "tmp.egg"
1555
- # with open("tmp.egg", "w") as f:
1556
- # f.write(string)
1545
+ expr_extracted = egraph.extract(expr)
1546
+ except EggSmolError as inner_exc:
1547
+ raise ValueError(f"Cannot simplify {expr}") from inner_exc
1548
+ egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1549
+ msg = f"Cannot simplify to primitive {expr_extracted}"
1557
1550
  raise ValueError(msg) from exc
1551
+ return egraph.eval(extracted)
1552
+
1553
+ # string = (
1554
+ # egraph.as_egglog_string
1555
+ # + "\n"
1556
+ # + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
1557
+ # )
1558
+ # # save to "tmp.egg"
1559
+ # with open("tmp.egg", "w") as f:
1560
+ # f.write(string)
@@ -17,18 +17,17 @@ def jit(fn: X) -> X:
17
17
  # 1. Create variables for each of the two args in the functions
18
18
  sig = inspect.signature(fn)
19
19
  arg1, arg2 = sig.parameters.keys()
20
-
21
- with EGraph() as egraph:
20
+ egraph = EGraph()
21
+ with egraph:
22
22
  res = fn(NDArray.var(arg1), NDArray.var(arg2))
23
23
  egraph.register(res)
24
24
  egraph.run(array_api_numba_schedule)
25
25
  res_optimized = egraph.extract(res)
26
- egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
26
+ # egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
27
27
 
28
- egraph = EGraph()
29
28
  fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
30
29
  egraph.register(fn_program)
31
30
  egraph.run(array_api_program_gen_schedule)
32
- fn = cast(X, egraph.eval(fn_program.py_object))
31
+ fn = cast(X, egraph.eval(egraph.extract(fn_program.py_object)))
33
32
  fn.expr = res_optimized # type: ignore[attr-defined]
34
33
  return fn
@@ -28,7 +28,7 @@ class ShapeAPI(Expr):
28
28
 
29
29
 
30
30
  @array_api_ruleset.register
31
- def shape_api_ruleset(dims: TupleInt, axis: TupleInt): # noqa: ANN201
31
+ def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
32
32
  s = ShapeAPI(dims)
33
33
  yield rewrite(s.deselect(axis)).to(
34
34
  ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
@@ -108,7 +108,7 @@ def _loopnest_api_ruleset(
108
108
 
109
109
 
110
110
  @function(ruleset=array_api_ruleset, unextractable=True)
111
- def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: # noqa: N803
111
+ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
112
112
  # peel off the outer shape for result array
113
113
  outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
114
114
  # get only the inner shape for reduction
@@ -126,20 +126,24 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: # noqa: N803
126
126
 
127
127
 
128
128
  # %%
129
- egraph = EGraph(save_egglog_string=True)
129
+ # egraph = EGraph(save_egglog_string=True)
130
+
131
+ # egraph.register(val.shape)
132
+ # egraph.run(array_api_ruleset.saturate())
133
+ # egraph.extract_multiple(val.shape, 10)
134
+
135
+ # %%
130
136
 
131
137
  X = NDArray.var("X")
132
138
  assume_shape(X, (3, 2, 3, 4))
133
139
  val = linalg_norm(X, (0, 1))
134
- egraph.register(val.shape)
135
- egraph.run(array_api_ruleset.saturate())
136
- egraph.extract_multiple(val.shape, 10)
137
-
138
- # %%
139
140
  egraph = EGraph()
140
- egraph.register(val.shape[2])
141
- egraph.run(array_api_ruleset.saturate())
142
- egraph.display(split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
141
+ x = egraph.let("x", val.shape[2])
142
+ # egraph.display(n_inline_leaves=0)
143
+ # egraph.extract(x)
144
+ # egraph.saturate(array_api_ruleset, expr=val.shape[2], split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
145
+ # egraph.run(array_api_ruleset.saturate())
146
+ # egraph.display()
143
147
 
144
148
 
145
149
  # %%
@@ -1,8 +1,6 @@
1
1
  # mypy: disable-error-code="empty-body"
2
2
  from __future__ import annotations
3
3
 
4
- import numpy as np
5
-
6
4
  from egglog import *
7
5
 
8
6
  from .array_api import *
@@ -13,9 +11,12 @@ from .program_gen import *
13
11
  # Depends on `np` as a global variable.
14
12
  ##
15
13
 
16
- array_api_program_gen_ruleset = ruleset()
14
+ array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
15
+ array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
17
16
 
18
- array_api_program_gen_schedule = array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate()
17
+ array_api_program_gen_schedule = (
18
+ array_api_program_gen_ruleset | program_gen_ruleset | array_api_program_gen_eval_ruleset | eval_program_rulseset
19
+ ).saturate()
19
20
 
20
21
 
21
22
  @function
@@ -98,17 +99,14 @@ def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int
98
99
  def ndarray_program(x: NDArray) -> Program: ...
99
100
 
100
101
 
101
- @function
102
- def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> Program: ...
102
+ @function(ruleset=array_api_program_gen_ruleset)
103
+ def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
104
+ return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))
103
105
 
104
106
 
105
- @array_api_program_gen_ruleset.register
106
- def _ndarray_function_two(f: Program, res: NDArray, l: NDArray, r: NDArray, o: PyObject):
107
- # When we have function, set the program and trigger it to be compiled
108
- yield rule(eq(f).to(ndarray_function_two(res, l, r))).then(
109
- union(f).with_(ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))),
110
- f.eval_py_object({"np": np}),
111
- )
107
+ @function(ruleset=array_api_program_gen_eval_ruleset)
108
+ def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
109
+ return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})
112
110
 
113
111
 
114
112
  @function
egglog/exp/program_gen.py CHANGED
@@ -83,8 +83,18 @@ class Program(Expr):
83
83
  Only keeps the original parent, not any additional ones, so that each set of statements is only added once.
84
84
  """
85
85
 
86
- @method(default=Unit())
87
- def eval_py_object(self, globals: object) -> Unit:
86
+ @property
87
+ def is_identifer(self) -> Bool:
88
+ """
89
+ Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
90
+ """
91
+
92
+
93
+ converter(String, Program, Program)
94
+
95
+
96
+ class EvalProgram(Expr):
97
+ def __init__(self, program: Program, globals: object) -> None:
88
98
  """
89
99
  Evaluates the program and saves as the py_object
90
100
  """
@@ -98,38 +108,34 @@ class Program(Expr):
98
108
  """
99
109
 
100
110
  @property
101
- def is_identifer(self) -> Bool:
111
+ def statements(self) -> String:
102
112
  """
103
- Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
113
+ Returns the statements of the program, if it's been compiled
104
114
  """
105
115
 
106
116
 
107
- converter(String, Program, Program)
108
-
109
- program_gen_ruleset = ruleset()
110
-
111
-
112
- @program_gen_ruleset.register
113
- def _py_object(p: Program, expr: String, statements: String, g: PyObject):
117
+ @ruleset
118
+ def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject):
114
119
  # When we evaluate a program, we first want to compile to a string
115
- yield rule(p.eval_py_object(g)).then(p.compile())
120
+ yield rule(EvalProgram(p, g)).then(p.compile())
116
121
  # Then we want to evaluate the statements/expr
117
122
  yield rule(
118
- p.eval_py_object(g),
123
+ eq(ep).to(EvalProgram(p, g)),
119
124
  eq(p.statements).to(statements),
120
125
  eq(p.expr).to(expr),
121
126
  ).then(
122
- set_(p.py_object).to(
127
+ set_(ep.py_object).to(
123
128
  py_eval(
124
129
  "l['___res']",
125
130
  PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)),
126
131
  )
127
- )
132
+ ),
133
+ set_(ep.statements).to(statements),
128
134
  )
129
135
 
130
136
 
131
- @program_gen_ruleset.register
132
- def _compile(
137
+ @ruleset
138
+ def program_gen_ruleset(
133
139
  s: String,
134
140
  s1: String,
135
141
  s2: String,
egglog/ipython_magic.py CHANGED
@@ -14,7 +14,7 @@ if IN_IPYTHON:
14
14
 
15
15
  @needs_local_scope
16
16
  @register_cell_magic
17
- def egglog(line, cell, local_ns): # noqa: ANN001, ANN201
17
+ def egglog(line, cell, local_ns):
18
18
  """
19
19
  Run an egglog program.
20
20