egglog 7.2.0__cp310-none-win_amd64.whl → 8.0.1__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +107 -53
- egglog/builtins.py +49 -6
- egglog/conversion.py +32 -9
- egglog/declarations.py +82 -4
- egglog/egraph.py +260 -179
- egglog/egraph_state.py +149 -66
- egglog/examples/higher_order_functions.py +4 -9
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +4 -8
- egglog/exp/array_api_loopnest.py +149 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +62 -25
- egglog/exp/program_gen.py +23 -17
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +88 -44
- egglog/runtime.py +53 -40
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35774 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/METADATA +33 -32
- egglog-8.0.1.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.1.dist-info/licenses}/LICENSE +0 -0
|
Binary file
|
egglog/bindings.pyi
CHANGED
|
@@ -5,8 +5,6 @@ from typing import TypeAlias
|
|
|
5
5
|
|
|
6
6
|
from typing_extensions import final
|
|
7
7
|
|
|
8
|
-
HIGH_COST: int
|
|
9
|
-
|
|
10
8
|
@final
|
|
11
9
|
class SerializedEGraph:
|
|
12
10
|
def inline_leaves(self) -> None: ...
|
|
@@ -14,12 +12,14 @@ class SerializedEGraph:
|
|
|
14
12
|
def to_dot(self) -> str: ...
|
|
15
13
|
def to_json(self) -> str: ...
|
|
16
14
|
def map_ops(self, map: dict[str, str]) -> None: ...
|
|
15
|
+
def split_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
|
|
17
16
|
|
|
18
17
|
@final
|
|
19
18
|
class PyObjectSort:
|
|
20
19
|
def __init__(self) -> None: ...
|
|
21
20
|
def store(self, __o: object, /) -> _Expr: ...
|
|
22
21
|
|
|
22
|
+
def parse_program(__input: str, /, filename: str | None = None) -> list[_Command]: ...
|
|
23
23
|
@final
|
|
24
24
|
class EGraph:
|
|
25
25
|
def __init__(
|
|
@@ -28,11 +28,9 @@ class EGraph:
|
|
|
28
28
|
*,
|
|
29
29
|
fact_directory: str | Path | None = None,
|
|
30
30
|
seminaive: bool = True,
|
|
31
|
-
terms_encoding: bool = False,
|
|
32
31
|
record: bool = False,
|
|
33
32
|
) -> None: ...
|
|
34
33
|
def commands(self) -> str | None: ...
|
|
35
|
-
def parse_program(self, __input: str, /) -> list[_Command]: ...
|
|
36
34
|
def run_program(self, *commands: _Command) -> list[str]: ...
|
|
37
35
|
def extract_report(self) -> _ExtractReport | None: ...
|
|
38
36
|
def run_report(self) -> RunReport | None: ...
|
|
@@ -43,7 +41,6 @@ class EGraph:
|
|
|
43
41
|
max_functions: int | None = None,
|
|
44
42
|
max_calls_per_function: int | None = None,
|
|
45
43
|
include_temporary_functions: bool = False,
|
|
46
|
-
split_primitive_outputs: bool = False,
|
|
47
44
|
) -> SerializedEGraph: ...
|
|
48
45
|
def eval_py_object(self, __expr: _Expr) -> object: ...
|
|
49
46
|
def eval_i64(self, __expr: _Expr) -> int: ...
|
|
@@ -56,6 +53,25 @@ class EGraph:
|
|
|
56
53
|
class EggSmolError(Exception):
|
|
57
54
|
context: str
|
|
58
55
|
|
|
56
|
+
##
|
|
57
|
+
# Spans
|
|
58
|
+
##
|
|
59
|
+
|
|
60
|
+
@final
|
|
61
|
+
class SrcFile:
|
|
62
|
+
def __init__(self, name: str, contents: str | None = None) -> None: ...
|
|
63
|
+
name: str
|
|
64
|
+
contents: str | None
|
|
65
|
+
|
|
66
|
+
@final
|
|
67
|
+
class Span:
|
|
68
|
+
def __init__(self, file: SrcFile, start: int, end: int) -> None: ...
|
|
69
|
+
file: SrcFile
|
|
70
|
+
start: int
|
|
71
|
+
end: int
|
|
72
|
+
|
|
73
|
+
DUMMY_SPAN: Span = ...
|
|
74
|
+
|
|
59
75
|
##
|
|
60
76
|
# Literals
|
|
61
77
|
##
|
|
@@ -92,17 +108,20 @@ _Literal: TypeAlias = Int | F64 | String | Bool | Unit
|
|
|
92
108
|
|
|
93
109
|
@final
|
|
94
110
|
class Lit:
|
|
95
|
-
def __init__(self, value: _Literal) -> None: ...
|
|
111
|
+
def __init__(self, span: Span, value: _Literal) -> None: ...
|
|
112
|
+
span: Span
|
|
96
113
|
value: _Literal
|
|
97
114
|
|
|
98
115
|
@final
|
|
99
116
|
class Var:
|
|
100
|
-
def __init__(self, name: str) -> None: ...
|
|
117
|
+
def __init__(self, span: Span, name: str) -> None: ...
|
|
118
|
+
span: Span
|
|
101
119
|
name: str
|
|
102
120
|
|
|
103
121
|
@final
|
|
104
122
|
class Call:
|
|
105
|
-
def __init__(self, name: str, args: list[_Expr]) -> None: ...
|
|
123
|
+
def __init__(self, span: Span, name: str, args: list[_Expr]) -> None: ...
|
|
124
|
+
span: Span
|
|
106
125
|
name: str
|
|
107
126
|
args: list[_Expr]
|
|
108
127
|
|
|
@@ -142,7 +161,8 @@ class TermDag:
|
|
|
142
161
|
|
|
143
162
|
@final
|
|
144
163
|
class Eq:
|
|
145
|
-
def __init__(self, exprs: list[_Expr]) -> None: ...
|
|
164
|
+
def __init__(self, span: Span, exprs: list[_Expr]) -> None: ...
|
|
165
|
+
span: Span
|
|
146
166
|
exprs: list[_Expr]
|
|
147
167
|
|
|
148
168
|
@final
|
|
@@ -172,43 +192,50 @@ _Change: TypeAlias = Delete | Subsume
|
|
|
172
192
|
|
|
173
193
|
@final
|
|
174
194
|
class Let:
|
|
175
|
-
def __init__(self, lhs: str, rhs: _Expr) -> None: ...
|
|
195
|
+
def __init__(self, span: Span, lhs: str, rhs: _Expr) -> None: ...
|
|
196
|
+
span: Span
|
|
176
197
|
lhs: str
|
|
177
198
|
rhs: _Expr
|
|
178
199
|
|
|
179
200
|
@final
|
|
180
201
|
class Set:
|
|
181
|
-
def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
|
|
202
|
+
def __init__(self, span: Span, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
|
|
203
|
+
span: Span
|
|
182
204
|
lhs: str
|
|
183
205
|
args: list[_Expr]
|
|
184
206
|
rhs: _Expr
|
|
185
207
|
|
|
186
208
|
@final
|
|
187
209
|
class Change:
|
|
210
|
+
span: Span
|
|
188
211
|
change: _Change
|
|
189
212
|
sym: str
|
|
190
213
|
args: list[_Expr]
|
|
191
|
-
def __init__(self, change: _Change, sym: str, args: list[_Expr]) -> None: ...
|
|
214
|
+
def __init__(self, span: Span, change: _Change, sym: str, args: list[_Expr]) -> None: ...
|
|
192
215
|
|
|
193
216
|
@final
|
|
194
217
|
class Union:
|
|
195
|
-
def __init__(self, lhs: _Expr, rhs: _Expr) -> None: ...
|
|
218
|
+
def __init__(self, span: Span, lhs: _Expr, rhs: _Expr) -> None: ...
|
|
219
|
+
span: Span
|
|
196
220
|
lhs: _Expr
|
|
197
221
|
rhs: _Expr
|
|
198
222
|
|
|
199
223
|
@final
|
|
200
224
|
class Panic:
|
|
201
|
-
def __init__(self, msg: str) -> None: ...
|
|
225
|
+
def __init__(self, span: Span, msg: str) -> None: ...
|
|
226
|
+
span: Span
|
|
202
227
|
msg: str
|
|
203
228
|
|
|
204
229
|
@final
|
|
205
230
|
class Expr_: # noqa: N801
|
|
206
|
-
def __init__(self, expr: _Expr) -> None: ...
|
|
231
|
+
def __init__(self, span: Span, expr: _Expr) -> None: ...
|
|
232
|
+
span: Span
|
|
207
233
|
expr: _Expr
|
|
208
234
|
|
|
209
235
|
@final
|
|
210
236
|
class Extract:
|
|
211
|
-
def __init__(self, expr: _Expr, variants: _Expr) -> None: ...
|
|
237
|
+
def __init__(self, span: Span, expr: _Expr, variants: _Expr) -> None: ...
|
|
238
|
+
span: Span
|
|
212
239
|
expr: _Expr
|
|
213
240
|
variants: _Expr
|
|
214
241
|
|
|
@@ -220,6 +247,7 @@ _Action: TypeAlias = Let | Set | Change | Union | Panic | Expr_ | Extract
|
|
|
220
247
|
|
|
221
248
|
@final
|
|
222
249
|
class FunctionDecl:
|
|
250
|
+
span: Span
|
|
223
251
|
name: str
|
|
224
252
|
schema: Schema
|
|
225
253
|
default: _Expr | None
|
|
@@ -231,6 +259,7 @@ class FunctionDecl:
|
|
|
231
259
|
|
|
232
260
|
def __init__(
|
|
233
261
|
self,
|
|
262
|
+
span: Span,
|
|
234
263
|
name: str,
|
|
235
264
|
schema: Schema,
|
|
236
265
|
default: _Expr | None = None,
|
|
@@ -243,7 +272,8 @@ class FunctionDecl:
|
|
|
243
272
|
|
|
244
273
|
@final
|
|
245
274
|
class Variant:
|
|
246
|
-
def __init__(self, name: str, types: list[str], cost: int | None = None) -> None: ...
|
|
275
|
+
def __init__(self, span: Span, name: str, types: list[str], cost: int | None = None) -> None: ...
|
|
276
|
+
span: Span
|
|
247
277
|
name: str
|
|
248
278
|
types: list[str]
|
|
249
279
|
cost: int | None
|
|
@@ -256,17 +286,19 @@ class Schema:
|
|
|
256
286
|
|
|
257
287
|
@final
|
|
258
288
|
class Rule:
|
|
289
|
+
span: Span
|
|
259
290
|
head: list[_Action]
|
|
260
291
|
body: list[_Fact]
|
|
261
|
-
def __init__(self, head: list[_Action], body: list[_Fact]) -> None: ...
|
|
292
|
+
def __init__(self, span: Span, head: list[_Action], body: list[_Fact]) -> None: ...
|
|
262
293
|
|
|
263
294
|
@final
|
|
264
295
|
class Rewrite:
|
|
296
|
+
span: Span
|
|
265
297
|
lhs: _Expr
|
|
266
298
|
rhs: _Expr
|
|
267
299
|
conditions: list[_Fact]
|
|
268
300
|
|
|
269
|
-
def __init__(self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
|
|
301
|
+
def __init__(self, span: Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
|
|
270
302
|
|
|
271
303
|
@final
|
|
272
304
|
class RunConfig:
|
|
@@ -322,27 +354,48 @@ _ExtractReport: TypeAlias = Variants | Best
|
|
|
322
354
|
|
|
323
355
|
@final
|
|
324
356
|
class Saturate:
|
|
357
|
+
span: Span
|
|
325
358
|
schedule: _Schedule
|
|
326
|
-
def __init__(self, schedule: _Schedule) -> None: ...
|
|
359
|
+
def __init__(self, span: Span, schedule: _Schedule) -> None: ...
|
|
327
360
|
|
|
328
361
|
@final
|
|
329
362
|
class Repeat:
|
|
363
|
+
span: Span
|
|
330
364
|
length: int
|
|
331
365
|
schedule: _Schedule
|
|
332
|
-
def __init__(self, length: int, schedule: _Schedule) -> None: ...
|
|
366
|
+
def __init__(self, span: Span, length: int, schedule: _Schedule) -> None: ...
|
|
333
367
|
|
|
334
368
|
@final
|
|
335
369
|
class Run:
|
|
370
|
+
span: Span
|
|
336
371
|
config: RunConfig
|
|
337
|
-
def __init__(self, config: RunConfig) -> None: ...
|
|
372
|
+
def __init__(self, span: Span, config: RunConfig) -> None: ...
|
|
338
373
|
|
|
339
374
|
@final
|
|
340
375
|
class Sequence:
|
|
376
|
+
span: Span
|
|
341
377
|
schedules: list[_Schedule]
|
|
342
|
-
def __init__(self, schedules: list[_Schedule]) -> None: ...
|
|
378
|
+
def __init__(self, span: Span, schedules: list[_Schedule]) -> None: ...
|
|
343
379
|
|
|
344
380
|
_Schedule: TypeAlias = Saturate | Repeat | Run | Sequence
|
|
345
381
|
|
|
382
|
+
##
|
|
383
|
+
# Subdatatypes
|
|
384
|
+
##
|
|
385
|
+
|
|
386
|
+
@final
|
|
387
|
+
class SubVariants:
|
|
388
|
+
def __init__(self, variants: list[Variant]) -> None: ...
|
|
389
|
+
variants: list[Variant]
|
|
390
|
+
|
|
391
|
+
@final
|
|
392
|
+
class NewSort:
|
|
393
|
+
def __init__(self, name: str, args: list[_Expr]) -> None: ...
|
|
394
|
+
name: str
|
|
395
|
+
args: list[_Expr]
|
|
396
|
+
|
|
397
|
+
_Subdatatypes: TypeAlias = SubVariants | NewSort
|
|
398
|
+
|
|
346
399
|
##
|
|
347
400
|
# Commands
|
|
348
401
|
##
|
|
@@ -355,21 +408,23 @@ class SetOption:
|
|
|
355
408
|
|
|
356
409
|
@final
|
|
357
410
|
class Datatype:
|
|
411
|
+
span: Span
|
|
358
412
|
name: str
|
|
359
413
|
variants: list[Variant]
|
|
360
|
-
def __init__(self, name: str, variants: list[Variant]) -> None: ...
|
|
414
|
+
def __init__(self, span: Span, name: str, variants: list[Variant]) -> None: ...
|
|
361
415
|
|
|
362
416
|
@final
|
|
363
|
-
class
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
def __init__(self,
|
|
417
|
+
class Datatypes:
|
|
418
|
+
span: Span
|
|
419
|
+
datatypes: list[tuple[Span, str, _Subdatatypes]]
|
|
420
|
+
def __init__(self, span: Span, datatypes: list[tuple[Span, str, _Subdatatypes]]) -> None: ...
|
|
367
421
|
|
|
368
422
|
@final
|
|
369
423
|
class Sort:
|
|
424
|
+
span: Span
|
|
370
425
|
name: str
|
|
371
426
|
presort_and_args: tuple[str, list[_Expr]] | None
|
|
372
|
-
def __init__(self, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
|
|
427
|
+
def __init__(self, span: Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
|
|
373
428
|
|
|
374
429
|
@final
|
|
375
430
|
class Function:
|
|
@@ -415,49 +470,50 @@ class RunSchedule:
|
|
|
415
470
|
|
|
416
471
|
@final
|
|
417
472
|
class Simplify:
|
|
473
|
+
span: Span
|
|
418
474
|
expr: _Expr
|
|
419
475
|
schedule: _Schedule
|
|
420
|
-
def __init__(self, expr: _Expr, schedule: _Schedule) -> None: ...
|
|
421
|
-
|
|
422
|
-
@final
|
|
423
|
-
class Calc:
|
|
424
|
-
identifiers: list[IdentSort]
|
|
425
|
-
exprs: list[_Expr]
|
|
426
|
-
def __init__(self, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
|
|
476
|
+
def __init__(self, span: Span, expr: _Expr, schedule: _Schedule) -> None: ...
|
|
427
477
|
|
|
428
478
|
@final
|
|
429
479
|
class QueryExtract:
|
|
480
|
+
span: Span
|
|
430
481
|
variants: int
|
|
431
482
|
expr: _Expr
|
|
432
|
-
def __init__(self, variants: int, expr: _Expr) -> None: ...
|
|
483
|
+
def __init__(self, span: Span, variants: int, expr: _Expr) -> None: ...
|
|
433
484
|
|
|
434
485
|
@final
|
|
435
486
|
class Check:
|
|
487
|
+
span: Span
|
|
436
488
|
facts: list[_Fact]
|
|
437
|
-
def __init__(self, facts: list[_Fact]) -> None: ...
|
|
489
|
+
def __init__(self, span: Span, facts: list[_Fact]) -> None: ...
|
|
438
490
|
|
|
439
491
|
@final
|
|
440
492
|
class PrintFunction:
|
|
493
|
+
span: Span
|
|
441
494
|
name: str
|
|
442
495
|
length: int
|
|
443
|
-
def __init__(self, name: str, length: int) -> None: ...
|
|
496
|
+
def __init__(self, span: Span, name: str, length: int) -> None: ...
|
|
444
497
|
|
|
445
498
|
@final
|
|
446
499
|
class PrintSize:
|
|
500
|
+
span: Span
|
|
447
501
|
name: str | None
|
|
448
|
-
def __init__(self, name: str | None) -> None: ...
|
|
502
|
+
def __init__(self, span: Span, name: str | None) -> None: ...
|
|
449
503
|
|
|
450
504
|
@final
|
|
451
505
|
class Output:
|
|
506
|
+
span: Span
|
|
452
507
|
file: str
|
|
453
508
|
exprs: list[_Expr]
|
|
454
|
-
def __init__(self, file: str, exprs: list[_Expr]) -> None: ...
|
|
509
|
+
def __init__(self, span: Span, file: str, exprs: list[_Expr]) -> None: ...
|
|
455
510
|
|
|
456
511
|
@final
|
|
457
512
|
class Input:
|
|
513
|
+
span: Span
|
|
458
514
|
name: str
|
|
459
515
|
file: str
|
|
460
|
-
def __init__(self, name: str, file: str) -> None: ...
|
|
516
|
+
def __init__(self, span: Span, name: str, file: str) -> None: ...
|
|
461
517
|
|
|
462
518
|
@final
|
|
463
519
|
class Push:
|
|
@@ -466,29 +522,29 @@ class Push:
|
|
|
466
522
|
|
|
467
523
|
@final
|
|
468
524
|
class Pop:
|
|
525
|
+
span: Span
|
|
469
526
|
length: int
|
|
470
|
-
def __init__(self, length: int) -> None: ...
|
|
527
|
+
def __init__(self, span: Span, length: int) -> None: ...
|
|
471
528
|
|
|
472
529
|
@final
|
|
473
530
|
class Fail:
|
|
531
|
+
span: Span
|
|
474
532
|
command: _Command
|
|
475
|
-
def __init__(self, command: _Command) -> None: ...
|
|
533
|
+
def __init__(self, span: Span, command: _Command) -> None: ...
|
|
476
534
|
|
|
477
535
|
@final
|
|
478
536
|
class Include:
|
|
537
|
+
span: Span
|
|
479
538
|
path: str
|
|
480
|
-
def __init__(self, path: str) -> None: ...
|
|
481
|
-
|
|
482
|
-
@final
|
|
483
|
-
class CheckProof:
|
|
484
|
-
def __init__(self) -> None: ...
|
|
539
|
+
def __init__(self, span: Span, path: str) -> None: ...
|
|
485
540
|
|
|
486
541
|
@final
|
|
487
542
|
class Relation:
|
|
543
|
+
span: Span
|
|
488
544
|
constructor: str
|
|
489
545
|
inputs: list[str]
|
|
490
546
|
|
|
491
|
-
def __init__(self, constructor: str, inputs: list[str]) -> None: ...
|
|
547
|
+
def __init__(self, span: Span, constructor: str, inputs: list[str]) -> None: ...
|
|
492
548
|
|
|
493
549
|
@final
|
|
494
550
|
class PrintOverallStatistics:
|
|
@@ -503,7 +559,7 @@ class UnstableCombinedRuleset:
|
|
|
503
559
|
_Command: TypeAlias = (
|
|
504
560
|
SetOption
|
|
505
561
|
| Datatype
|
|
506
|
-
|
|
|
562
|
+
| Datatypes
|
|
507
563
|
| Sort
|
|
508
564
|
| Function
|
|
509
565
|
| AddRuleset
|
|
@@ -512,7 +568,6 @@ _Command: TypeAlias = (
|
|
|
512
568
|
| BiRewriteCommand
|
|
513
569
|
| ActionCommand
|
|
514
570
|
| RunSchedule
|
|
515
|
-
| Calc
|
|
516
571
|
| Simplify
|
|
517
572
|
| QueryExtract
|
|
518
573
|
| Check
|
|
@@ -524,7 +579,6 @@ _Command: TypeAlias = (
|
|
|
524
579
|
| Pop
|
|
525
580
|
| Fail
|
|
526
581
|
| Include
|
|
527
|
-
| CheckProof
|
|
528
582
|
| Relation
|
|
529
583
|
| PrintOverallStatistics
|
|
530
584
|
| UnstableCombinedRuleset
|
egglog/builtins.py
CHANGED
|
@@ -6,17 +6,21 @@ Builtin sorts and function to egg.
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
8
|
from functools import partial
|
|
9
|
-
from
|
|
9
|
+
from types import FunctionType
|
|
10
|
+
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
|
|
10
11
|
|
|
11
12
|
from typing_extensions import TypeVarTuple, Unpack
|
|
12
13
|
|
|
13
|
-
from .conversion import converter
|
|
14
|
-
from .egraph import Expr, Unit, function, method
|
|
15
|
-
from .
|
|
14
|
+
from .conversion import converter, get_type_args
|
|
15
|
+
from .egraph import Expr, Unit, function, get_current_ruleset, method
|
|
16
|
+
from .functionalize import functionalize
|
|
17
|
+
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
|
|
18
|
+
from .thunk import Thunk
|
|
16
19
|
|
|
17
20
|
if TYPE_CHECKING:
|
|
18
21
|
from collections.abc import Callable
|
|
19
22
|
|
|
23
|
+
|
|
20
24
|
__all__ = [
|
|
21
25
|
"i64",
|
|
22
26
|
"i64Like",
|
|
@@ -80,7 +84,7 @@ class Bool(Expr, egg_sort="bool", builtin=True):
|
|
|
80
84
|
converter(bool, Bool, Bool)
|
|
81
85
|
|
|
82
86
|
# The types which can be convertered into an i64
|
|
83
|
-
i64Like = Union["i64", int] # noqa: N816
|
|
87
|
+
i64Like: TypeAlias = Union["i64", int] # noqa: N816, PYI042
|
|
84
88
|
|
|
85
89
|
|
|
86
90
|
class i64(Expr, builtin=True): # noqa: N801
|
|
@@ -182,7 +186,7 @@ converter(int, i64, i64)
|
|
|
182
186
|
def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
|
|
183
187
|
|
|
184
188
|
|
|
185
|
-
f64Like = Union["f64", float] # noqa: N816
|
|
189
|
+
f64Like: TypeAlias = Union["f64", float] # noqa: N816, PYI042
|
|
186
190
|
|
|
187
191
|
|
|
188
192
|
class f64(Expr, builtin=True): # noqa: N801
|
|
@@ -404,6 +408,12 @@ class Vec(Expr, Generic[T], builtin=True):
|
|
|
404
408
|
@method(egg_fn="rebuild")
|
|
405
409
|
def rebuild(self) -> Vec[T]: ...
|
|
406
410
|
|
|
411
|
+
@method(egg_fn="vec-remove")
|
|
412
|
+
def remove(self, index: i64Like) -> Vec[T]: ...
|
|
413
|
+
|
|
414
|
+
@method(egg_fn="vec-set")
|
|
415
|
+
def set(self, index: i64Like, value: T) -> Vec[T]: ...
|
|
416
|
+
|
|
407
417
|
|
|
408
418
|
class PyObject(Expr, builtin=True):
|
|
409
419
|
def __init__(self, value: object) -> None: ...
|
|
@@ -501,3 +511,36 @@ class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
|
|
|
501
511
|
|
|
502
512
|
converter(RuntimeFunction, UnstableFn, UnstableFn)
|
|
503
513
|
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def _convert_function(a: FunctionType) -> UnstableFn:
|
|
517
|
+
"""
|
|
518
|
+
Converts a function type to an unstable function
|
|
519
|
+
|
|
520
|
+
Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
|
|
521
|
+
which are runtime expressions with `var`s in them and add them as args to the function
|
|
522
|
+
"""
|
|
523
|
+
# Update annotations of a to be the type we are trying to convert to
|
|
524
|
+
return_tp, *arg_tps = get_type_args()
|
|
525
|
+
a.__annotations__ = {
|
|
526
|
+
"return": return_tp,
|
|
527
|
+
# The first varnames should always be the arg names
|
|
528
|
+
**dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
|
|
529
|
+
}
|
|
530
|
+
# Modify name to make it unique
|
|
531
|
+
# a.__name__ = f"{a.__name__} {hash(a.__code__)}"
|
|
532
|
+
transformed_fn = functionalize(a, value_to_annotation)
|
|
533
|
+
assert isinstance(transformed_fn, partial)
|
|
534
|
+
return UnstableFn(
|
|
535
|
+
function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def value_to_annotation(a: object) -> type | None:
|
|
540
|
+
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
|
|
541
|
+
if not isinstance(a, RuntimeExpr):
|
|
542
|
+
return None
|
|
543
|
+
return cast(type, RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
converter(FunctionType, UnstableFn, _convert_function)
|
egglog/conversion.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from contextvars import ContextVar
|
|
3
5
|
from dataclasses import dataclass
|
|
4
6
|
from typing import TYPE_CHECKING, NewType, TypeVar, cast
|
|
5
7
|
|
|
@@ -9,9 +11,8 @@ from .runtime import *
|
|
|
9
11
|
from .thunk import *
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
12
|
-
from collections.abc import Callable
|
|
14
|
+
from collections.abc import Callable, Generator
|
|
13
15
|
|
|
14
|
-
from .declarations import HasDeclerations
|
|
15
16
|
from .egraph import Expr
|
|
16
17
|
|
|
17
18
|
__all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
|
|
@@ -84,7 +85,7 @@ def convert(source: object, target: type[V]) -> V:
|
|
|
84
85
|
Convert a source object to a target type.
|
|
85
86
|
"""
|
|
86
87
|
assert isinstance(target, RuntimeClass)
|
|
87
|
-
return cast(V, resolve_literal(target.__egg_tp__, source))
|
|
88
|
+
return cast(V, resolve_literal(target.__egg_tp__, source, target.__egg_decls_thunk__))
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
@@ -92,7 +93,7 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
|
92
93
|
Convert a source object to the same type as the target.
|
|
93
94
|
"""
|
|
94
95
|
tp = target.__egg_typed_expr__.tp
|
|
95
|
-
return resolve_literal(tp.to_var(), source)
|
|
96
|
+
return resolve_literal(tp.to_var(), source, Thunk.value(target.__egg_decls__))
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
def process_tp(tp: type | RuntimeClass) -> TypeName | type:
|
|
@@ -140,7 +141,28 @@ def identity(x: object) -> object:
|
|
|
140
141
|
return x
|
|
141
142
|
|
|
142
143
|
|
|
143
|
-
|
|
144
|
+
TYPE_ARGS = ContextVar[tuple[RuntimeClass, ...]]("TYPE_ARGS")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_type_args() -> tuple[RuntimeClass, ...]:
|
|
148
|
+
"""
|
|
149
|
+
Get the type args for the type being converted.
|
|
150
|
+
"""
|
|
151
|
+
return TYPE_ARGS.get()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@contextmanager
|
|
155
|
+
def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declarations]) -> Generator[None, None, None]:
|
|
156
|
+
token = TYPE_ARGS.set(tuple(RuntimeClass(decls, a.to_var()) for a in args))
|
|
157
|
+
try:
|
|
158
|
+
yield
|
|
159
|
+
finally:
|
|
160
|
+
TYPE_ARGS.reset(token)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def resolve_literal(
|
|
164
|
+
tp: TypeOrVarRef, arg: object, decls: Callable[[], Declarations] = CONVERSIONS_DECLS
|
|
165
|
+
) -> RuntimeExpr:
|
|
144
166
|
arg_type = _get_tp(arg)
|
|
145
167
|
|
|
146
168
|
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
|
|
@@ -148,7 +170,7 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
|
148
170
|
tp_just = tp.to_just()
|
|
149
171
|
except NotImplementedError:
|
|
150
172
|
# If this is a var, it has to be a runtime expession
|
|
151
|
-
assert isinstance(arg, RuntimeExpr)
|
|
173
|
+
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
|
|
152
174
|
return arg
|
|
153
175
|
tp_name = TypeName(tp_just.name)
|
|
154
176
|
if arg_type == tp_name:
|
|
@@ -158,13 +180,14 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
|
158
180
|
# Try all parent types as well, if we are converting from a Python type
|
|
159
181
|
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
|
|
160
182
|
try:
|
|
161
|
-
fn = CONVERSIONS[(
|
|
183
|
+
fn = CONVERSIONS[(arg_type_instance, tp_name)][1]
|
|
162
184
|
except KeyError:
|
|
163
185
|
continue
|
|
164
186
|
break
|
|
165
187
|
else:
|
|
166
188
|
raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
|
|
167
|
-
|
|
189
|
+
with with_type_args(tp_just.args, decls):
|
|
190
|
+
return fn(arg)
|
|
168
191
|
|
|
169
192
|
|
|
170
193
|
def _get_tp(x: object) -> TypeName | type:
|
|
@@ -172,6 +195,6 @@ def _get_tp(x: object) -> TypeName | type:
|
|
|
172
195
|
return TypeName(x.__egg_typed_expr__.tp.name)
|
|
173
196
|
tp = type(x)
|
|
174
197
|
# If this value has a custom metaclass, let's use that as our index instead of the type
|
|
175
|
-
if type(tp)
|
|
198
|
+
if type(tp) is not type:
|
|
176
199
|
return type(tp)
|
|
177
200
|
return tp
|