egglog 10.0.1__cp311-cp311-win_amd64.whl → 11.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/builtins.py CHANGED
@@ -6,18 +6,20 @@ Builtin sorts and function to egg.
6
6
  from __future__ import annotations
7
7
 
8
8
  from collections.abc import Callable
9
+ from dataclasses import dataclass
9
10
  from fractions import Fraction
10
11
  from functools import partial, reduce
12
+ from inspect import signature
11
13
  from types import FunctionType, MethodType
12
- from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
14
+ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload
13
15
 
14
- from typing_extensions import TypeVarTuple, Unpack
16
+ from typing_extensions import TypeVarTuple, Unpack, deprecated
15
17
 
16
- from .conversion import convert, converter, get_type_args
18
+ from .conversion import convert, converter, get_type_args, resolve_literal
17
19
  from .declarations import *
18
- from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
19
- from .functionalize import functionalize
20
- from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
20
+ from .deconstruct import get_callable_args, get_literal_value
21
+ from .egraph import BaseExpr, BuiltinExpr, _add_default_rewrite_inner, expr_fact, function, get_current_ruleset, method
22
+ from .runtime import RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate
21
23
  from .thunk import Thunk
22
24
 
23
25
  if TYPE_CHECKING:
@@ -31,7 +33,7 @@ __all__ = [
31
33
  "BigRatLike",
32
34
  "Bool",
33
35
  "BoolLike",
34
- "BuiltinEvalError",
36
+ "ExprValueError",
35
37
  "Map",
36
38
  "MapLike",
37
39
  "MultiSet",
@@ -56,15 +58,17 @@ __all__ = [
56
58
  ]
57
59
 
58
60
 
59
- class BuiltinEvalError(Exception):
61
+ @dataclass
62
+ class ExprValueError(AttributeError):
60
63
  """
61
- Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
62
-
63
- Try extracting this expression first.
64
+ Raised when an expression cannot be converted to a Python value because the value is not a constructor.
64
65
  """
65
66
 
67
+ expr: BaseExpr
68
+ allowed: str
69
+
66
70
  def __str__(self) -> str:
67
- return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
71
+ return f"Cannot get Python value of {self.expr}, must be of form {self.allowed}. Try calling `extract` on it to get the underlying value."
68
72
 
69
73
 
70
74
  class Unit(BuiltinExpr, egg_sort="Unit"):
@@ -80,13 +84,21 @@ class Unit(BuiltinExpr, egg_sort="Unit"):
80
84
 
81
85
 
82
86
  class String(BuiltinExpr):
87
+ def __init__(self, value: str) -> None: ...
88
+
83
89
  @method(preserve=True)
90
+ @deprecated("use .value")
84
91
  def eval(self) -> str:
85
- value = _extract_lit(self)
86
- assert isinstance(value, str)
87
- return value
92
+ return self.value
88
93
 
89
- def __init__(self, value: str) -> None: ...
94
+ @method(preserve=True) # type: ignore[misc]
95
+ @property
96
+ def value(self) -> str:
97
+ if (value := get_literal_value(self)) is not None:
98
+ return value
99
+ raise ExprValueError(self, "String")
100
+
101
+ __match_args__ = ("value",)
90
102
 
91
103
  @method(egg_fn="replace")
92
104
  def replace(self, old: StringLike, new: StringLike) -> String: ...
@@ -101,21 +113,27 @@ def join(*strings: StringLike) -> String: ...
101
113
 
102
114
  converter(str, String, String)
103
115
 
104
- BoolLike: TypeAlias = Union["Bool", bool]
105
-
106
116
 
107
117
  class Bool(BuiltinExpr, egg_sort="bool"):
118
+ def __init__(self, value: bool) -> None: ...
119
+
108
120
  @method(preserve=True)
121
+ @deprecated("use .value")
109
122
  def eval(self) -> bool:
110
- value = _extract_lit(self)
111
- assert isinstance(value, bool)
112
- return value
123
+ return self.value
124
+
125
+ @method(preserve=True) # type: ignore[misc]
126
+ @property
127
+ def value(self) -> bool:
128
+ if (value := get_literal_value(self)) is not None:
129
+ return value
130
+ raise ExprValueError(self, "Bool")
131
+
132
+ __match_args__ = ("value",)
113
133
 
114
134
  @method(preserve=True)
115
135
  def __bool__(self) -> bool:
116
- return self.eval()
117
-
118
- def __init__(self, value: bool) -> None: ...
136
+ return self.value
119
137
 
120
138
  @method(egg_fn="not")
121
139
  def __invert__(self) -> Bool: ...
@@ -133,28 +151,36 @@ class Bool(BuiltinExpr, egg_sort="bool"):
133
151
  def implies(self, other: BoolLike) -> Bool: ...
134
152
 
135
153
 
136
- converter(bool, Bool, Bool)
154
+ BoolLike: TypeAlias = Bool | bool
137
155
 
138
- # The types which can be convertered into an i64
139
- i64Like: TypeAlias = Union["i64", int] # noqa: N816, PYI042
156
+
157
+ converter(bool, Bool, Bool)
140
158
 
141
159
 
142
160
  class i64(BuiltinExpr): # noqa: N801
161
+ def __init__(self, value: int) -> None: ...
162
+
143
163
  @method(preserve=True)
164
+ @deprecated("use .value")
144
165
  def eval(self) -> int:
145
- value = _extract_lit(self)
146
- assert isinstance(value, int)
147
- return value
166
+ return self.value
167
+
168
+ @method(preserve=True) # type: ignore[misc]
169
+ @property
170
+ def value(self) -> int:
171
+ if (value := get_literal_value(self)) is not None:
172
+ return value
173
+ raise ExprValueError(self, "i64")
174
+
175
+ __match_args__ = ("value",)
148
176
 
149
177
  @method(preserve=True)
150
178
  def __index__(self) -> int:
151
- return self.eval()
179
+ return self.value
152
180
 
153
181
  @method(preserve=True)
154
182
  def __int__(self) -> int:
155
- return self.eval()
156
-
157
- def __init__(self, value: int) -> None: ...
183
+ return self.value
158
184
 
159
185
  @method(egg_fn="+")
160
186
  def __add__(self, other: i64Like) -> i64: ...
@@ -248,6 +274,9 @@ class i64(BuiltinExpr): # noqa: N801
248
274
  def bool_ge(self, other: i64Like) -> Bool: ...
249
275
 
250
276
 
277
+ # The types which can be convertered into an i64
278
+ i64Like: TypeAlias = i64 | int # noqa: N816, PYI042
279
+
251
280
  converter(int, i64, i64)
252
281
 
253
282
 
@@ -255,25 +284,30 @@ converter(int, i64, i64)
255
284
  def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
256
285
 
257
286
 
258
- f64Like: TypeAlias = Union["f64", float] # noqa: N816, PYI042
259
-
260
-
261
287
  class f64(BuiltinExpr): # noqa: N801
288
+ def __init__(self, value: float) -> None: ...
289
+
262
290
  @method(preserve=True)
291
+ @deprecated("use .value")
263
292
  def eval(self) -> float:
264
- value = _extract_lit(self)
265
- assert isinstance(value, float)
266
- return value
293
+ return self.value
294
+
295
+ @method(preserve=True) # type: ignore[misc]
296
+ @property
297
+ def value(self) -> float:
298
+ if (value := get_literal_value(self)) is not None:
299
+ return value
300
+ raise ExprValueError(self, "f64")
301
+
302
+ __match_args__ = ("value",)
267
303
 
268
304
  @method(preserve=True)
269
305
  def __float__(self) -> float:
270
- return self.eval()
306
+ return self.value
271
307
 
272
308
  @method(preserve=True)
273
309
  def __int__(self) -> int:
274
- return int(self.eval())
275
-
276
- def __init__(self, value: float) -> None: ...
310
+ return int(self.value)
277
311
 
278
312
  @method(egg_fn="neg")
279
313
  def __neg__(self) -> f64: ...
@@ -337,6 +371,9 @@ class f64(BuiltinExpr): # noqa: N801
337
371
  def to_string(self) -> String: ...
338
372
 
339
373
 
374
+ f64Like: TypeAlias = f64 | float # noqa: N816, PYI042
375
+
376
+
340
377
  converter(float, f64, f64)
341
378
 
342
379
 
@@ -346,34 +383,34 @@ V = TypeVar("V", bound=BaseExpr)
346
383
 
347
384
  class Map(BuiltinExpr, Generic[T, V]):
348
385
  @method(preserve=True)
386
+ @deprecated("use .value")
349
387
  def eval(self) -> dict[T, V]:
350
- call = _extract_call(self)
351
- expr = cast("RuntimeExpr", self)
388
+ return self.value
389
+
390
+ @method(preserve=True) # type: ignore[misc]
391
+ @property
392
+ def value(self) -> dict[T, V]:
352
393
  d = {}
353
- while call.callable != ClassMethodRef("Map", "empty"):
354
- msg = "Map can only be evaluated if it is empty or a series of inserts."
355
- if call.callable != MethodRef("Map", "insert"):
356
- raise BuiltinEvalError(msg)
357
- call_typed, k_typed, v_typed = call.args
358
- if not isinstance(call_typed.expr, CallDecl):
359
- raise BuiltinEvalError(msg)
360
- k = cast("T", expr.__with_expr__(k_typed))
361
- v = cast("V", expr.__with_expr__(v_typed))
394
+ while args := get_callable_args(self, Map[T, V].insert):
395
+ self, k, v = args # noqa: PLW0642
362
396
  d[k] = v
363
- call = call_typed.expr
397
+ if get_callable_args(self, Map.empty) is None:
398
+ raise ExprValueError(self, "Map.empty or Map.insert")
364
399
  return d
365
400
 
401
+ __match_args__ = ("value",)
402
+
366
403
  @method(preserve=True)
367
404
  def __iter__(self) -> Iterator[T]:
368
- return iter(self.eval())
405
+ return iter(self.value)
369
406
 
370
407
  @method(preserve=True)
371
408
  def __len__(self) -> int:
372
- return len(self.eval())
409
+ return len(self.value)
373
410
 
374
411
  @method(preserve=True)
375
412
  def __contains__(self, key: T) -> bool:
376
- return key in self.eval()
413
+ return key in self.value
377
414
 
378
415
  @method(egg_fn="map-empty")
379
416
  @classmethod
@@ -416,24 +453,30 @@ MapLike: TypeAlias = Map[T, V] | dict[TO, VO]
416
453
 
417
454
  class Set(BuiltinExpr, Generic[T]):
418
455
  @method(preserve=True)
456
+ @deprecated("use .value")
419
457
  def eval(self) -> set[T]:
420
- call = _extract_call(self)
421
- if call.callable != InitRef("Set"):
422
- msg = "Set can only be initialized with the Set constructor."
423
- raise BuiltinEvalError(msg)
424
- return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
458
+ return self.value
459
+
460
+ @method(preserve=True) # type: ignore[misc]
461
+ @property
462
+ def value(self) -> set[T]:
463
+ if (args := get_callable_args(self, Set[T])) is not None:
464
+ return set(args)
465
+ raise ExprValueError(self, "Set(*xs)")
466
+
467
+ __match_args__ = ("value",)
425
468
 
426
469
  @method(preserve=True)
427
470
  def __iter__(self) -> Iterator[T]:
428
- return iter(self.eval())
471
+ return iter(self.value)
429
472
 
430
473
  @method(preserve=True)
431
474
  def __len__(self) -> int:
432
- return len(self.eval())
475
+ return len(self.value)
433
476
 
434
477
  @method(preserve=True)
435
478
  def __contains__(self, key: T) -> bool:
436
- return key in self.eval()
479
+ return key in self.value
437
480
 
438
481
  @method(egg_fn="set-of")
439
482
  def __init__(self, *args: T) -> None: ...
@@ -480,24 +523,30 @@ SetLike: TypeAlias = Set[T] | set[TO]
480
523
 
481
524
  class MultiSet(BuiltinExpr, Generic[T]):
482
525
  @method(preserve=True)
526
+ @deprecated("use .value")
483
527
  def eval(self) -> list[T]:
484
- call = _extract_call(self)
485
- if call.callable != InitRef("MultiSet"):
486
- msg = "MultiSet can only be initialized with the MultiSet constructor."
487
- raise BuiltinEvalError(msg)
488
- return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
528
+ return self.value
529
+
530
+ @method(preserve=True) # type: ignore[misc]
531
+ @property
532
+ def value(self) -> list[T]:
533
+ if (args := get_callable_args(self, MultiSet[T])) is not None:
534
+ return list(args)
535
+ raise ExprValueError(self, "MultiSet")
536
+
537
+ __match_args__ = ("value",)
489
538
 
490
539
  @method(preserve=True)
491
540
  def __iter__(self) -> Iterator[T]:
492
- return iter(self.eval())
541
+ return iter(self.value)
493
542
 
494
543
  @method(preserve=True)
495
544
  def __len__(self) -> int:
496
- return len(self.eval())
545
+ return len(self.value)
497
546
 
498
547
  @method(preserve=True)
499
548
  def __contains__(self, key: T) -> bool:
500
- return key in self.eval()
549
+ return key in self.value
501
550
 
502
551
  @method(egg_fn="multiset-of")
503
552
  def __init__(self, *args: T) -> None: ...
@@ -529,30 +578,27 @@ class MultiSet(BuiltinExpr, Generic[T]):
529
578
 
530
579
  class Rational(BuiltinExpr):
531
580
  @method(preserve=True)
581
+ @deprecated("use .value")
532
582
  def eval(self) -> Fraction:
533
- call = _extract_call(self)
534
- if call.callable != InitRef("Rational"):
535
- msg = "Rational can only be initialized with the Rational constructor."
536
- raise BuiltinEvalError(msg)
537
-
538
- def _to_int(e: TypedExprDecl) -> int:
539
- expr = e.expr
540
- if not isinstance(expr, LitDecl):
541
- msg = "Rational can only be initialized with literals"
542
- raise BuiltinEvalError(msg)
543
- assert isinstance(expr.value, int)
544
- return expr.value
545
-
546
- num, den = call.args
547
- return Fraction(_to_int(num), _to_int(den))
583
+ return self.value
584
+
585
+ @method(preserve=True) # type: ignore[misc]
586
+ @property
587
+ def value(self) -> Fraction:
588
+ match get_callable_args(self, Rational):
589
+ case (i64(num), i64(den)):
590
+ return Fraction(num, den)
591
+ raise ExprValueError(self, "Rational(i64(num), i64(den))")
592
+
593
+ __match_args__ = ("value",)
548
594
 
549
595
  @method(preserve=True)
550
596
  def __float__(self) -> float:
551
- return float(self.eval())
597
+ return float(self.value)
552
598
 
553
599
  @method(preserve=True)
554
600
  def __int__(self) -> int:
555
- return int(self.eval())
601
+ return int(self.value)
556
602
 
557
603
  @method(egg_fn="rational")
558
604
  def __init__(self, num: i64Like, den: i64Like) -> None: ...
@@ -616,25 +662,27 @@ class Rational(BuiltinExpr):
616
662
 
617
663
  class BigInt(BuiltinExpr):
618
664
  @method(preserve=True)
665
+ @deprecated("use .value")
619
666
  def eval(self) -> int:
620
- call = _extract_call(self)
621
- if call.callable != ClassMethodRef("BigInt", "from_string"):
622
- msg = "BigInt can only be initialized with the BigInt constructor."
623
- raise BuiltinEvalError(msg)
624
- (s,) = call.args
625
- if not isinstance(s.expr, LitDecl):
626
- msg = "BigInt can only be initialized with literals"
627
- raise BuiltinEvalError(msg)
628
- assert isinstance(s.expr.value, str)
629
- return int(s.expr.value)
667
+ return self.value
668
+
669
+ @method(preserve=True) # type: ignore[misc]
670
+ @property
671
+ def value(self) -> int:
672
+ match get_callable_args(self, BigInt.from_string):
673
+ case (String(s),):
674
+ return int(s)
675
+ raise ExprValueError(self, "BigInt.from_string(String(s))")
676
+
677
+ __match_args__ = ("value",)
630
678
 
631
679
  @method(preserve=True)
632
680
  def __index__(self) -> int:
633
- return self.eval()
681
+ return self.value
634
682
 
635
683
  @method(preserve=True)
636
684
  def __int__(self) -> int:
637
- return self.eval()
685
+ return self.value
638
686
 
639
687
  @method(egg_fn="from-string")
640
688
  @classmethod
@@ -741,34 +789,27 @@ BigIntLike: TypeAlias = BigInt | i64Like
741
789
 
742
790
  class BigRat(BuiltinExpr):
743
791
  @method(preserve=True)
792
+ @deprecated("use .value")
744
793
  def eval(self) -> Fraction:
745
- call = _extract_call(self)
746
- if call.callable != InitRef("BigRat"):
747
- msg = "BigRat can only be initialized with the BigRat constructor."
748
- raise BuiltinEvalError(msg)
749
-
750
- def _to_fraction(e: TypedExprDecl) -> Fraction:
751
- expr = e.expr
752
- if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
753
- msg = "BigRat can only be initialized BigInt strings"
754
- raise BuiltinEvalError(msg)
755
- (s,) = expr.args
756
- if not isinstance(s.expr, LitDecl):
757
- msg = "BigInt can only be initialized with literals"
758
- raise BuiltinEvalError(msg)
759
- assert isinstance(s.expr.value, str)
760
- return Fraction(s.expr.value)
761
-
762
- num, den = call.args
763
- return Fraction(_to_fraction(num), _to_fraction(den))
794
+ return self.value
795
+
796
+ @method(preserve=True) # type: ignore[misc]
797
+ @property
798
+ def value(self) -> Fraction:
799
+ match get_callable_args(self, BigRat):
800
+ case (BigInt(num), BigInt(den)):
801
+ return Fraction(num, den)
802
+ raise ExprValueError(self, "BigRat(BigInt(num), BigInt(den))")
803
+
804
+ __match_args__ = ("value",)
764
805
 
765
806
  @method(preserve=True)
766
807
  def __float__(self) -> float:
767
- return float(self.eval())
808
+ return float(self.value)
768
809
 
769
810
  @method(preserve=True)
770
811
  def __int__(self) -> int:
771
- return int(self.eval())
812
+ return int(self.value)
772
813
 
773
814
  @method(egg_fn="bigrat")
774
815
  def __init__(self, num: BigIntLike, den: BigIntLike) -> None: ...
@@ -848,27 +889,32 @@ BigRatLike: TypeAlias = BigRat | Fraction
848
889
 
849
890
  class Vec(BuiltinExpr, Generic[T]):
850
891
  @method(preserve=True)
892
+ @deprecated("use .value")
851
893
  def eval(self) -> tuple[T, ...]:
852
- call = _extract_call(self)
853
- if call.callable == ClassMethodRef("Vec", "empty"):
894
+ return self.value
895
+
896
+ @method(preserve=True) # type: ignore[misc]
897
+ @property
898
+ def value(self) -> tuple[T, ...]:
899
+ if get_callable_args(self, Vec.empty) is not None:
854
900
  return ()
901
+ if (args := get_callable_args(self, Vec[T])) is not None:
902
+ return args
903
+ raise ExprValueError(self, "Vec(*xs) or Vec.empty()")
855
904
 
856
- if call.callable != InitRef("Vec"):
857
- msg = "Vec can only be initialized with the Vec constructor."
858
- raise BuiltinEvalError(msg)
859
- return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
905
+ __match_args__ = ("value",)
860
906
 
861
907
  @method(preserve=True)
862
908
  def __iter__(self) -> Iterator[T]:
863
- return iter(self.eval())
909
+ return iter(self.value)
864
910
 
865
911
  @method(preserve=True)
866
912
  def __len__(self) -> int:
867
- return len(self.eval())
913
+ return len(self.value)
868
914
 
869
915
  @method(preserve=True)
870
916
  def __contains__(self, key: T) -> bool:
871
- return key in self.eval()
917
+ return key in self.value
872
918
 
873
919
  @method(egg_fn="vec-of")
874
920
  def __init__(self, *args: T) -> None: ...
@@ -922,13 +968,20 @@ VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
922
968
 
923
969
  class PyObject(BuiltinExpr):
924
970
  @method(preserve=True)
971
+ @deprecated("use .value")
925
972
  def eval(self) -> object:
973
+ return self.value
974
+
975
+ @method(preserve=True) # type: ignore[misc]
976
+ @property
977
+ def value(self) -> object:
926
978
  expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
927
979
  if not isinstance(expr, PyObjectDecl):
928
- msg = "PyObject can only be evaluated if it is a PyObject literal"
929
- raise BuiltinEvalError(msg)
980
+ raise ExprValueError(self, "PyObject(x)")
930
981
  return expr.value
931
982
 
983
+ __match_args__ = ("value",)
984
+
932
985
  def __init__(self, value: object) -> None: ...
933
986
 
934
987
  @method(egg_fn="py-from-string")
@@ -1018,6 +1071,23 @@ class UnstableFn(BuiltinExpr, Generic[T, Unpack[TS]]):
1018
1071
  @method(egg_fn="unstable-fn")
1019
1072
  def __init__(self, f, *partial) -> None: ...
1020
1073
 
1074
+ @method(preserve=True)
1075
+ @deprecated("use .value")
1076
+ def eval(self) -> Callable[[Unpack[TS]], T]:
1077
+ return self.value
1078
+
1079
+ @method(preserve=True) # type: ignore[prop-decorator]
1080
+ @property
1081
+ def value(self) -> Callable[[Unpack[TS]], T]:
1082
+ """
1083
+ If this is a constructor, returns either the callable directly or a `functools.partial` function if args are provided.
1084
+ """
1085
+ if (fn := get_literal_value(self)) is not None:
1086
+ return fn
1087
+ raise ExprValueError(self, "UnstableFn(f, *args)")
1088
+
1089
+ __match_args__ = ("value",)
1090
+
1021
1091
  @method(egg_fn="unstable-app")
1022
1092
  def __call__(self, *args: Unpack[TS]) -> T: ...
1023
1093
 
@@ -1028,57 +1098,36 @@ converter(RuntimeFunction, UnstableFn, UnstableFn)
1028
1098
  converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
1029
1099
 
1030
1100
 
1031
- def _convert_function(a: FunctionType) -> UnstableFn:
1101
+ def _convert_function(fn: FunctionType) -> UnstableFn:
1032
1102
  """
1033
- Converts a function type to an unstable function
1103
+ Converts a function type to an unstable function. This function will be an anon function in egglog.
1104
+
1105
+ Would just be UnstableFn(function(a)) but we have to account for unbound vars within the body.
1034
1106
 
1035
- Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
1036
- which are runtime expressions with `var`s in them and add them as args to the function
1107
+ This means that we have to turn all of those unbound vars into args to the function, and then
1108
+ partially apply them, alongside creating a default rewrite for the function.
1037
1109
  """
1038
- # Update annotations of a to be the type we are trying to convert to
1039
- return_tp, *arg_tps = get_type_args()
1040
- a.__annotations__ = {
1041
- "return": return_tp,
1042
- # The first varnames should always be the arg names
1043
- **dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
1044
- }
1045
- # Modify name to make it unique
1046
- # a.__name__ = f"{a.__name__} {hash(a.__code__)}"
1047
- transformed_fn = functionalize(a, value_to_annotation)
1048
- assert isinstance(transformed_fn, partial)
1049
- return UnstableFn(
1050
- function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
1051
- *transformed_fn.args,
1110
+ decls = Declarations()
1111
+ return_type, *arg_types = [resolve_type_annotation_mutate(decls, tp) for tp in get_type_args()]
1112
+ arg_names = [p.name for p in signature(fn).parameters.values()]
1113
+ arg_decls = [
1114
+ TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True)
1115
+ ]
1116
+ res = resolve_literal(
1117
+ return_type, fn(*(RuntimeExpr.__from_values__(decls, a) for a in arg_decls)), Thunk.value(decls)
1052
1118
  )
1119
+ res_expr = res.__egg_typed_expr__
1120
+ decls |= res
1121
+ # these are all the args that appear in the body that are not bound by the args of the function
1122
+ unbound_vars = list(collect_unbound_vars(res_expr) - set(arg_decls))
1123
+ # prefix the args with them
1124
+ fn_ref = UnnamedFunctionRef(tuple(unbound_vars + arg_decls), res_expr)
1125
+ rewrite_decl = DefaultRewriteDecl(fn_ref, res_expr.expr, subsume=True)
1126
+ ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, get_current_ruleset())
1127
+ ruleset_decls |= res
1053
1128
 
1054
-
1055
- def value_to_annotation(a: object) -> type | None:
1056
- # only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
1057
- if not isinstance(a, RuntimeExpr):
1058
- return None
1059
- return cast("type", RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
1129
+ fn = RuntimeFunction(Thunk.value(decls), Thunk.value(fn_ref))
1130
+ return UnstableFn(fn, *(RuntimeExpr.__from_values__(decls, v) for v in unbound_vars))
1060
1131
 
1061
1132
 
1062
1133
  converter(FunctionType, UnstableFn, _convert_function)
1063
-
1064
-
1065
- def _extract_lit(e: BaseExpr) -> LitType:
1066
- """
1067
- Special case extracting literals to make this faster by using termdag directly.
1068
- """
1069
- expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1070
- if not isinstance(expr, LitDecl):
1071
- msg = "Expected a literal"
1072
- raise BuiltinEvalError(msg)
1073
- return expr.value
1074
-
1075
-
1076
- def _extract_call(e: BaseExpr) -> CallDecl:
1077
- """
1078
- Extracts the call form of an expression
1079
- """
1080
- expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1081
- if not isinstance(expr, CallDecl):
1082
- msg = "Expected a call expression"
1083
- raise BuiltinEvalError(msg)
1084
- return expr