egglog 7.1.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Binary file
egglog/bindings.pyi CHANGED
@@ -14,6 +14,7 @@ class SerializedEGraph:
14
14
  def to_dot(self) -> str: ...
15
15
  def to_json(self) -> str: ...
16
16
  def map_ops(self, map: dict[str, str]) -> None: ...
17
+ def split_e_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
17
18
 
18
19
  @final
19
20
  class PyObjectSort:
@@ -32,7 +33,7 @@ class EGraph:
32
33
  record: bool = False,
33
34
  ) -> None: ...
34
35
  def commands(self) -> str | None: ...
35
- def parse_program(self, __input: str, /) -> list[_Command]: ...
36
+ def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ...
36
37
  def run_program(self, *commands: _Command) -> list[str]: ...
37
38
  def extract_report(self) -> _ExtractReport | None: ...
38
39
  def run_report(self) -> RunReport | None: ...
@@ -43,7 +44,6 @@ class EGraph:
43
44
  max_functions: int | None = None,
44
45
  max_calls_per_function: int | None = None,
45
46
  include_temporary_functions: bool = False,
46
- split_primitive_outputs: bool = False,
47
47
  ) -> SerializedEGraph: ...
48
48
  def eval_py_object(self, __expr: _Expr) -> object: ...
49
49
  def eval_i64(self, __expr: _Expr) -> int: ...
@@ -56,6 +56,25 @@ class EGraph:
56
56
  class EggSmolError(Exception):
57
57
  context: str
58
58
 
59
+ ##
60
+ # Spans
61
+ ##
62
+
63
+ @final
64
+ class SrcFile:
65
+ def __init__(self, name: str, contents: str | None = None) -> None: ...
66
+ name: str
67
+ contents: str | None
68
+
69
+ @final
70
+ class Span:
71
+ def __init__(self, file: SrcFile, start: int, end: int) -> None: ...
72
+ file: SrcFile
73
+ start: int
74
+ end: int
75
+
76
+ DUMMY_SPAN: Span = ...
77
+
59
78
  ##
60
79
  # Literals
61
80
  ##
@@ -92,17 +111,20 @@ _Literal: TypeAlias = Int | F64 | String | Bool | Unit
92
111
 
93
112
  @final
94
113
  class Lit:
95
- def __init__(self, value: _Literal) -> None: ...
114
+ def __init__(self, span: Span, value: _Literal) -> None: ...
115
+ span: Span
96
116
  value: _Literal
97
117
 
98
118
  @final
99
119
  class Var:
100
- def __init__(self, name: str) -> None: ...
120
+ def __init__(self, span: Span, name: str) -> None: ...
121
+ span: Span
101
122
  name: str
102
123
 
103
124
  @final
104
125
  class Call:
105
- def __init__(self, name: str, args: list[_Expr]) -> None: ...
126
+ def __init__(self, span: Span, name: str, args: list[_Expr]) -> None: ...
127
+ span: Span
106
128
  name: str
107
129
  args: list[_Expr]
108
130
 
@@ -142,7 +164,8 @@ class TermDag:
142
164
 
143
165
  @final
144
166
  class Eq:
145
- def __init__(self, exprs: list[_Expr]) -> None: ...
167
+ def __init__(self, span: Span, exprs: list[_Expr]) -> None: ...
168
+ span: Span
146
169
  exprs: list[_Expr]
147
170
 
148
171
  @final
@@ -172,43 +195,50 @@ _Change: TypeAlias = Delete | Subsume
172
195
 
173
196
  @final
174
197
  class Let:
175
- def __init__(self, lhs: str, rhs: _Expr) -> None: ...
198
+ def __init__(self, span: Span, lhs: str, rhs: _Expr) -> None: ...
199
+ span: Span
176
200
  lhs: str
177
201
  rhs: _Expr
178
202
 
179
203
  @final
180
204
  class Set:
