egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Binary file
egglog/bindings.pyi CHANGED
@@ -5,8 +5,6 @@ from typing import TypeAlias
5
5
 
6
6
  from typing_extensions import final
7
7
 
8
- HIGH_COST: int
9
-
10
8
  @final
11
9
  class SerializedEGraph:
12
10
  def inline_leaves(self) -> None: ...
@@ -14,12 +12,14 @@ class SerializedEGraph:
14
12
  def to_dot(self) -> str: ...
15
13
  def to_json(self) -> str: ...
16
14
  def map_ops(self, map: dict[str, str]) -> None: ...
15
+ def split_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
17
16
 
18
17
  @final
19
18
  class PyObjectSort:
20
19
  def __init__(self) -> None: ...
21
20
  def store(self, __o: object, /) -> _Expr: ...
22
21
 
22
+ def parse_program(__input: str, /, filename: str | None = None) -> list[_Command]: ...
23
23
  @final
24
24
  class EGraph:
25
25
  def __init__(
@@ -28,11 +28,9 @@ class EGraph:
28
28
  *,
29
29
  fact_directory: str | Path | None = None,
30
30
  seminaive: bool = True,
31
- terms_encoding: bool = False,
32
31
  record: bool = False,
33
32
  ) -> None: ...
34
33
  def commands(self) -> str | None: ...
35
- def parse_program(self, __input: str, /) -> list[_Command]: ...
36
34
  def run_program(self, *commands: _Command) -> list[str]: ...
37
35
  def extract_report(self) -> _ExtractReport | None: ...
38
36
  def run_report(self) -> RunReport | None: ...
@@ -43,7 +41,6 @@ class EGraph:
43
41
  max_functions: int | None = None,
44
42
  max_calls_per_function: int | None = None,
45
43
  include_temporary_functions: bool = False,
46
- split_primitive_outputs: bool = False,
47
44
  ) -> SerializedEGraph: ...
48
45
  def eval_py_object(self, __expr: _Expr) -> object: ...
49
46
  def eval_i64(self, __expr: _Expr) -> int: ...
@@ -56,6 +53,25 @@ class EGraph:
56
53
  class EggSmolError(Exception):
57
54
  context: str
58
55
 
56
+ ##
57
+ # Spans
58
+ ##
59
+
60
+ @final
61
+ class SrcFile:
62
+ def __init__(self, name: str, contents: str | None = None) -> None: ...
63
+ name: str
64
+ contents: str | None
65
+
66
+ @final
67
+ class Span:
68
+ def __init__(self, file: SrcFile, start: int, end: int) -> None: ...
69
+ file: SrcFile
70
+ start: int
71
+ end: int
72
+
73
+ DUMMY_SPAN: Span = ...
74
+
59
75
  ##
60
76
  # Literals
61
77
  ##
@@ -92,17 +108,20 @@ _Literal: TypeAlias = Int | F64 | String | Bool | Unit
92
108
 
93
109
  @final
94
110
  class Lit:
95
- def __init__(self, value: _Literal) -> None: ...
111
+ def __init__(self, span: Span, value: _Literal) -> None: ...
112
+ span: Span
96
113
  value: _Literal
97
114
 
98
115
  @final
99
116
  class Var:
100
- def __init__(self, name: str) -> None: ...
117
+ def __init__(self, span: Span, name: str) -> None: ...
118
+ span: Span
101
119
  name: str
102
120
 
103
121
  @final
104
122
  class Call:
105
- def __init__(self, name: str, args: list[_Expr]) -> None: ...
123
+ def __init__(self, span: Span, name: str, args: list[_Expr]) -> None: ...
124
+ span: Span
106
125
  name: str
107
126
  args: list[_Expr]
108
127
 
@@ -142,7 +161,8 @@ class TermDag:
142
161
 
143
162
  @final
144
163
  class Eq:
145
- def __init__(self, exprs: list[_Expr]) -> None: ...
164
+ def __init__(self, span: Span, exprs: list[_Expr]) -> None: ...
165
+ span: Span
146
166
  exprs: list[_Expr]
147
167
 
148
168
  @final
@@ -172,43 +192,50 @@ _Change: TypeAlias = Delete | Subsume
172
192
 
173
193
  @final
174
194
  class Let:
175
- def __init__(self, lhs: str, rhs: _Expr) -> None: ...
195
+ def __init__(self, span: Span, lhs: str, rhs: _Expr) -> None: ...
196
+ span: Span
176
197
  lhs: str
177
198
  rhs: _Expr
178
199
 
179
200
  @final
180
201
  class Set:
181
- def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
202
+ def __init__(self, span: Span, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
203
+ span: Span
182
204
  lhs: str
183
205
  args: list[_Expr]
184
206
  rhs: _Expr
185
207
 
186
208
  @final
187
209
  class Change:
210
+ span: Span
188
211
  change: _Change
189
212
  sym: str
190
213
  args: list[_Expr]
191
- def __init__(self, change: _Change, sym: str, args: list[_Expr]) -> None: ...
214
+ def __init__(self, span: Span, change: _Change, sym: str, args: list[_Expr]) -> None: ...
192
215
 
193
216
  @final
194
217
  class Union:
195
- def __init__(self, lhs: _Expr, rhs: _Expr) -> None: ...
218
+ def __init__(self, span: Span, lhs: _Expr, rhs: _Expr) -> None: ...
219
+ span: Span
196
220
  lhs: _Expr
197
221
  rhs: _Expr
198
222
 
199
223
  @final
200
224
  class Panic:
201
- def __init__(self, msg: str) -> None: ...
225
+ def __init__(self, span: Span, msg: str) -> None: ...
226
+ span: Span
202
227
  msg: str
203
228
 
204
229
  @final
205
230
  class Expr_: # noqa: N801
206
- def __init__(self, expr: _Expr) -> None: ...
231
+ def __init__(self, span: Span, expr: _Expr) -> None: ...
232
+ span: Span
207
233
  expr: _Expr
208
234
 
209
235
  @final
210
236
  class Extract:
211
- def __init__(self, expr: _Expr, variants: _Expr) -> None: ...
237
+ def __init__(self, span: Span, expr: _Expr, variants: _Expr) -> None: ...
238
+ span: Span
212
239
  expr: _Expr
213
240
  variants: _Expr
214
241
 
@@ -220,6 +247,7 @@ _Action: TypeAlias = Let | Set | Change | Union | Panic | Expr_ | Extract
220
247
 
221
248
  @final
222
249
  class FunctionDecl:
250
+ span: Span
223
251
  name: str
224
252
  schema: Schema
225
253
  default: _Expr | None
@@ -231,6 +259,7 @@ class FunctionDecl:
231
259
 
232
260
  def __init__(
233
261
  self,
262
+ span: Span,
234
263
  name: str,
235
264
  schema: Schema,
236
265
  default: _Expr | None = None,
@@ -243,7 +272,8 @@ class FunctionDecl:
243
272
 
244
273
  @final
245
274
  class Variant:
246
- def __init__(self, name: str, types: list[str], cost: int | None = None) -> None: ...
275
+ def __init__(self, span: Span, name: str, types: list[str], cost: int | None = None) -> None: ...
276
+ span: Span
247
277
  name: str
248
278
  types: list[str]
249
279
  cost: int | None
@@ -256,17 +286,19 @@ class Schema:
256
286
 
257
287
  @final
258
288
  class Rule:
289
+ span: Span
259
290
  head: list[_Action]
260
291
  body: list[_Fact]
261
- def __init__(self, head: list[_Action], body: list[_Fact]) -> None: ...
292
+ def __init__(self, span: Span, head: list[_Action], body: list[_Fact]) -> None: ...
262
293
 
263
294
  @final
264
295
  class Rewrite:
296
+ span: Span
265
297
  lhs: _Expr
266
298
  rhs: _Expr
267
299
  conditions: list[_Fact]
268
300
 
269
- def __init__(self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
301
+ def __init__(self, span: Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... # noqa: B006
270
302
 
271
303
  @final
272
304
  class RunConfig:
@@ -322,27 +354,48 @@ _ExtractReport: TypeAlias = Variants | Best
322
354
 
323
355
  @final
324
356
  class Saturate:
357
+ span: Span
325
358
  schedule: _Schedule
326
- def __init__(self, schedule: _Schedule) -> None: ...
359
+ def __init__(self, span: Span, schedule: _Schedule) -> None: ...
327
360
 
328
361
  @final
329
362
  class Repeat:
363
+ span: Span
330
364
  length: int
331
365
  schedule: _Schedule
332
- def __init__(self, length: int, schedule: _Schedule) -> None: ...
366
+ def __init__(self, span: Span, length: int, schedule: _Schedule) -> None: ...
333
367
 
334
368
  @final
335
369
  class Run:
370
+ span: Span
336
371
  config: RunConfig
337
- def __init__(self, config: RunConfig) -> None: ...
372
+ def __init__(self, span: Span, config: RunConfig) -> None: ...
338
373
 
339
374
  @final
340
375
  class Sequence:
376
+ span: Span
341
377
  schedules: list[_Schedule]
342
- def __init__(self, schedules: list[_Schedule]) -> None: ...
378
+ def __init__(self, span: Span, schedules: list[_Schedule]) -> None: ...
343
379
 
344
380
  _Schedule: TypeAlias = Saturate | Repeat | Run | Sequence
345
381
 
382
+ ##
383
+ # Subdatatypes
384
+ ##
385
+
386
+ @final
387
+ class SubVariants:
388
+ def __init__(self, variants: list[Variant]) -> None: ...
389
+ variants: list[Variant]
390
+
391
+ @final
392
+ class NewSort:
393
+ def __init__(self, name: str, args: list[_Expr]) -> None: ...
394
+ name: str
395
+ args: list[_Expr]
396
+
397
+ _Subdatatypes: TypeAlias = SubVariants | NewSort
398
+
346
399
  ##
347
400
  # Commands
348
401
  ##
@@ -355,21 +408,23 @@ class SetOption:
355
408
 
356
409
  @final
357
410
  class Datatype:
411
+ span: Span
358
412
  name: str
359
413
  variants: list[Variant]
360
- def __init__(self, name: str, variants: list[Variant]) -> None: ...
414
+ def __init__(self, span: Span, name: str, variants: list[Variant]) -> None: ...
361
415
 
362
416
  @final
363
- class Declare:
364
- name: str
365
- sort: str
366
- def __init__(self, name: str, sort: str) -> None: ...
417
+ class Datatypes:
418
+ span: Span
419
+ datatypes: list[tuple[Span, str, _Subdatatypes]]
420
+ def __init__(self, span: Span, datatypes: list[tuple[Span, str, _Subdatatypes]]) -> None: ...
367
421
 
368
422
  @final
369
423
  class Sort:
424
+ span: Span
370
425
  name: str
371
426
  presort_and_args: tuple[str, list[_Expr]] | None
372
- def __init__(self, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
427
+ def __init__(self, span: Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
373
428
 
374
429
  @final
375
430
  class Function:
@@ -415,49 +470,50 @@ class RunSchedule:
415
470
 
416
471
  @final
417
472
  class Simplify:
473
+ span: Span
418
474
  expr: _Expr
419
475
  schedule: _Schedule
420
- def __init__(self, expr: _Expr, schedule: _Schedule) -> None: ...
421
-
422
- @final
423
- class Calc:
424
- identifiers: list[IdentSort]
425
- exprs: list[_Expr]
426
- def __init__(self, identifiers: list[IdentSort], exprs: list[_Expr]) -> None: ...
476
+ def __init__(self, span: Span, expr: _Expr, schedule: _Schedule) -> None: ...
427
477
 
428
478
  @final
429
479
  class QueryExtract:
480
+ span: Span
430
481
  variants: int
431
482
  expr: _Expr
432
- def __init__(self, variants: int, expr: _Expr) -> None: ...
483
+ def __init__(self, span: Span, variants: int, expr: _Expr) -> None: ...
433
484
 
434
485
  @final
435
486
  class Check:
487
+ span: Span
436
488
  facts: list[_Fact]
437
- def __init__(self, facts: list[_Fact]) -> None: ...
489
+ def __init__(self, span: Span, facts: list[_Fact]) -> None: ...
438
490
 
439
491
  @final
440
492
  class PrintFunction:
493
+ span: Span
441
494
  name: str
442
495
  length: int
443
- def __init__(self, name: str, length: int) -> None: ...
496
+ def __init__(self, span: Span, name: str, length: int) -> None: ...
444
497
 
445
498
  @final
446
499
  class PrintSize:
500
+ span: Span
447
501
  name: str | None
448
- def __init__(self, name: str | None) -> None: ...
502
+ def __init__(self, span: Span, name: str | None) -> None: ...
449
503
 
450
504
  @final
451
505
  class Output:
506
+ span: Span
452
507
  file: str
453
508
  exprs: list[_Expr]
454
- def __init__(self, file: str, exprs: list[_Expr]) -> None: ...
509
+ def __init__(self, span: Span, file: str, exprs: list[_Expr]) -> None: ...
455
510
 
456
511
  @final
457
512
  class Input:
513
+ span: Span
458
514
  name: str
459
515
  file: str
460
- def __init__(self, name: str, file: str) -> None: ...
516
+ def __init__(self, span: Span, name: str, file: str) -> None: ...
461
517
 
462
518
  @final
463
519
  class Push:
@@ -466,29 +522,29 @@ class Push:
466
522
 
467
523
  @final
468
524
  class Pop:
525
+ span: Span
469
526
  length: int
470
- def __init__(self, length: int) -> None: ...
527
+ def __init__(self, span: Span, length: int) -> None: ...
471
528
 
472
529
  @final
473
530
  class Fail:
531
+ span: Span
474
532
  command: _Command
475
- def __init__(self, command: _Command) -> None: ...
533
+ def __init__(self, span: Span, command: _Command) -> None: ...
476
534
 
477
535
  @final
478
536
  class Include:
537
+ span: Span
479
538
  path: str
480
- def __init__(self, path: str) -> None: ...
481
-
482
- @final
483
- class CheckProof:
484
- def __init__(self) -> None: ...
539
+ def __init__(self, span: Span, path: str) -> None: ...
485
540
 
486
541
  @final
487
542
  class Relation:
543
+ span: Span
488
544
  constructor: str
489
545
  inputs: list[str]
490
546
 
491
- def __init__(self, constructor: str, inputs: list[str]) -> None: ...
547
+ def __init__(self, span: Span, constructor: str, inputs: list[str]) -> None: ...
492
548
 
493
549
  @final
494
550
  class PrintOverallStatistics:
@@ -503,7 +559,7 @@ class UnstableCombinedRuleset:
503
559
  _Command: TypeAlias = (
504
560
  SetOption
505
561
  | Datatype
506
- | Declare
562
+ | Datatypes
507
563
  | Sort
508
564
  | Function
509
565
  | AddRuleset
@@ -512,7 +568,6 @@ _Command: TypeAlias = (
512
568
  | BiRewriteCommand
513
569
  | ActionCommand
514
570
  | RunSchedule
515
- | Calc
516
571
  | Simplify
517
572
  | QueryExtract
518
573
  | Check
@@ -524,7 +579,6 @@ _Command: TypeAlias = (
524
579
  | Pop
525
580
  | Fail
526
581
  | Include
527
- | CheckProof
528
582
  | Relation
529
583
  | PrintOverallStatistics
530
584
  | UnstableCombinedRuleset
egglog/builtins.py CHANGED
@@ -6,17 +6,21 @@ Builtin sorts and function to egg.
6
6
  from __future__ import annotations
7
7
 
8
8
  from functools import partial
9
- from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, overload
9
+ from types import FunctionType
10
+ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
10
11
 
11
12
  from typing_extensions import TypeVarTuple, Unpack
12
13
 
13
- from .conversion import converter
14
- from .egraph import Expr, Unit, function, method
15
- from .runtime import RuntimeFunction
14
+ from .conversion import converter, get_type_args
15
+ from .egraph import Expr, Unit, function, get_current_ruleset, method
16
+ from .functionalize import functionalize
17
+ from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
18
+ from .thunk import Thunk
16
19
 
17
20
  if TYPE_CHECKING:
18
21
  from collections.abc import Callable
19
22
 
23
+
20
24
  __all__ = [
21
25
  "i64",
22
26
  "i64Like",
@@ -80,7 +84,7 @@ class Bool(Expr, egg_sort="bool", builtin=True):
80
84
  converter(bool, Bool, Bool)
81
85
 
82
86
  # The types which can be convertered into an i64
83
- i64Like = Union["i64", int] # noqa: N816
87
+ i64Like: TypeAlias = Union["i64", int] # noqa: N816, PYI042
84
88
 
85
89
 
86
90
  class i64(Expr, builtin=True): # noqa: N801
@@ -182,7 +186,7 @@ converter(int, i64, i64)
182
186
  def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
183
187
 
184
188
 
185
- f64Like = Union["f64", float] # noqa: N816
189
+ f64Like: TypeAlias = Union["f64", float] # noqa: N816, PYI042
186
190
 
187
191
 
188
192
  class f64(Expr, builtin=True): # noqa: N801
@@ -404,6 +408,12 @@ class Vec(Expr, Generic[T], builtin=True):
404
408
  @method(egg_fn="rebuild")
405
409
  def rebuild(self) -> Vec[T]: ...
406
410
 
411
+ @method(egg_fn="vec-remove")
412
+ def remove(self, index: i64Like) -> Vec[T]: ...
413
+
414
+ @method(egg_fn="vec-set")
415
+ def set(self, index: i64Like, value: T) -> Vec[T]: ...
416
+
407
417
 
408
418
  class PyObject(Expr, builtin=True):
409
419
  def __init__(self, value: object) -> None: ...
@@ -501,3 +511,36 @@ class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
501
511
 
502
512
  converter(RuntimeFunction, UnstableFn, UnstableFn)
503
513
  converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
514
+
515
+
516
+ def _convert_function(a: FunctionType) -> UnstableFn:
517
+ """
518
+ Converts a function type to an unstable function
519
+
520
+ Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
521
+ which are runtime expressions with `var`s in them and add them as args to the function
522
+ """
523
+ # Update annotations of a to be the type we are trying to convert to
524
+ return_tp, *arg_tps = get_type_args()
525
+ a.__annotations__ = {
526
+ "return": return_tp,
527
+ # The first varnames should always be the arg names
528
+ **dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
529
+ }
530
+ # Modify name to make it unique
531
+ # a.__name__ = f"{a.__name__} {hash(a.__code__)}"
532
+ transformed_fn = functionalize(a, value_to_annotation)
533
+ assert isinstance(transformed_fn, partial)
534
+ return UnstableFn(
535
+ function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args
536
+ )
537
+
538
+
539
+ def value_to_annotation(a: object) -> type | None:
540
+ # only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
541
+ if not isinstance(a, RuntimeExpr):
542
+ return None
543
+ return cast(type, RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
544
+
545
+
546
+ converter(FunctionType, UnstableFn, _convert_function)
egglog/conversion.py CHANGED
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from contextlib import contextmanager
4
+ from contextvars import ContextVar
3
5
  from dataclasses import dataclass
4
6
  from typing import TYPE_CHECKING, NewType, TypeVar, cast
5
7
 
@@ -9,9 +11,8 @@ from .runtime import *
9
11
  from .thunk import *
10
12
 
11
13
  if TYPE_CHECKING:
12
- from collections.abc import Callable
14
+ from collections.abc import Callable, Generator
13
15
 
14
- from .declarations import HasDeclerations
15
16
  from .egraph import Expr
16
17
 
17
18
  __all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
@@ -84,7 +85,7 @@ def convert(source: object, target: type[V]) -> V:
84
85
  Convert a source object to a target type.
85
86
  """
86
87
  assert isinstance(target, RuntimeClass)
87
- return cast(V, resolve_literal(target.__egg_tp__, source))
88
+ return cast(V, resolve_literal(target.__egg_tp__, source, target.__egg_decls_thunk__))
88
89
 
89
90
 
90
91
  def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
@@ -92,7 +93,7 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
92
93
  Convert a source object to the same type as the target.
93
94
  """
94
95
  tp = target.__egg_typed_expr__.tp
95
- return resolve_literal(tp.to_var(), source)
96
+ return resolve_literal(tp.to_var(), source, Thunk.value(target.__egg_decls__))
96
97
 
97
98
 
98
99
  def process_tp(tp: type | RuntimeClass) -> TypeName | type:
@@ -140,7 +141,28 @@ def identity(x: object) -> object:
140
141
  return x
141
142
 
142
143
 
143
- def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
144
+ TYPE_ARGS = ContextVar[tuple[RuntimeClass, ...]]("TYPE_ARGS")
145
+
146
+
147
+ def get_type_args() -> tuple[RuntimeClass, ...]:
148
+ """
149
+ Get the type args for the type being converted.
150
+ """
151
+ return TYPE_ARGS.get()
152
+
153
+
154
+ @contextmanager
155
+ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declarations]) -> Generator[None, None, None]:
156
+ token = TYPE_ARGS.set(tuple(RuntimeClass(decls, a.to_var()) for a in args))
157
+ try:
158
+ yield
159
+ finally:
160
+ TYPE_ARGS.reset(token)
161
+
162
+
163
+ def resolve_literal(
164
+ tp: TypeOrVarRef, arg: object, decls: Callable[[], Declarations] = CONVERSIONS_DECLS
165
+ ) -> RuntimeExpr:
144
166
  arg_type = _get_tp(arg)
145
167
 
146
168
  # If we have any type variables, dont bother trying to resolve the literal, just return the arg
@@ -148,7 +170,7 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
148
170
  tp_just = tp.to_just()
149
171
  except NotImplementedError:
150
172
  # If this is a var, it has to be a runtime expession
151
- assert isinstance(arg, RuntimeExpr)
173
+ assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
152
174
  return arg
153
175
  tp_name = TypeName(tp_just.name)
154
176
  if arg_type == tp_name:
@@ -158,13 +180,14 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
158
180
  # Try all parent types as well, if we are converting from a Python type
159
181
  for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
160
182
  try:
161
- fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
183
+ fn = CONVERSIONS[(arg_type_instance, tp_name)][1]
162
184
  except KeyError:
163
185
  continue
164
186
  break
165
187
  else:
166
188
  raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
167
- return fn(arg)
189
+ with with_type_args(tp_just.args, decls):
190
+ return fn(arg)
168
191
 
169
192
 
170
193
  def _get_tp(x: object) -> TypeName | type:
@@ -172,6 +195,6 @@ def _get_tp(x: object) -> TypeName | type:
172
195
  return TypeName(x.__egg_typed_expr__.tp.name)
173
196
  tp = type(x)
174
197
  # If this value has a custom metaclass, let's use that as our index instead of the type
175
- if type(tp) != type:
198
+ if type(tp) is not type:
176
199
  return type(tp)
177
200
  return tp