egglog 8.0.0__cp312-none-win_amd64.whl → 8.0.1__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +49 -35
- egglog/conversion.py +1 -1
- egglog/declarations.py +1 -2
- egglog/egraph.py +23 -10
- egglog/egraph_state.py +12 -5
- egglog/examples/higher_order_functions.py +1 -1
- egglog/exp/array_api.py +17 -14
- egglog/exp/array_api_jit.py +4 -5
- egglog/exp/array_api_loopnest.py +15 -11
- egglog/exp/array_api_program_gen.py +11 -13
- egglog/exp/program_gen.py +23 -17
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +4 -4
- egglog/runtime.py +1 -1
- egglog/visualizer.css +1 -1
- egglog/visualizer.js +1658 -1637
- {egglog-8.0.0.dist-info → egglog-8.0.1.dist-info}/METADATA +2 -2
- {egglog-8.0.0.dist-info → egglog-8.0.1.dist-info}/RECORD +21 -21
- {egglog-8.0.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +0 -0
- {egglog-8.0.0.dist-info → egglog-8.0.1.dist-info}/licenses/LICENSE +0 -0
|
Binary file
|
egglog/bindings.pyi
CHANGED
|
@@ -5,8 +5,6 @@ from typing import TypeAlias
|
|
|
5
5
|
|
|
6
6
|
from typing_extensions import final
|
|
7
7
|
|
|
8
|
-
HIGH_COST: int
|
|
9
|
-
|
|
10
8
|
@final
|
|
11
9
|
class SerializedEGraph:
|
|
12
10
|
def inline_leaves(self) -> None: ...
|
|
@@ -14,13 +12,14 @@ class SerializedEGraph:
|
|
|
14
12
|
def to_dot(self) -> str: ...
|
|
15
13
|
def to_json(self) -> str: ...
|
|
16
14
|
def map_ops(self, map: dict[str, str]) -> None: ...
|
|
17
|
-
def
|
|
15
|
+
def split_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
|
|
18
16
|
|
|
19
17
|
@final
|
|
20
18
|
class PyObjectSort:
|
|
21
19
|
def __init__(self) -> None: ...
|
|
22
20
|
def store(self, __o: object, /) -> _Expr: ...
|
|
23
21
|
|
|
22
|
+
def parse_program(__input: str, /, filename: str | None = None) -> list[_Command]: ...
|
|
24
23
|
@final
|
|
25
24
|
class EGraph:
|
|
26
25
|
def __init__(
|
|
@@ -29,11 +28,9 @@ class EGraph:
|
|
|
29
28
|
*,
|
|
30
29
|
fact_directory: str | Path | None = None,
|
|
31
30
|
seminaive: bool = True,
|
|
32
|
-
terms_encoding: bool = False,
|
|
33
31
|
record: bool = False,
|
|
34
32
|
) -> None: ...
|
|
35
33
|
def commands(self) -> str | None: ...
|
|
36
|
-
def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ...
|
|
37
34
|
def run_program(self, *commands: _Command) -> list[str]: ...
|
|
38
35
|
def extract_report(self) -> _ExtractReport | None: ...
|
|
39
36
|
def run_report(self) -> RunReport | None: ...
|
|
@@ -250,6 +247,7 @@ _Action: TypeAlias = Let | Set | Change | Union | Panic | Expr_ | Extract
|
|
|
250
247
|
|
|
251
248
|
@final
|
|
252
249
|
class FunctionDecl:
|
|
250
|
+
span: Span
|
|
253
251
|
name: str
|
|
254
252
|
schema: Schema
|
|
255
253
|
default: _Expr | None
|
|
@@ -261,6 +259,7 @@ class FunctionDecl:
|
|
|
261
259
|
|
|
262
260
|
def __init__(
|
|
263
261
|
self,
|
|
262
|
+
span: Span,
|
|
264
263
|
name: str,
|
|
265
264
|
schema: Schema,
|
|
266
265
|
default: _Expr | None = None,
|
|
@@ -273,7 +272,8 @@ class FunctionDecl:
|
|
|
273
272
|
|
|
274
273
|
@final
|
|
275
274
|
class Variant:
|
|
276
|
-
def __init__(self, name: str, types: list[str], cost: int | None = None) -> None: ...
|
|
275
|
+
def __init__(self, span: Span, name: str, types: list[str], cost: int | None = None) -> None: ...
|
|
276
|
+
span: Span
|
|
277
277
|
name: str
|
|
278
278
|
types: list[str]
|
|
279
279
|
cost: int | None
|
|
@@ -379,6 +379,23 @@ class Sequence:
|
|
|
379
379
|
|
|
380
380
|
_Schedule: TypeAlias = Saturate | Repeat | Run | Sequence
|
|
381
381
|
|
|
382
|
+
##
|
|
383
|
+
# Subdatatypes
|
|
384
|
+
##
|
|
385
|
+
|
|
386
|
+
@final
|
|
387
|
+
class SubVariants:
|
|
388
|
+
def __init__(self, variants: list[Variant]) -> None: ...
|
|
389
|
+
variants: list[Variant]
|
|
390
|
+
|
|
391
|
+
@final
|
|
392
|
+
class NewSort:
|
|
393
|
+
def __init__(self, name: str, args: list[_Expr]) -> None: ...
|
|
394
|
+
name: str
|
|
395
|
+
args: list[_Expr]
|
|
396
|
+
|
|
397
|
+
_Subdatatypes: TypeAlias = SubVariants | NewSort
|
|
398
|
+
|
|
382
399
|
##
|
|
383
400
|
# Commands
|
|
384
401
|
##
|
|
@@ -391,22 +408,23 @@ class SetOption:
|
|
|
391
408
|
|
|
392
409
|
@final
|
|
393
410
|
class Datatype:
|
|
411
|
+
span: Span
|
|
394
412
|
name: str
|
|
395
413
|
variants: list[Variant]
|
|
396
|
-
def __init__(self, name: str, variants: list[Variant]) -> None: ...
|
|
414
|
+
def __init__(self, span: Span, name: str, variants: list[Variant]) -> None: ...
|
|
397
415
|
|
|
398
416
|
@final
|
|
399
|
-
class
|
|
417
|
+
class Datatypes:
|
|
400
418
|
span: Span
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
def __init__(self, span: Span, name: str, sort: str) -> None: ...
|
|
419
|
+
datatypes: list[tuple[Span, str, _Subdatatypes]]
|
|
420
|
+
def __init__(self, span: Span, datatypes: list[tuple[Span, str, _Subdatatypes]]) -> None: ...
|
|
404
421
|
|
|
405
422
|
@final
|
|
406
423
|
class Sort:
|
|
424
|
+
span: Span
|
|
407
425
|
name: str
|
|
408
426
|
presort_and_args: tuple[str, list[_Expr]] | None
|
|
409
|
-
def __init__(self, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
|
|
427
|
+
def __init__(self, span: Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
|
|
410
428
|
|
|
411
429
|
@final
|
|
412
430
|
class Function:
|
|
@@ -452,22 +470,17 @@ class RunSchedule:
|
|
|
452
470
|
|
|
453
471
|
@final
|
|
454
472
|
class Simplify:
|
|
473
|
+
span: Span
|
|
455
474
|
expr: _Expr
|
|
456
475
|
schedule: _Schedule
|
|
457
|
-
def __init__(self, expr: _Expr, schedule: _Schedule) -> None: ...
|
|
458
|
-
|
|
459
|
-
@final
|
|
460
|
-
class Calc:
|
|
461
|
-
span: Span
|
|
462
|
-
identifiers: list[IdentSort]
|
|
463
|
-
exprs: list[_Expr]
|
|
464
|
-
def __init__(self, span: Span, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
|
|
476
|
+
def __init__(self, span: Span, expr: _Expr, schedule: _Schedule) -> None: ...
|
|
465
477
|
|
|
466
478
|
@final
|
|
467
479
|
class QueryExtract:
|
|
480
|
+
span: Span
|
|
468
481
|
variants: int
|
|
469
482
|
expr: _Expr
|
|
470
|
-
def __init__(self, variants: int, expr: _Expr) -> None: ...
|
|
483
|
+
def __init__(self, span: Span, variants: int, expr: _Expr) -> None: ...
|
|
471
484
|
|
|
472
485
|
@final
|
|
473
486
|
class Check:
|
|
@@ -484,20 +497,23 @@ class PrintFunction:
|
|
|
484
497
|
|
|
485
498
|
@final
|
|
486
499
|
class PrintSize:
|
|
500
|
+
span: Span
|
|
487
501
|
name: str | None
|
|
488
|
-
def __init__(self, name: str | None) -> None: ...
|
|
502
|
+
def __init__(self, span: Span, name: str | None) -> None: ...
|
|
489
503
|
|
|
490
504
|
@final
|
|
491
505
|
class Output:
|
|
506
|
+
span: Span
|
|
492
507
|
file: str
|
|
493
508
|
exprs: list[_Expr]
|
|
494
|
-
def __init__(self, file: str, exprs: list[_Expr]) -> None: ...
|
|
509
|
+
def __init__(self, span: Span, file: str, exprs: list[_Expr]) -> None: ...
|
|
495
510
|
|
|
496
511
|
@final
|
|
497
512
|
class Input:
|
|
513
|
+
span: Span
|
|
498
514
|
name: str
|
|
499
515
|
file: str
|
|
500
|
-
def __init__(self, name: str, file: str) -> None: ...
|
|
516
|
+
def __init__(self, span: Span, name: str, file: str) -> None: ...
|
|
501
517
|
|
|
502
518
|
@final
|
|
503
519
|
class Push:
|
|
@@ -506,29 +522,29 @@ class Push:
|
|
|
506
522
|
|
|
507
523
|
@final
|
|
508
524
|
class Pop:
|
|
525
|
+
span: Span
|
|
509
526
|
length: int
|
|
510
|
-
def __init__(self, length: int) -> None: ...
|
|
527
|
+
def __init__(self, span: Span, length: int) -> None: ...
|
|
511
528
|
|
|
512
529
|
@final
|
|
513
530
|
class Fail:
|
|
531
|
+
span: Span
|
|
514
532
|
command: _Command
|
|
515
|
-
def __init__(self, command: _Command) -> None: ...
|
|
533
|
+
def __init__(self, span: Span, command: _Command) -> None: ...
|
|
516
534
|
|
|
517
535
|
@final
|
|
518
536
|
class Include:
|
|
537
|
+
span: Span
|
|
519
538
|
path: str
|
|
520
|
-
def __init__(self, path: str) -> None: ...
|
|
521
|
-
|
|
522
|
-
@final
|
|
523
|
-
class CheckProof:
|
|
524
|
-
def __init__(self) -> None: ...
|
|
539
|
+
def __init__(self, span: Span, path: str) -> None: ...
|
|
525
540
|
|
|
526
541
|
@final
|
|
527
542
|
class Relation:
|
|
543
|
+
span: Span
|
|
528
544
|
constructor: str
|
|
529
545
|
inputs: list[str]
|
|
530
546
|
|
|
531
|
-
def __init__(self, constructor: str, inputs: list[str]) -> None: ...
|
|
547
|
+
def __init__(self, span: Span, constructor: str, inputs: list[str]) -> None: ...
|
|
532
548
|
|
|
533
549
|
@final
|
|
534
550
|
class PrintOverallStatistics:
|
|
@@ -543,7 +559,7 @@ class UnstableCombinedRuleset:
|
|
|
543
559
|
_Command: TypeAlias = (
|
|
544
560
|
SetOption
|
|
545
561
|
| Datatype
|
|
546
|
-
|
|
|
562
|
+
| Datatypes
|
|
547
563
|
| Sort
|
|
548
564
|
| Function
|
|
549
565
|
| AddRuleset
|
|
@@ -552,7 +568,6 @@ _Command: TypeAlias = (
|
|
|
552
568
|
| BiRewriteCommand
|
|
553
569
|
| ActionCommand
|
|
554
570
|
| RunSchedule
|
|
555
|
-
| Calc
|
|
556
571
|
| Simplify
|
|
557
572
|
| QueryExtract
|
|
558
573
|
| Check
|
|
@@ -564,7 +579,6 @@ _Command: TypeAlias = (
|
|
|
564
579
|
| Pop
|
|
565
580
|
| Fail
|
|
566
581
|
| Include
|
|
567
|
-
| CheckProof
|
|
568
582
|
| Relation
|
|
569
583
|
| PrintOverallStatistics
|
|
570
584
|
| UnstableCombinedRuleset
|
egglog/conversion.py
CHANGED
|
@@ -195,6 +195,6 @@ def _get_tp(x: object) -> TypeName | type:
|
|
|
195
195
|
return TypeName(x.__egg_typed_expr__.tp.name)
|
|
196
196
|
tp = type(x)
|
|
197
197
|
# If this value has a custom metaclass, let's use that as our index instead of the type
|
|
198
|
-
if type(tp)
|
|
198
|
+
if type(tp) is not type:
|
|
199
199
|
return type(tp)
|
|
200
200
|
return tp
|
egglog/declarations.py
CHANGED
|
@@ -622,9 +622,8 @@ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExp
|
|
|
622
622
|
res = replacements[typed_expr]
|
|
623
623
|
else:
|
|
624
624
|
match typed_expr.expr:
|
|
625
|
-
case (
|
|
625
|
+
case CallDecl(callable, args, bound_tp_params) | PartialCallDecl(
|
|
626
626
|
CallDecl(callable, args, bound_tp_params)
|
|
627
|
-
| PartialCallDecl(CallDecl(callable, args, bound_tp_params))
|
|
628
627
|
):
|
|
629
628
|
new_args = tuple(_inner(a) for a in args)
|
|
630
629
|
call_decl = CallDecl(callable, new_args, bound_tp_params)
|
egglog/egraph.py
CHANGED
|
@@ -486,7 +486,7 @@ class _ExprMetaclass(type):
|
|
|
486
486
|
return isinstance(instance, RuntimeExpr)
|
|
487
487
|
|
|
488
488
|
|
|
489
|
-
def _generate_class_decls( # noqa: C901
|
|
489
|
+
def _generate_class_decls( # noqa: C901,PLR0912
|
|
490
490
|
namespace: dict[str, Any],
|
|
491
491
|
frame: FrameType,
|
|
492
492
|
builtin: bool,
|
|
@@ -704,8 +704,12 @@ def _fn_decl(
|
|
|
704
704
|
# https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
|
|
705
705
|
if "Callable" not in hint_globals:
|
|
706
706
|
hint_globals["Callable"] = Callable
|
|
707
|
-
|
|
708
|
-
|
|
707
|
+
# Instead of passing both globals and locals, just pass the globals. Otherwise, for some reason forward references
|
|
708
|
+
# won't be resolved correctly
|
|
709
|
+
# We need this to be false so it returns "__forward_value__" https://github.com/python/cpython/blob/440ed18e08887b958ad50db1b823e692a747b671/Lib/typing.py#L919
|
|
710
|
+
# https://github.com/egraphs-good/egglog-python/issues/210
|
|
711
|
+
hint_globals.update(hint_locals)
|
|
712
|
+
hints = get_type_hints(fn, hint_globals)
|
|
709
713
|
|
|
710
714
|
params = list(signature(fn).parameters.values())
|
|
711
715
|
|
|
@@ -1021,7 +1025,7 @@ class EGraph(_BaseModule):
|
|
|
1021
1025
|
"""
|
|
1022
1026
|
Loads a CSV file and sets it as *input, output of the function.
|
|
1023
1027
|
"""
|
|
1024
|
-
self._egraph.run_program(bindings.Input(self._callable_to_egg(fn), path))
|
|
1028
|
+
self._egraph.run_program(bindings.Input(bindings.DUMMY_SPAN, self._callable_to_egg(fn), path))
|
|
1025
1029
|
|
|
1026
1030
|
def _callable_to_egg(self, fn: object) -> str:
|
|
1027
1031
|
ref, decls = resolve_callable(fn)
|
|
@@ -1063,7 +1067,7 @@ class EGraph(_BaseModule):
|
|
|
1063
1067
|
typed_expr = runtime_expr.__egg_typed_expr__
|
|
1064
1068
|
# Must also register type
|
|
1065
1069
|
egg_expr = self._state.typed_expr_to_egg(typed_expr)
|
|
1066
|
-
self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
|
|
1070
|
+
self._egraph.run_program(bindings.Simplify(bindings.DUMMY_SPAN, egg_expr, egg_schedule))
|
|
1067
1071
|
extract_report = self._egraph.extract_report()
|
|
1068
1072
|
if not isinstance(extract_report, bindings.Best):
|
|
1069
1073
|
msg = "No extract report saved"
|
|
@@ -1118,7 +1122,7 @@ class EGraph(_BaseModule):
|
|
|
1118
1122
|
"""
|
|
1119
1123
|
Checks that one of the facts is not true
|
|
1120
1124
|
"""
|
|
1121
|
-
self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
|
|
1125
|
+
self._egraph.run_program(bindings.Fail(bindings.DUMMY_SPAN, self._facts_to_check(facts)))
|
|
1122
1126
|
|
|
1123
1127
|
def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
|
|
1124
1128
|
facts = _fact_likes(fact_likes)
|
|
@@ -1191,7 +1195,7 @@ class EGraph(_BaseModule):
|
|
|
1191
1195
|
"""
|
|
1192
1196
|
Pop the current state of the egraph, reverting back to the previous state.
|
|
1193
1197
|
"""
|
|
1194
|
-
self._egraph.run_program(bindings.Pop(1))
|
|
1198
|
+
self._egraph.run_program(bindings.Pop(bindings.DUMMY_SPAN, 1))
|
|
1195
1199
|
self._state = self._state_stack.pop()
|
|
1196
1200
|
|
|
1197
1201
|
def __enter__(self) -> Self:
|
|
@@ -1204,7 +1208,7 @@ class EGraph(_BaseModule):
|
|
|
1204
1208
|
self.push()
|
|
1205
1209
|
return self
|
|
1206
1210
|
|
|
1207
|
-
def __exit__(self, exc_type, exc, exc_tb) -> None:
|
|
1211
|
+
def __exit__(self, exc_type, exc, exc_tb) -> None:
|
|
1208
1212
|
CURRENT_EGRAPH.reset(self._token_stack.pop())
|
|
1209
1213
|
self.pop()
|
|
1210
1214
|
|
|
@@ -1262,7 +1266,7 @@ class EGraph(_BaseModule):
|
|
|
1262
1266
|
)
|
|
1263
1267
|
if split_primitive_outputs or split_functions:
|
|
1264
1268
|
additional_ops = set(map(self._callable_to_egg, split_functions))
|
|
1265
|
-
serialized.
|
|
1269
|
+
serialized.split_classes(self._egraph, additional_ops)
|
|
1266
1270
|
serialized.map_ops(self._state.op_mapping())
|
|
1267
1271
|
|
|
1268
1272
|
for _ in range(n_inline_leaves):
|
|
@@ -1322,7 +1326,14 @@ class EGraph(_BaseModule):
|
|
|
1322
1326
|
serialized = self._serialize(**kwargs)
|
|
1323
1327
|
VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
|
|
1324
1328
|
|
|
1325
|
-
def saturate(
|
|
1329
|
+
def saturate(
|
|
1330
|
+
self,
|
|
1331
|
+
schedule: Schedule | None = None,
|
|
1332
|
+
*,
|
|
1333
|
+
expr: Expr | None = None,
|
|
1334
|
+
max: int = 1000,
|
|
1335
|
+
**kwargs: Unpack[GraphvizKwargs],
|
|
1336
|
+
) -> None:
|
|
1326
1337
|
"""
|
|
1327
1338
|
Saturate the egraph, running the given schedule until the egraph is saturated.
|
|
1328
1339
|
It serializes the egraph at each step and returns a widget to visualize the egraph.
|
|
@@ -1330,6 +1341,8 @@ class EGraph(_BaseModule):
|
|
|
1330
1341
|
from .visualizer_widget import VisualizerWidget
|
|
1331
1342
|
|
|
1332
1343
|
def to_json() -> str:
|
|
1344
|
+
if expr:
|
|
1345
|
+
print(self.extract(expr))
|
|
1333
1346
|
return self._serialize(**kwargs).to_json()
|
|
1334
1347
|
|
|
1335
1348
|
egraphs = [to_json()]
|
egglog/egraph_state.py
CHANGED
|
@@ -206,18 +206,25 @@ class EGraphState:
|
|
|
206
206
|
self.egg_fn_to_callable_refs[egg_name].add(ref)
|
|
207
207
|
match decl:
|
|
208
208
|
case RelationDecl(arg_types, _, _):
|
|
209
|
-
self.egraph.run_program(
|
|
209
|
+
self.egraph.run_program(
|
|
210
|
+
bindings.Relation(bindings.DUMMY_SPAN, egg_name, [self.type_ref_to_egg(a) for a in arg_types])
|
|
211
|
+
)
|
|
210
212
|
case ConstantDecl(tp, _):
|
|
211
213
|
# Use function decleration instead of constant b/c constants cannot be extracted
|
|
212
214
|
# https://github.com/egraphs-good/egglog/issues/334
|
|
213
215
|
self.egraph.run_program(
|
|
214
|
-
bindings.Function(
|
|
216
|
+
bindings.Function(
|
|
217
|
+
bindings.FunctionDecl(
|
|
218
|
+
bindings.DUMMY_SPAN, egg_name, bindings.Schema([], self.type_ref_to_egg(tp))
|
|
219
|
+
)
|
|
220
|
+
)
|
|
215
221
|
)
|
|
216
222
|
case FunctionDecl():
|
|
217
223
|
if not decl.builtin:
|
|
218
224
|
signature = decl.signature
|
|
219
225
|
assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
|
|
220
226
|
egg_fn_decl = bindings.FunctionDecl(
|
|
227
|
+
bindings.DUMMY_SPAN,
|
|
221
228
|
egg_name,
|
|
222
229
|
bindings.Schema(
|
|
223
230
|
[self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
|
|
@@ -261,7 +268,7 @@ class EGraphState:
|
|
|
261
268
|
args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
|
|
262
269
|
else:
|
|
263
270
|
args = None
|
|
264
|
-
self.egraph.run_program(bindings.Sort(egg_name, args))
|
|
271
|
+
self.egraph.run_program(bindings.Sort(bindings.DUMMY_SPAN, egg_name, args))
|
|
265
272
|
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
|
|
266
273
|
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
|
|
267
274
|
# even if you never use that function.
|
|
@@ -323,7 +330,7 @@ class EGraphState:
|
|
|
323
330
|
@overload
|
|
324
331
|
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
325
332
|
|
|
326
|
-
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
|
|
333
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912,C901
|
|
327
334
|
"""
|
|
328
335
|
Convert an ExprDecl to an egg expression.
|
|
329
336
|
"""
|
|
@@ -486,7 +493,7 @@ class FromEggState:
|
|
|
486
493
|
if term.name == "py-object":
|
|
487
494
|
call = bindings.termdag_term_to_expr(self.termdag, term)
|
|
488
495
|
expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
|
|
489
|
-
|
|
496
|
+
elif term.name == "unstable-fn":
|
|
490
497
|
# Get function name
|
|
491
498
|
fn_term, *arg_terms = term.args
|
|
492
499
|
fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
|
|
@@ -26,7 +26,7 @@ class MathList(Expr):
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@ruleset
|
|
29
|
-
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
|
|
29
|
+
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
|
|
30
30
|
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
|
|
31
31
|
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
|
|
32
32
|
yield rewrite(MathList().map(f)).to(MathList())
|
egglog/exp/array_api.py
CHANGED
|
@@ -1538,20 +1538,23 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
|
|
|
1538
1538
|
egraph.register(expr)
|
|
1539
1539
|
egraph.run(array_api_schedule)
|
|
1540
1540
|
try:
|
|
1541
|
-
|
|
1541
|
+
extracted = egraph.extract(prim_expr)
|
|
1542
1542
|
except EggSmolError as exc:
|
|
1543
|
-
|
|
1543
|
+
# Try giving some context, by showing the smallest version of the larger expression
|
|
1544
1544
|
try:
|
|
1545
|
-
|
|
1546
|
-
except EggSmolError:
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
# egraph.as_egglog_string
|
|
1551
|
-
# + "\n"
|
|
1552
|
-
# + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
|
|
1553
|
-
# )
|
|
1554
|
-
# # save to "tmp.egg"
|
|
1555
|
-
# with open("tmp.egg", "w") as f:
|
|
1556
|
-
# f.write(string)
|
|
1545
|
+
expr_extracted = egraph.extract(expr)
|
|
1546
|
+
except EggSmolError as inner_exc:
|
|
1547
|
+
raise ValueError(f"Cannot simplify {expr}") from inner_exc
|
|
1548
|
+
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1549
|
+
msg = f"Cannot simplify to primitive {expr_extracted}"
|
|
1557
1550
|
raise ValueError(msg) from exc
|
|
1551
|
+
return egraph.eval(extracted)
|
|
1552
|
+
|
|
1553
|
+
# string = (
|
|
1554
|
+
# egraph.as_egglog_string
|
|
1555
|
+
# + "\n"
|
|
1556
|
+
# + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
|
|
1557
|
+
# )
|
|
1558
|
+
# # save to "tmp.egg"
|
|
1559
|
+
# with open("tmp.egg", "w") as f:
|
|
1560
|
+
# f.write(string)
|
egglog/exp/array_api_jit.py
CHANGED
|
@@ -17,18 +17,17 @@ def jit(fn: X) -> X:
|
|
|
17
17
|
# 1. Create variables for each of the two args in the functions
|
|
18
18
|
sig = inspect.signature(fn)
|
|
19
19
|
arg1, arg2 = sig.parameters.keys()
|
|
20
|
-
|
|
21
|
-
with
|
|
20
|
+
egraph = EGraph()
|
|
21
|
+
with egraph:
|
|
22
22
|
res = fn(NDArray.var(arg1), NDArray.var(arg2))
|
|
23
23
|
egraph.register(res)
|
|
24
24
|
egraph.run(array_api_numba_schedule)
|
|
25
25
|
res_optimized = egraph.extract(res)
|
|
26
|
-
egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
|
|
26
|
+
# egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
|
|
27
27
|
|
|
28
|
-
egraph = EGraph()
|
|
29
28
|
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
|
|
30
29
|
egraph.register(fn_program)
|
|
31
30
|
egraph.run(array_api_program_gen_schedule)
|
|
32
|
-
fn = cast(X, egraph.eval(fn_program.py_object))
|
|
31
|
+
fn = cast(X, egraph.eval(egraph.extract(fn_program.py_object)))
|
|
33
32
|
fn.expr = res_optimized # type: ignore[attr-defined]
|
|
34
33
|
return fn
|
egglog/exp/array_api_loopnest.py
CHANGED
|
@@ -28,7 +28,7 @@ class ShapeAPI(Expr):
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@array_api_ruleset.register
|
|
31
|
-
def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
|
|
31
|
+
def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
|
|
32
32
|
s = ShapeAPI(dims)
|
|
33
33
|
yield rewrite(s.deselect(axis)).to(
|
|
34
34
|
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
|
|
@@ -108,7 +108,7 @@ def _loopnest_api_ruleset(
|
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
@function(ruleset=array_api_ruleset, unextractable=True)
|
|
111
|
-
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
|
|
111
|
+
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
|
|
112
112
|
# peel off the outer shape for result array
|
|
113
113
|
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
|
|
114
114
|
# get only the inner shape for reduction
|
|
@@ -126,20 +126,24 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: # noqa: N803
|
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
# %%
|
|
129
|
-
egraph = EGraph(save_egglog_string=True)
|
|
129
|
+
# egraph = EGraph(save_egglog_string=True)
|
|
130
|
+
|
|
131
|
+
# egraph.register(val.shape)
|
|
132
|
+
# egraph.run(array_api_ruleset.saturate())
|
|
133
|
+
# egraph.extract_multiple(val.shape, 10)
|
|
134
|
+
|
|
135
|
+
# %%
|
|
130
136
|
|
|
131
137
|
X = NDArray.var("X")
|
|
132
138
|
assume_shape(X, (3, 2, 3, 4))
|
|
133
139
|
val = linalg_norm(X, (0, 1))
|
|
134
|
-
egraph.register(val.shape)
|
|
135
|
-
egraph.run(array_api_ruleset.saturate())
|
|
136
|
-
egraph.extract_multiple(val.shape, 10)
|
|
137
|
-
|
|
138
|
-
# %%
|
|
139
140
|
egraph = EGraph()
|
|
140
|
-
egraph.
|
|
141
|
-
egraph.
|
|
142
|
-
egraph.
|
|
141
|
+
x = egraph.let("x", val.shape[2])
|
|
142
|
+
# egraph.display(n_inline_leaves=0)
|
|
143
|
+
# egraph.extract(x)
|
|
144
|
+
# egraph.saturate(array_api_ruleset, expr=val.shape[2], split_functions=[Int, TRUE, FALSE], n_inline_leaves=2)
|
|
145
|
+
# egraph.run(array_api_ruleset.saturate())
|
|
146
|
+
# egraph.display()
|
|
143
147
|
|
|
144
148
|
|
|
145
149
|
# %%
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
# mypy: disable-error-code="empty-body"
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
4
|
from egglog import *
|
|
7
5
|
|
|
8
6
|
from .array_api import *
|
|
@@ -13,9 +11,12 @@ from .program_gen import *
|
|
|
13
11
|
# Depends on `np` as a global variable.
|
|
14
12
|
##
|
|
15
13
|
|
|
16
|
-
array_api_program_gen_ruleset = ruleset()
|
|
14
|
+
array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
|
|
15
|
+
array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
|
|
17
16
|
|
|
18
|
-
array_api_program_gen_schedule =
|
|
17
|
+
array_api_program_gen_schedule = (
|
|
18
|
+
array_api_program_gen_ruleset | program_gen_ruleset | array_api_program_gen_eval_ruleset | eval_program_rulseset
|
|
19
|
+
).saturate()
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
@function
|
|
@@ -98,17 +99,14 @@ def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int
|
|
|
98
99
|
def ndarray_program(x: NDArray) -> Program: ...
|
|
99
100
|
|
|
100
101
|
|
|
101
|
-
@function
|
|
102
|
-
def
|
|
102
|
+
@function(ruleset=array_api_program_gen_ruleset)
|
|
103
|
+
def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
|
|
104
|
+
return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))
|
|
103
105
|
|
|
104
106
|
|
|
105
|
-
@
|
|
106
|
-
def
|
|
107
|
-
|
|
108
|
-
yield rule(eq(f).to(ndarray_function_two(res, l, r))).then(
|
|
109
|
-
union(f).with_(ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))),
|
|
110
|
-
f.eval_py_object({"np": np}),
|
|
111
|
-
)
|
|
107
|
+
@function(ruleset=array_api_program_gen_eval_ruleset)
|
|
108
|
+
def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
|
|
109
|
+
return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})
|
|
112
110
|
|
|
113
111
|
|
|
114
112
|
@function
|
egglog/exp/program_gen.py
CHANGED
|
@@ -83,8 +83,18 @@ class Program(Expr):
|
|
|
83
83
|
Only keeps the original parent, not any additional ones, so that each set of statements is only added once.
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
|
-
@
|
|
87
|
-
def
|
|
86
|
+
@property
|
|
87
|
+
def is_identifer(self) -> Bool:
|
|
88
|
+
"""
|
|
89
|
+
Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
converter(String, Program, Program)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class EvalProgram(Expr):
|
|
97
|
+
def __init__(self, program: Program, globals: object) -> None:
|
|
88
98
|
"""
|
|
89
99
|
Evaluates the program and saves as the py_object
|
|
90
100
|
"""
|
|
@@ -98,38 +108,34 @@ class Program(Expr):
|
|
|
98
108
|
"""
|
|
99
109
|
|
|
100
110
|
@property
|
|
101
|
-
def
|
|
111
|
+
def statements(self) -> String:
|
|
102
112
|
"""
|
|
103
|
-
Returns
|
|
113
|
+
Returns the statements of the program, if it's been compiled
|
|
104
114
|
"""
|
|
105
115
|
|
|
106
116
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
program_gen_ruleset = ruleset()
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
@program_gen_ruleset.register
|
|
113
|
-
def _py_object(p: Program, expr: String, statements: String, g: PyObject):
|
|
117
|
+
@ruleset
|
|
118
|
+
def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject):
|
|
114
119
|
# When we evaluate a program, we first want to compile to a string
|
|
115
|
-
yield rule(p
|
|
120
|
+
yield rule(EvalProgram(p, g)).then(p.compile())
|
|
116
121
|
# Then we want to evaluate the statements/expr
|
|
117
122
|
yield rule(
|
|
118
|
-
|
|
123
|
+
eq(ep).to(EvalProgram(p, g)),
|
|
119
124
|
eq(p.statements).to(statements),
|
|
120
125
|
eq(p.expr).to(expr),
|
|
121
126
|
).then(
|
|
122
|
-
set_(
|
|
127
|
+
set_(ep.py_object).to(
|
|
123
128
|
py_eval(
|
|
124
129
|
"l['___res']",
|
|
125
130
|
PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)),
|
|
126
131
|
)
|
|
127
|
-
)
|
|
132
|
+
),
|
|
133
|
+
set_(ep.statements).to(statements),
|
|
128
134
|
)
|
|
129
135
|
|
|
130
136
|
|
|
131
|
-
@
|
|
132
|
-
def
|
|
137
|
+
@ruleset
|
|
138
|
+
def program_gen_ruleset(
|
|
133
139
|
s: String,
|
|
134
140
|
s1: String,
|
|
135
141
|
s2: String,
|