181
- def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
205
+ def __init__(self, span: Span, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
206
+ span: Span
182
207
  lhs: str
183
208
  args: list[_Expr]
184
209
  rhs: _Expr
185
210
 
186
211
  @final
187
212
  class Change:
213
+ span: Span
188
214
  change: _Change
189
215
  sym: str
190
216
  args: list[_Expr]
191
- def __init__(self, change: _Change, sym: str, args: list[_Expr]) -> None: ...
217
+ def __init__(self, span: Span, change: _Change, sym: str, args: list[_Expr]) -> None: ...
192
218
 
193
219
  @final
194
220
  class Union:
195
- def __init__(self, lhs: _Expr, rhs: _Expr) -> None: ...
221
+ def __init__(self, span: Span, lhs: _Expr, rhs: _Expr) -> None: ...
222
+ span: Span
196
223
  lhs: _Expr
197
224
  rhs: _Expr
198
225
 
199
226
  @final
200
227
  class Panic:
201
- def __init__(self, msg: str) -> None: ...
228
+ def __init__(self, span: Span, msg: str) -> None: ...
229
+ span: Span
202
230
  msg: str
203
231
 
204
232
  @final
205
233
  class Expr_: # noqa: N801
206
- def __init__(self, expr: _Expr) -> None: ...
234
+ def __init__(self, span: Span, expr: _Expr) -> None: ...
235
+ span: Span
207
236
  expr: _Expr
208
237
 
209
238
  @final
210
239
  class Extract:
211
- def __init__(self, expr: _Expr, variants: _Expr) -> None: ...
240
+ def __init__(self, span: Span, expr: _Expr, variants: _Expr) -> None: ...
241
+ span: Span
212
242
  expr: _Expr
213
243
  variants: _Expr
214
244
 
@@ -256,17 +286,19 @@ class Schema:
256
286
 
257
287
  @final
258
288
  class Rule:
289
+ span: Span
259
290
  head: list[_Action]
260
291
  body: list[_Fact]
261
- def __init__(self, head: list[_Action], body: list[_Fact]) -> None: ...
292
+ def __init__(self, span: Span, head: list[_Action], body: list[_Fact]) -> None: ...
262
293
 
263
294
  @final
264
295
  class Rewrite:
296
+ span: Span
265
297
  lhs: _Expr
266
298
  rhs: _Expr
267
299
  conditions: list[_Fact]
268
300
 
269
- def __init__(self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
301
+ def __init__(self, span: Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
270
302
 
271
303
  @final
272
304
  class RunConfig:
@@ -322,24 +354,28 @@ _ExtractReport: TypeAlias = Variants | Best
322
354
 
323
355
  @final
324
356
  class Saturate:
357
+ span: Span
325
358
  schedule: _Schedule
326
- def __init__(self, schedule: _Schedule) -> None: ...
359
+ def __init__(self, span: Span, schedule: _Schedule) -> None: ...
327
360
 
328
361
  @final
329
362
  class Repeat:
363
+ span: Span
330
364
  length: int
331
365
  schedule: _Schedule
332
- def __init__(self, length: int, schedule: _Schedule) -> None: ...
366
+ def __init__(self, span: Span, length: int, schedule: _Schedule) -> None: ...
333
367
 
334
368
  @final
335
369
  class Run:
370
+ span: Span
336
371
  config: RunConfig
337
- def __init__(self, config: RunConfig) -> None: ...
372
+ def __init__(self, span: Span, config: RunConfig) -> None: ...
338
373
 
339
374
  @final
340
375
  class Sequence:
376
+ span: Span
341
377
  schedules: list[_Schedule]
342
- def __init__(self, schedules: list[_Schedule]) -> None: ...
378
+ def __init__(self, span: Span, schedules: list[_Schedule]) -> None: ...
343
379
 
344
380
  _Schedule: TypeAlias = Saturate | Repeat | Run | Sequence
345
381
 
@@ -361,9 +397,10 @@ class Datatype:
361
397
 
362
398
  @final
363
399
  class Declare:
400
+ span: Span
364
401
  name: str
365
402
  sort: str
366
- def __init__(self, name: str, sort: str) -> None: ...
403
+ def __init__(self, span: Span, name: str, sort: str) -> None: ...
367
404
 
368
405
  @final
369
406
  class Sort:
@@ -421,9 +458,10 @@ class Simplify:
421
458
 
422
459
  @final
423
460
  class Calc:
461
+ span: Span
424
462
  identifiers: list[IdentSort]
425
463
  exprs: list[_Expr]
426
- def __init__(self, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
464
+ def __init__(self, span: Span, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
427
465
 
428
466
  @final
429
467
  class QueryExtract:
@@ -433,14 +471,16 @@ class QueryExtract:
433
471
 
434
472
  @final
435
473
  class Check:
474
+ span: Span
436
475
  facts: list[_Fact]
437
- def __init__(self, facts: list[_Fact]) -> None: ...
476
+ def __init__(self, span: Span, facts: list[_Fact]) -> None: ...
438
477
 
439
478
  @final
440
479
  class PrintFunction:
480
+ span: Span
441
481
  name: str
442
482
  length: int
443
- def __init__(self, name: str, length: int) -> None: ...
483
+ def __init__(self, span: Span, name: str, length: int) -> None: ...
444
484
 
445
485
  @final
446
486
  class PrintSize:
egglog/builtins.py CHANGED
@@ -6,17 +6,21 @@ Builtin sorts and function to egg.
6
6
  from __future__ import annotations
7
7
 
8
8
  from functools import partial
9
- from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, overload
9
+ from types import FunctionType
10
+ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
10
11
 
11
12
  from typing_extensions import TypeVarTuple, Unpack
12
13
 
13
- from .conversion import converter
14
- from .egraph import Expr, Unit, function, method
15
- from .runtime import RuntimeFunction
14
+ from .conversion import converter, get_type_args
15
+ from .egraph import Expr, Unit, function, get_current_ruleset, method
16
+ from .functionalize import functionalize
17
+ from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
18
+ from .thunk import Thunk
16
19
 
17
20
  if TYPE_CHECKING:
18
21
  from collections.abc import Callable
19
22
 
23
+
20
24
  __all__ = [
21
25
  "i64",
22
26
  "i64Like",
@@ -80,7 +84,7 @@ class Bool(Expr, egg_sort="bool", builtin=True):
80
84
  converter(bool, Bool, Bool)
81
85
 
82
86
  # The types which can be convertered into an i64
83
- i64Like = Union["i64", int] # noqa: N816
87
+ i64Like: TypeAlias = Union["i64", int] # noqa: N816, PYI042
84
88
 
85
89
 
86
90
  class i64(Expr, builtin=True): # noqa: N801
@@ -182,7 +186,7 @@ converter(int, i64, i64)
182
186
  def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
183
187
 
184
188
 
185
- f64Like = Union["f64", float] # noqa: N816
189
+ f64Like: TypeAlias = Union["f64", float] # noqa: N816, PYI042
186
190
 
187
191
 
188
192
  class f64(Expr, builtin=True): # noqa: N801
@@ -404,6 +408,12 @@ class Vec(Expr, Generic[T], builtin=True):
404
408
  @method(egg_fn="rebuild")
405
409
  def rebuild(self) -> Vec[T]: ...
406
410
 
411
+ @method(egg_fn="vec-remove")
412
+ def remove(self, index: i64Like) -> Vec[T]: ...
413
+
414
+ @method(egg_fn="vec-set")
415
+ def set(self, index: i64Like, value: T) -> Vec[T]: ...
416
+
407
417
 
408
418
  class PyObject(Expr, builtin=True):
409
419
  def __init__(self, value: object) -> None: ...
@@ -501,3 +511,36 @@ class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
501
511
 
502
512
  converter(RuntimeFunction, UnstableFn, UnstableFn)
503
513
  converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
514
+
515
+
516
+ def _convert_function(a: FunctionType) -> UnstableFn:
517
+ """
518
+ Converts a function type to an unstable function
519
+
520
+ Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
521
+ which are runtime expressions with `var`s in them and add them as args to the function
522
+ """
523
+ # Update annotations of a to be the type we are trying to convert to
524
+ return_tp, *arg_tps = get_type_args()
525
+ a.__annotations__ = {
526
+ "return": return_tp,
527
+ # The first varnames should always be the arg names
528
+ **dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
529
+ }
530
+ # Modify name to make it unique
531
+ # a.__name__ = f"{a.__name__} {hash(a.__code__)}"
532
+ transformed_fn = functionalize(a, value_to_annotation)
533
+ assert isinstance(transformed_fn, partial)
534
+ return UnstableFn(
535
+ function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args
536
+ )
537
+
538
+
539
+ def value_to_annotation(a: object) -> type | None:
540
+ # only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
541
+ if not isinstance(a, RuntimeExpr):
542
+ return None
543
+ return cast(type, RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
544
+
545
+
546
+ converter(FunctionType, UnstableFn, _convert_function)
egglog/conversion.py CHANGED
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from contextlib import contextmanager
4
+ from contextvars import ContextVar
3
5
  from dataclasses import dataclass
4
6
  from typing import TYPE_CHECKING, NewType, TypeVar, cast
5
7
 
@@ -9,9 +11,8 @@ from .runtime import *
9
11
  from .thunk import *
10
12
 
11
13
  if TYPE_CHECKING:
12
- from collections.abc import Callable
14
+ from collections.abc import Callable, Generator
13
15
 
14
- from .declarations import HasDeclerations
15
16
  from .egraph import Expr
16
17
 
17
18
  __all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
@@ -84,7 +85,7 @@ def convert(source: object, target: type[V]) -> V:
84
85
  Convert a source object to a target type.
85
86
  """
86
87
  assert isinstance(target, RuntimeClass)
87
- return cast(V, resolve_literal(target.__egg_tp__, source))
88
+ return cast(V, resolve_literal(target.__egg_tp__, source, target.__egg_decls_thunk__))
88
89
 
89
90
 
90
91
  def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
@@ -92,7 +93,7 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
92
93
  Convert a source object to the same type as the target.
93
94
  """
94
95
  tp = target.__egg_typed_expr__.tp
95
- return resolve_literal(tp.to_var(), source)
96
+ return resolve_literal(tp.to_var(), source, Thunk.value(target.__egg_decls__))
96
97
 
97
98
 
98
99
  def process_tp(tp: type | RuntimeClass) -> TypeName | type:
@@ -140,7 +141,28 @@ def identity(x: object) -> object:
140
141
  return x
141
142
 
142
143
 
143
- def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
144
+ TYPE_ARGS = ContextVar[tuple[RuntimeClass, ...]]("TYPE_ARGS")
145
+
146
+
147
+ def get_type_args() -> tuple[RuntimeClass, ...]:
148
+ """
149
+ Get the type args for the type being converted.
150
+ """
151
+ return TYPE_ARGS.get()
152
+
153
+
154
+ @contextmanager
155
+ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declarations]) -> Generator[None, None, None]:
156
+ token = TYPE_ARGS.set(tuple(RuntimeClass(decls, a.to_var()) for a in args))
157
+ try:
158
+ yield
159
+ finally:
160
+ TYPE_ARGS.reset(token)
161
+
162
+
163
+ def resolve_literal(
164
+ tp: TypeOrVarRef, arg: object, decls: Callable[[], Declarations] = CONVERSIONS_DECLS
165
+ ) -> RuntimeExpr:
144
166
  arg_type = _get_tp(arg)
145
167
 
146
168
  # If we have any type variables, dont bother trying to resolve the literal, just return the arg
@@ -148,7 +170,7 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
148
170
  tp_just = tp.to_just()
149
171
  except NotImplementedError:
150
172
  # If this is a var, it has to be a runtime expession
151
- assert isinstance(arg, RuntimeExpr)
173
+ assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
152
174
  return arg
153
175
  tp_name = TypeName(tp_just.name)
154
176
  if arg_type == tp_name:
@@ -158,13 +180,14 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
158
180
  # Try all parent types as well, if we are converting from a Python type
159
181
  for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
160
182
  try:
161
- fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
183
+ fn = CONVERSIONS[(arg_type_instance, tp_name)][1]
162
184
  except KeyError:
163
185
  continue
164
186
  break
165
187
  else:
166
188
  raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
167
- return fn(arg)
189
+ with with_type_args(tp_just.args, decls):
190
+ return fn(arg)
168
191
 
169
192
 
170
193
  def _get_tp(x: object) -> TypeName | type: