effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,900 @@
|
|
1
|
+
"""Type unification and inference utilities for Python's generic type system.
|
2
|
+
|
3
|
+
This module implements a unification algorithm for type inference over a subset of
|
4
|
+
Python's generic types. Unification is a fundamental operation in type systems that
|
5
|
+
finds substitutions for type variables to make two types equivalent.
|
6
|
+
|
7
|
+
The module provides four main operations:
|
8
|
+
|
9
|
+
1. **unify(typ, subtyp, subs={})**: The core unification algorithm that attempts to
|
10
|
+
find a substitution mapping for type variables that makes a pattern type equal to
|
11
|
+
a concrete type. It handles TypeVars, generic types (List[T], Dict[K,V]), unions,
|
12
|
+
callables, and function signatures with inspect.Signature/BoundArguments.
|
13
|
+
|
14
|
+
2. **substitute(typ, subs)**: Applies a substitution mapping to a type expression,
|
15
|
+
replacing all TypeVars with their mapped concrete types. This is used to
|
16
|
+
instantiate generic types after unification.
|
17
|
+
|
18
|
+
3. **freetypevars(typ)**: Extracts all free (unbound) type variables from a type
|
19
|
+
expression. Useful for analyzing generic types and ensuring all TypeVars are
|
20
|
+
properly bound.
|
21
|
+
|
22
|
+
4. **nested_type(value)**: Infers the type of a runtime value, handling nested
|
23
|
+
collections by recursively determining element types. For example, [1, 2, 3]
|
24
|
+
becomes list[int], and {"key": [1, 2]} becomes dict[str, list[int]].
|
25
|
+
|
26
|
+
The unification algorithm uses a single-dispatch pattern to handle different type
|
27
|
+
combinations:
|
28
|
+
- TypeVar unification binds variables to concrete types
|
29
|
+
- Generic type unification matches origins and recursively unifies type arguments
|
30
|
+
- Structural unification handles sequences and mappings by element
|
31
|
+
- Union types attempt unification with any matching branch
|
32
|
+
- Function signatures unify parameter types with bound arguments
|
33
|
+
|
34
|
+
Example usage:
|
35
|
+
>>> from effectful.internals.unification import unify, substitute, freetypevars
|
36
|
+
>>> import typing
|
37
|
+
>>> T = typing.TypeVar('T')
|
38
|
+
>>> K = typing.TypeVar('K')
|
39
|
+
>>> V = typing.TypeVar('V')
|
40
|
+
|
41
|
+
>>> # Find substitution that makes list[T] equal to list[int]
|
42
|
+
>>> subs = unify(list[T], list[int])
|
43
|
+
>>> subs
|
44
|
+
{~T: <class 'int'>}
|
45
|
+
|
46
|
+
>>> # Apply substitution to instantiate a generic type
|
47
|
+
>>> substitute(dict[K, list[V]], {K: str, V: int})
|
48
|
+
dict[str, list[int]]
|
49
|
+
|
50
|
+
>>> # Find all type variables in a type expression
|
51
|
+
>>> freetypevars(dict[str, list[V]])
|
52
|
+
{~V}
|
53
|
+
|
54
|
+
This module is primarily used internally by effectful for type inference in its
|
55
|
+
effect system, allowing it to track and propagate type information through
|
56
|
+
effect handlers and operations.
|
57
|
+
"""
|
58
|
+
|
59
|
+
import abc
|
60
|
+
import builtins
|
61
|
+
import collections
|
62
|
+
import collections.abc
|
63
|
+
import functools
|
64
|
+
import inspect
|
65
|
+
import numbers
|
66
|
+
import operator
|
67
|
+
import types
|
68
|
+
import typing
|
69
|
+
|
70
|
+
try:
|
71
|
+
from typing import _collect_type_parameters as _freetypevars # type: ignore
|
72
|
+
except ImportError:
|
73
|
+
from typing import _collect_parameters as _freetypevars # type: ignore
|
74
|
+
|
75
|
+
import effectful.ops.types
|
76
|
+
|
77
|
+
if typing.TYPE_CHECKING:
|
78
|
+
TypeConstant = type | abc.ABCMeta | types.EllipsisType | None
|
79
|
+
GenericAlias = types.GenericAlias
|
80
|
+
UnionType = types.UnionType
|
81
|
+
else:
|
82
|
+
TypeConstant = (
|
83
|
+
type | abc.ABCMeta | types.EllipsisType | type(None) | type(typing.Any)
|
84
|
+
)
|
85
|
+
GenericAlias = types.GenericAlias | typing._GenericAlias
|
86
|
+
UnionType = types.UnionType | typing._UnionGenericAlias
|
87
|
+
|
88
|
+
TypeVariable = typing.TypeVar | typing.TypeVarTuple | typing.ParamSpec
|
89
|
+
TypeApplication = GenericAlias | UnionType
|
90
|
+
TypeExpression = TypeVariable | TypeConstant | TypeApplication
|
91
|
+
TypeExpressions = TypeExpression | collections.abc.Sequence[TypeExpression]
|
92
|
+
|
93
|
+
Substitutions = collections.abc.Mapping[TypeVariable, TypeExpressions]
|
94
|
+
|
95
|
+
|
96
|
+
@typing.overload
|
97
|
+
def unify(
|
98
|
+
typ: inspect.Signature,
|
99
|
+
subtyp: inspect.BoundArguments,
|
100
|
+
subs: Substitutions = {},
|
101
|
+
) -> Substitutions: ...
|
102
|
+
|
103
|
+
|
104
|
+
@typing.overload
|
105
|
+
def unify(
|
106
|
+
typ: TypeExpressions,
|
107
|
+
subtyp: TypeExpressions,
|
108
|
+
subs: Substitutions = {},
|
109
|
+
) -> Substitutions: ...
|
110
|
+
|
111
|
+
|
112
|
+
def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions:
|
113
|
+
"""
|
114
|
+
Unify a pattern type with a concrete type, returning a substitution map.
|
115
|
+
|
116
|
+
This function attempts to find a substitution of type variables that makes
|
117
|
+
the pattern type (typ) equal to the concrete type (subtyp). It updates
|
118
|
+
and returns the substitution mapping, or raises TypeError if unification
|
119
|
+
is not possible.
|
120
|
+
|
121
|
+
The function handles:
|
122
|
+
- TypeVar unification (binding type variables to concrete types)
|
123
|
+
- Generic type unification (matching origins and recursively unifying args)
|
124
|
+
- Structural unification of sequences and mappings
|
125
|
+
- Exact type matching for non-generic types
|
126
|
+
|
127
|
+
Args:
|
128
|
+
typ: The pattern type that may contain TypeVars to be unified
|
129
|
+
subtyp: The concrete type to unify with the pattern
|
130
|
+
subs: Existing substitution mappings to be extended (not modified)
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
A new substitution mapping that includes all previous substitutions
|
134
|
+
plus any new TypeVar bindings discovered during unification.
|
135
|
+
|
136
|
+
Raises:
|
137
|
+
TypeError: If unification is not possible (incompatible types or
|
138
|
+
conflicting TypeVar bindings)
|
139
|
+
|
140
|
+
Examples:
|
141
|
+
>>> import typing
|
142
|
+
>>> T = typing.TypeVar('T')
|
143
|
+
>>> K = typing.TypeVar('K')
|
144
|
+
>>> V = typing.TypeVar('V')
|
145
|
+
|
146
|
+
>>> # Simple TypeVar unification
|
147
|
+
>>> unify(T, int, {})
|
148
|
+
{~T: <class 'int'>}
|
149
|
+
|
150
|
+
>>> # Generic type unification
|
151
|
+
>>> unify(list[T], list[int], {})
|
152
|
+
{~T: <class 'int'>}
|
153
|
+
|
154
|
+
>>> # Exact type matching
|
155
|
+
>>> unify(int, int, {})
|
156
|
+
{}
|
157
|
+
|
158
|
+
>>> # Failed unification - incompatible types
|
159
|
+
>>> unify(list[T], dict[str, int], {}) # doctest: +ELLIPSIS
|
160
|
+
Traceback (most recent call last):
|
161
|
+
...
|
162
|
+
TypeError: Cannot unify ...
|
163
|
+
|
164
|
+
>>> # Failed unification - conflicting TypeVar binding
|
165
|
+
>>> unify(T, str, {T: int}) # doctest: +ELLIPSIS
|
166
|
+
Traceback (most recent call last):
|
167
|
+
...
|
168
|
+
TypeError: Cannot unify ...
|
169
|
+
"""
|
170
|
+
if isinstance(typ, inspect.Signature):
|
171
|
+
return _unify_signature(typ, subtyp, subs)
|
172
|
+
|
173
|
+
if typ != canonicalize(typ) or subtyp != canonicalize(subtyp):
|
174
|
+
return unify(canonicalize(typ), canonicalize(subtyp), subs)
|
175
|
+
|
176
|
+
if typ is subtyp or typ == subtyp:
|
177
|
+
return subs
|
178
|
+
elif isinstance(typ, TypeVariable) or isinstance(subtyp, TypeVariable):
|
179
|
+
return _unify_typevar(typ, subtyp, subs)
|
180
|
+
elif isinstance(typ, collections.abc.Sequence) or isinstance(
|
181
|
+
subtyp, collections.abc.Sequence
|
182
|
+
):
|
183
|
+
return _unify_sequence(typ, subtyp, subs)
|
184
|
+
elif isinstance(typ, UnionType) or isinstance(subtyp, UnionType):
|
185
|
+
return _unify_union(typ, subtyp, subs)
|
186
|
+
elif isinstance(typ, GenericAlias) or isinstance(subtyp, GenericAlias):
|
187
|
+
return _unify_generic(typ, subtyp, subs)
|
188
|
+
elif isinstance(typ, type) and isinstance(subtyp, type) and issubclass(subtyp, typ):
|
189
|
+
return subs
|
190
|
+
elif typ in (typing.Any, ...) or subtyp in (typing.Any, ...):
|
191
|
+
return subs
|
192
|
+
else:
|
193
|
+
raise TypeError(f"Cannot unify type {typ} with {subtyp} given {subs}. ")
|
194
|
+
|
195
|
+
|
196
|
+
@typing.overload
|
197
|
+
def _unify_typevar(
|
198
|
+
typ: TypeVariable, subtyp: TypeExpression, subs: Substitutions
|
199
|
+
) -> Substitutions: ...
|
200
|
+
|
201
|
+
|
202
|
+
@typing.overload
|
203
|
+
def _unify_typevar(
|
204
|
+
typ: TypeExpression, subtyp: TypeVariable, subs: Substitutions
|
205
|
+
) -> Substitutions: ...
|
206
|
+
|
207
|
+
|
208
|
+
def _unify_typevar(typ, subtyp, subs: Substitutions) -> Substitutions:
|
209
|
+
if isinstance(typ, TypeVariable) and isinstance(subtyp, TypeVariable):
|
210
|
+
return subs if typ == subtyp else {typ: subtyp, **subs}
|
211
|
+
elif isinstance(typ, TypeVariable) and not isinstance(subtyp, TypeVariable):
|
212
|
+
return unify(subs.get(typ, subtyp), subtyp, {typ: subtyp, **subs})
|
213
|
+
elif (
|
214
|
+
not isinstance(typ, TypeVariable)
|
215
|
+
and isinstance(subtyp, TypeVariable)
|
216
|
+
and getattr(subtyp, "__bound__", None) is None
|
217
|
+
):
|
218
|
+
return unify(typ, subs.get(subtyp, typ), {subtyp: typ, **subs})
|
219
|
+
else:
|
220
|
+
raise TypeError(f"Cannot unify type variable {typ} with {subtyp} given {subs}.")
|
221
|
+
|
222
|
+
|
223
|
+
@typing.overload
|
224
|
+
def _unify_sequence(
|
225
|
+
typ: collections.abc.Sequence, subtyp: TypeExpressions, subs: Substitutions
|
226
|
+
) -> Substitutions: ...
|
227
|
+
|
228
|
+
|
229
|
+
@typing.overload
|
230
|
+
def _unify_sequence(
|
231
|
+
typ: TypeExpressions, subtyp: collections.abc.Sequence, subs: Substitutions
|
232
|
+
) -> Substitutions: ...
|
233
|
+
|
234
|
+
|
235
|
+
def _unify_sequence(typ, subtyp, subs: Substitutions) -> Substitutions:
|
236
|
+
if isinstance(typ, types.EllipsisType) or isinstance(subtyp, types.EllipsisType):
|
237
|
+
return subs
|
238
|
+
if len(typ) != len(subtyp):
|
239
|
+
raise TypeError(f"Cannot unify sequence {typ} with {subtyp} given {subs}. ")
|
240
|
+
for p_item, c_item in zip(typ, subtyp):
|
241
|
+
subs = unify(p_item, c_item, subs)
|
242
|
+
return subs
|
243
|
+
|
244
|
+
|
245
|
+
@typing.overload
|
246
|
+
def _unify_union(
|
247
|
+
typ: UnionType, subtyp: TypeExpression, subs: Substitutions
|
248
|
+
) -> Substitutions: ...
|
249
|
+
|
250
|
+
|
251
|
+
@typing.overload
|
252
|
+
def _unify_union(
|
253
|
+
typ: TypeExpression, subtyp: UnionType, subs: Substitutions
|
254
|
+
) -> Substitutions: ...
|
255
|
+
|
256
|
+
|
257
|
+
def _unify_union(typ, subtyp, subs: Substitutions) -> Substitutions:
|
258
|
+
if typ == subtyp:
|
259
|
+
return subs
|
260
|
+
elif isinstance(subtyp, UnionType):
|
261
|
+
# If subtyp is a union, try to unify with each argument
|
262
|
+
for arg in typing.get_args(subtyp):
|
263
|
+
subs = unify(typ, arg, subs)
|
264
|
+
return subs
|
265
|
+
elif isinstance(typ, UnionType):
|
266
|
+
unifiers: list[Substitutions] = []
|
267
|
+
for arg in typing.get_args(typ):
|
268
|
+
try:
|
269
|
+
unifiers.append(unify(arg, subtyp, subs))
|
270
|
+
except TypeError: # noqa
|
271
|
+
continue
|
272
|
+
if len(unifiers) > 0 and all(u == unifiers[0] for u in unifiers):
|
273
|
+
return unifiers[0]
|
274
|
+
raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}")
|
275
|
+
|
276
|
+
|
277
|
+
@typing.overload
|
278
|
+
def _unify_generic(
|
279
|
+
typ: GenericAlias, subtyp: type, subs: Substitutions
|
280
|
+
) -> Substitutions: ...
|
281
|
+
|
282
|
+
|
283
|
+
@typing.overload
|
284
|
+
def _unify_generic(
|
285
|
+
typ: type, subtyp: GenericAlias, subs: Substitutions
|
286
|
+
) -> Substitutions: ...
|
287
|
+
|
288
|
+
|
289
|
+
@typing.overload
|
290
|
+
def _unify_generic(
|
291
|
+
typ: GenericAlias, subtyp: GenericAlias, subs: Substitutions
|
292
|
+
) -> Substitutions: ...
|
293
|
+
|
294
|
+
|
295
|
+
def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions:
|
296
|
+
if (
|
297
|
+
isinstance(typ, GenericAlias)
|
298
|
+
and isinstance(subtyp, GenericAlias)
|
299
|
+
and issubclass(typing.get_origin(subtyp), typing.get_origin(typ))
|
300
|
+
):
|
301
|
+
if typing.get_origin(subtyp) is tuple and typing.get_origin(typ) is not tuple:
|
302
|
+
for arg in typing.get_args(subtyp):
|
303
|
+
subs = unify(typ, tuple[arg, ...], subs) # type: ignore
|
304
|
+
return subs
|
305
|
+
elif typing.get_origin(subtyp) is collections.abc.Mapping and not issubclass(
|
306
|
+
typing.get_origin(typ), collections.abc.Mapping
|
307
|
+
):
|
308
|
+
return unify(typing.get_args(typ)[0], typing.get_args(subtyp)[0], subs)
|
309
|
+
elif typing.get_origin(subtyp) is collections.abc.Generator and not issubclass(
|
310
|
+
typing.get_origin(typ), collections.abc.Generator
|
311
|
+
):
|
312
|
+
return unify(typing.get_args(typ)[0], typing.get_args(subtyp)[0], subs)
|
313
|
+
elif typing.get_origin(typ) == typing.get_origin(subtyp):
|
314
|
+
return unify(typing.get_args(typ), typing.get_args(subtyp), subs)
|
315
|
+
elif types.get_original_bases(typing.get_origin(subtyp)):
|
316
|
+
for base in types.get_original_bases(typing.get_origin(subtyp)):
|
317
|
+
if isinstance(base, type | GenericAlias) and issubclass(
|
318
|
+
typing.get_origin(base) or base, # type: ignore
|
319
|
+
typing.get_origin(typ),
|
320
|
+
):
|
321
|
+
return unify(typ, base[typing.get_args(subtyp)], subs) # type: ignore
|
322
|
+
elif isinstance(typ, type) and isinstance(subtyp, GenericAlias):
|
323
|
+
return unify(typ, typing.get_origin(subtyp), subs)
|
324
|
+
elif (
|
325
|
+
isinstance(typ, GenericAlias)
|
326
|
+
and isinstance(subtyp, type)
|
327
|
+
and issubclass(subtyp, typing.get_origin(typ))
|
328
|
+
):
|
329
|
+
return subs # implicit expansion to subtyp[Any]
|
330
|
+
raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.")
|
331
|
+
|
332
|
+
|
333
|
+
def _unify_signature(
|
334
|
+
typ: inspect.Signature, subtyp: inspect.BoundArguments, subs: Substitutions
|
335
|
+
) -> Substitutions:
|
336
|
+
if typ != subtyp.signature:
|
337
|
+
raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ")
|
338
|
+
|
339
|
+
for name, param in typ.parameters.items():
|
340
|
+
if param.annotation is inspect.Parameter.empty:
|
341
|
+
continue
|
342
|
+
|
343
|
+
if name not in subtyp.arguments:
|
344
|
+
assert param.kind in {
|
345
|
+
inspect.Parameter.VAR_POSITIONAL,
|
346
|
+
inspect.Parameter.VAR_KEYWORD,
|
347
|
+
}
|
348
|
+
continue
|
349
|
+
|
350
|
+
ptyp, psubtyp = param.annotation, subtyp.arguments[name]
|
351
|
+
if param.kind is inspect.Parameter.VAR_POSITIONAL and isinstance(
|
352
|
+
psubtyp, collections.abc.Sequence
|
353
|
+
):
|
354
|
+
for psubtyp_item in _freshen(psubtyp):
|
355
|
+
subs = unify(ptyp, psubtyp_item, subs)
|
356
|
+
elif param.kind is inspect.Parameter.VAR_KEYWORD and isinstance(
|
357
|
+
psubtyp, collections.abc.Mapping
|
358
|
+
):
|
359
|
+
for psubtyp_item in _freshen(tuple(psubtyp.values())):
|
360
|
+
subs = unify(ptyp, psubtyp_item, subs)
|
361
|
+
elif param.kind not in {
|
362
|
+
inspect.Parameter.VAR_KEYWORD,
|
363
|
+
inspect.Parameter.VAR_POSITIONAL,
|
364
|
+
} or isinstance(psubtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs):
|
365
|
+
subs = unify(ptyp, _freshen(psubtyp), subs)
|
366
|
+
else:
|
367
|
+
raise TypeError(f"Cannot unify {param} with {psubtyp} given {subs}")
|
368
|
+
return subs
|
369
|
+
|
370
|
+
|
371
|
+
def _freshen(tp: typing.Any):
|
372
|
+
"""
|
373
|
+
Return a freshened version of the given type expression.
|
374
|
+
|
375
|
+
This function replaces all TypeVars in the type expression with new TypeVars
|
376
|
+
that have unique names, ensuring that the resulting type has no free TypeVars.
|
377
|
+
It is useful for creating fresh type variables in generic programming contexts.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
tp: The type expression to freshen. Can be a plain type, TypeVar,
|
381
|
+
generic alias, or union type.
|
382
|
+
|
383
|
+
Returns:
|
384
|
+
A new type expression with all TypeVars replaced by fresh TypeVars.
|
385
|
+
|
386
|
+
Examples:
|
387
|
+
>>> import typing
|
388
|
+
>>> T = typing.TypeVar('T')
|
389
|
+
>>> isinstance(_freshen(T), typing.TypeVar)
|
390
|
+
True
|
391
|
+
>>> _freshen(T) == T
|
392
|
+
False
|
393
|
+
"""
|
394
|
+
assert all(canonicalize(fv) is fv for fv in freetypevars(tp))
|
395
|
+
subs: Substitutions = {
|
396
|
+
fv: typing.TypeVar(fv.__name__, bound=fv.__bound__)
|
397
|
+
if isinstance(fv, typing.TypeVar)
|
398
|
+
else typing.ParamSpec(fv.__name__)
|
399
|
+
for fv in freetypevars(tp)
|
400
|
+
if isinstance(fv, typing.TypeVar | typing.ParamSpec)
|
401
|
+
}
|
402
|
+
return substitute(tp, subs)
|
403
|
+
|
404
|
+
|
405
|
+
@functools.singledispatch
|
406
|
+
def canonicalize(typ) -> TypeExpressions:
|
407
|
+
"""
|
408
|
+
Normalize generic types
|
409
|
+
"""
|
410
|
+
raise TypeError(f"Cannot canonicalize type {typ}.")
|
411
|
+
|
412
|
+
|
413
|
+
@canonicalize.register
|
414
|
+
def _(typ: type | abc.ABCMeta):
|
415
|
+
if issubclass(typ, effectful.ops.types.Term):
|
416
|
+
return effectful.ops.types.Term
|
417
|
+
elif issubclass(typ, effectful.ops.types.Operation):
|
418
|
+
return effectful.ops.types.Operation
|
419
|
+
elif typ is dict:
|
420
|
+
return collections.abc.MutableMapping
|
421
|
+
elif typ is list:
|
422
|
+
return collections.abc.MutableSequence
|
423
|
+
elif typ is set:
|
424
|
+
return collections.abc.MutableSet
|
425
|
+
elif typ is frozenset:
|
426
|
+
return collections.abc.Set
|
427
|
+
elif typ is range:
|
428
|
+
return collections.abc.Sequence[int]
|
429
|
+
elif typ is types.GeneratorType:
|
430
|
+
return collections.abc.Generator
|
431
|
+
elif typ in {types.FunctionType, types.BuiltinFunctionType, types.LambdaType}:
|
432
|
+
return collections.abc.Callable[..., typing.Any]
|
433
|
+
elif isinstance(typ, abc.ABCMeta) and (
|
434
|
+
typ in collections.abc.__dict__.values() or typ in numbers.__dict__.values()
|
435
|
+
):
|
436
|
+
return typ
|
437
|
+
elif isinstance(typ, type) and (
|
438
|
+
typ in builtins.__dict__.values() or typ in types.__dict__.values()
|
439
|
+
):
|
440
|
+
return typ
|
441
|
+
elif types.get_original_bases(typ):
|
442
|
+
for base in types.get_original_bases(typ):
|
443
|
+
cbase = canonicalize(base)
|
444
|
+
if cbase != object:
|
445
|
+
return cbase
|
446
|
+
return typ
|
447
|
+
else:
|
448
|
+
raise TypeError(f"Cannot canonicalize type {typ}.")
|
449
|
+
|
450
|
+
|
451
|
+
@canonicalize.register
|
452
|
+
def _(typ: types.EllipsisType | None):
|
453
|
+
return typ
|
454
|
+
|
455
|
+
|
456
|
+
@canonicalize.register
|
457
|
+
def _(typ: typing.TypeVar):
|
458
|
+
if (
|
459
|
+
typ.__constraints__
|
460
|
+
or typ.__covariant__
|
461
|
+
or typ.__contravariant__
|
462
|
+
or getattr(typ, "__default__", None) is not getattr(typing, "NoDefault", None)
|
463
|
+
):
|
464
|
+
raise TypeError(f"Cannot canonicalize typevar {typ} with nonempty attributes")
|
465
|
+
return typ
|
466
|
+
|
467
|
+
|
468
|
+
@canonicalize.register
|
469
|
+
def _(typ: typing.ParamSpec):
|
470
|
+
if (
|
471
|
+
typ.__bound__
|
472
|
+
or typ.__covariant__
|
473
|
+
or typ.__contravariant__
|
474
|
+
or getattr(typ, "__default__", None) is not getattr(typing, "NoDefault", None)
|
475
|
+
):
|
476
|
+
raise TypeError(f"Cannot canonicalize typevar {typ} with nonempty attributes")
|
477
|
+
return typ
|
478
|
+
|
479
|
+
|
480
|
+
@canonicalize.register
|
481
|
+
def _(typ: typing.TypeVarTuple):
|
482
|
+
if getattr(typ, "__default__", None) is not getattr(typing, "NoDefault", None):
|
483
|
+
raise TypeError(f"Cannot canonicalize typevar {typ} with nonempty attributes")
|
484
|
+
return typ
|
485
|
+
|
486
|
+
|
487
|
+
@canonicalize.register
|
488
|
+
def _(typ: UnionType):
|
489
|
+
ctyp = canonicalize(typing.get_args(typ)[0])
|
490
|
+
for arg in typing.get_args(typ)[1:]:
|
491
|
+
ctyp = ctyp | canonicalize(arg) # type: ignore
|
492
|
+
return ctyp
|
493
|
+
|
494
|
+
|
495
|
+
@canonicalize.register
|
496
|
+
def _(typ: GenericAlias):
|
497
|
+
origin, args = typing.get_origin(typ), typing.get_args(typ)
|
498
|
+
if origin is tuple and len(args) == 2 and args[-1] is Ellipsis: # Variadic tuple
|
499
|
+
return collections.abc.Sequence[canonicalize(args[0])] # type: ignore
|
500
|
+
elif isinstance(origin, typing._SpecialForm):
|
501
|
+
if len(args) == 1:
|
502
|
+
return canonicalize(args[0])
|
503
|
+
else:
|
504
|
+
raise TypeError(f"Cannot canonicalize type {typ}")
|
505
|
+
else:
|
506
|
+
return canonicalize(origin)[tuple(canonicalize(a) for a in args)] # type: ignore
|
507
|
+
|
508
|
+
|
509
|
+
@canonicalize.register
|
510
|
+
def _(typ: list | tuple):
|
511
|
+
return type(typ)(canonicalize(item) for item in typ)
|
512
|
+
|
513
|
+
|
514
|
+
@canonicalize.register
|
515
|
+
def _(typ: effectful.ops.types._InterpretationMeta):
|
516
|
+
return typ
|
517
|
+
|
518
|
+
|
519
|
+
@canonicalize.register
|
520
|
+
def _(typ: typing._AnnotatedAlias): # type: ignore
|
521
|
+
return canonicalize(typing.get_args(typ)[0])
|
522
|
+
|
523
|
+
|
524
|
+
@canonicalize.register
|
525
|
+
def _(typ: typing._SpecialGenericAlias): # type: ignore
|
526
|
+
assert not typing.get_args(typ), "Should not have type arguments"
|
527
|
+
return canonicalize(typing.get_origin(typ))
|
528
|
+
|
529
|
+
|
530
|
+
@canonicalize.register
|
531
|
+
def _(typ: typing._LiteralGenericAlias): # type: ignore
|
532
|
+
return canonicalize(nested_type(typing.get_args(typ)[0]))
|
533
|
+
|
534
|
+
|
535
|
+
@canonicalize.register
|
536
|
+
def _(typ: typing.NewType):
|
537
|
+
return canonicalize(typ.__supertype__)
|
538
|
+
|
539
|
+
|
540
|
+
@canonicalize.register
|
541
|
+
def _(typ: typing.TypeAliasType):
|
542
|
+
return canonicalize(typ.__value__)
|
543
|
+
|
544
|
+
|
545
|
+
@canonicalize.register
|
546
|
+
def _(typ: typing._ConcatenateGenericAlias): # type: ignore
|
547
|
+
return Ellipsis
|
548
|
+
|
549
|
+
|
550
|
+
@canonicalize.register
|
551
|
+
def _(typ: typing._AnyMeta): # type: ignore
|
552
|
+
return typing.Any
|
553
|
+
|
554
|
+
|
555
|
+
@canonicalize.register
|
556
|
+
def _(typ: typing.ParamSpecArgs | typing.ParamSpecKwargs):
|
557
|
+
return typing.Any
|
558
|
+
|
559
|
+
|
560
|
+
@canonicalize.register
|
561
|
+
def _(typ: typing._SpecialForm):
|
562
|
+
return typing.Any
|
563
|
+
|
564
|
+
|
565
|
+
@canonicalize.register
|
566
|
+
def _(typ: typing._ProtocolMeta):
|
567
|
+
return typing.Any
|
568
|
+
|
569
|
+
|
570
|
+
@canonicalize.register
|
571
|
+
def _(typ: typing._UnpackGenericAlias): # type: ignore
|
572
|
+
raise TypeError(f"Cannot canonicalize type {typ}")
|
573
|
+
|
574
|
+
|
575
|
+
@canonicalize.register
|
576
|
+
def _(typ: typing.ForwardRef):
|
577
|
+
if typ.__forward_value__ is not None:
|
578
|
+
return canonicalize(typ.__forward_value__)
|
579
|
+
else:
|
580
|
+
raise TypeError(f"Cannot canonicalize lazy ForwardRef {typ}.")
|
581
|
+
|
582
|
+
|
583
|
+
@functools.singledispatch
|
584
|
+
def nested_type(value) -> TypeExpression:
|
585
|
+
"""
|
586
|
+
Infer the type of a value, handling nested collections with generic parameters.
|
587
|
+
|
588
|
+
This function is a singledispatch generic function that determines the type
|
589
|
+
of a given value. For collections (mappings, sequences, sets), it recursively
|
590
|
+
infers the types of contained elements to produce a properly parameterized
|
591
|
+
generic type. For example, a list [1, 2, 3] becomes Sequence[int].
|
592
|
+
|
593
|
+
The function handles:
|
594
|
+
- Basic types and type annotations (passed through unchanged)
|
595
|
+
- Collections with recursive type inference for elements
|
596
|
+
- Special cases like str/bytes (treated as types, not sequences)
|
597
|
+
- Tuples (preserving exact element types)
|
598
|
+
- Empty collections (returning the collection's type without parameters)
|
599
|
+
|
600
|
+
This is primarily used by canonicalize() to handle cases where values
|
601
|
+
are provided instead of type annotations.
|
602
|
+
|
603
|
+
Args:
|
604
|
+
value: Any value whose type needs to be inferred. Can be a type,
|
605
|
+
a value instance, or a collection containing other values.
|
606
|
+
|
607
|
+
Returns:
|
608
|
+
The inferred type, potentially with generic parameters for collections.
|
609
|
+
|
610
|
+
Raises:
|
611
|
+
TypeError: If the value is a TypeVar (TypeVars shouldn't appear in values)
|
612
|
+
or if the value is a Term from effectful.ops.types.
|
613
|
+
|
614
|
+
Examples:
|
615
|
+
>>> import collections.abc
|
616
|
+
>>> import typing
|
617
|
+
>>> from effectful.internals.unification import nested_type
|
618
|
+
|
619
|
+
# Basic types are returned as their type
|
620
|
+
>>> nested_type(42)
|
621
|
+
<class 'int'>
|
622
|
+
>>> nested_type("hello")
|
623
|
+
<class 'str'>
|
624
|
+
>>> nested_type(3.14)
|
625
|
+
<class 'float'>
|
626
|
+
>>> nested_type(True)
|
627
|
+
<class 'bool'>
|
628
|
+
|
629
|
+
# Type objects pass through unchanged
|
630
|
+
>>> nested_type(int)
|
631
|
+
<class 'int'>
|
632
|
+
>>> nested_type(str)
|
633
|
+
<class 'str'>
|
634
|
+
>>> nested_type(list)
|
635
|
+
<class 'list'>
|
636
|
+
|
637
|
+
# Empty collections return their base type
|
638
|
+
>>> nested_type([])
|
639
|
+
<class 'list'>
|
640
|
+
>>> nested_type({})
|
641
|
+
<class 'dict'>
|
642
|
+
>>> nested_type(set())
|
643
|
+
<class 'set'>
|
644
|
+
|
645
|
+
# Sequences become Sequence[element_type]
|
646
|
+
>>> nested_type([1, 2, 3])
|
647
|
+
collections.abc.MutableSequence[int]
|
648
|
+
>>> nested_type(["a", "b", "c"])
|
649
|
+
collections.abc.MutableSequence[str]
|
650
|
+
|
651
|
+
# Tuples preserve exact structure
|
652
|
+
>>> nested_type((1, "hello", 3.14))
|
653
|
+
tuple[int, str, float]
|
654
|
+
>>> nested_type(())
|
655
|
+
<class 'tuple'>
|
656
|
+
>>> nested_type((1,))
|
657
|
+
tuple[int]
|
658
|
+
|
659
|
+
# Sets become Set[element_type]
|
660
|
+
>>> nested_type({1, 2, 3})
|
661
|
+
collections.abc.MutableSet[int]
|
662
|
+
>>> nested_type({"a", "b"})
|
663
|
+
collections.abc.MutableSet[str]
|
664
|
+
|
665
|
+
# Mappings become Mapping[key_type, value_type]
|
666
|
+
>>> nested_type({"key": "value"})
|
667
|
+
collections.abc.MutableMapping[str, str]
|
668
|
+
>>> nested_type({1: "one", 2: "two"})
|
669
|
+
collections.abc.MutableMapping[int, str]
|
670
|
+
|
671
|
+
# Strings and bytes are NOT treated as sequences
|
672
|
+
>>> nested_type("hello")
|
673
|
+
<class 'str'>
|
674
|
+
>>> nested_type(b"bytes")
|
675
|
+
<class 'bytes'>
|
676
|
+
|
677
|
+
# Annotated functions return types derived from their annotations
|
678
|
+
>>> def annotated_func(x: int) -> str:
|
679
|
+
... return str(x)
|
680
|
+
>>> nested_type(annotated_func)
|
681
|
+
collections.abc.Callable[[int], str]
|
682
|
+
|
683
|
+
# Unannotated functions/callables return their type
|
684
|
+
>>> def f(): pass
|
685
|
+
>>> nested_type(f)
|
686
|
+
<class 'function'>
|
687
|
+
>>> nested_type(lambda x: x)
|
688
|
+
<class 'function'>
|
689
|
+
|
690
|
+
# Generic aliases and union types pass through
|
691
|
+
>>> nested_type(list[int])
|
692
|
+
list[int]
|
693
|
+
>>> nested_type(int | str)
|
694
|
+
int | str
|
695
|
+
"""
|
696
|
+
return type(value)
|
697
|
+
|
698
|
+
|
699
|
+
@nested_type.register
|
700
|
+
def _(value: TypeExpression):
|
701
|
+
return value
|
702
|
+
|
703
|
+
|
704
|
+
@nested_type.register
|
705
|
+
def _(value: effectful.ops.types.Term):
|
706
|
+
raise TypeError(f"Terms should not appear in nested_type, but got {value}")
|
707
|
+
|
708
|
+
|
709
|
+
@nested_type.register
|
710
|
+
def _(value: effectful.ops.types.Operation):
|
711
|
+
typ = nested_type.dispatch(collections.abc.Callable)(value)
|
712
|
+
(arg_types, return_type) = typing.get_args(typ)
|
713
|
+
return effectful.ops.types.Operation[arg_types, return_type] # type: ignore
|
714
|
+
|
715
|
+
|
716
|
+
@nested_type.register
|
717
|
+
def _(value: collections.abc.Callable):
|
718
|
+
if typing.get_overloads(value):
|
719
|
+
return type(value)
|
720
|
+
|
721
|
+
try:
|
722
|
+
sig = inspect.signature(value)
|
723
|
+
except ValueError:
|
724
|
+
return type(value)
|
725
|
+
|
726
|
+
if sig.return_annotation is inspect.Signature.empty:
|
727
|
+
return type(value)
|
728
|
+
elif any(
|
729
|
+
p.annotation is inspect.Parameter.empty
|
730
|
+
or p.kind
|
731
|
+
in {
|
732
|
+
inspect.Parameter.VAR_POSITIONAL,
|
733
|
+
inspect.Parameter.VAR_KEYWORD,
|
734
|
+
inspect.Parameter.KEYWORD_ONLY,
|
735
|
+
}
|
736
|
+
for p in sig.parameters.values()
|
737
|
+
):
|
738
|
+
return collections.abc.Callable[..., sig.return_annotation]
|
739
|
+
else:
|
740
|
+
return collections.abc.Callable[
|
741
|
+
[p.annotation for p in sig.parameters.values()], sig.return_annotation
|
742
|
+
]
|
743
|
+
|
744
|
+
|
745
|
+
@nested_type.register
|
746
|
+
def _(value: collections.abc.Mapping):
|
747
|
+
if value and isinstance(value, effectful.ops.types.Interpretation):
|
748
|
+
return effectful.ops.types.Interpretation
|
749
|
+
|
750
|
+
if len(value) == 0:
|
751
|
+
return type(value)
|
752
|
+
elif len(value) == 1:
|
753
|
+
ktyp = nested_type(next(iter(value.keys())))
|
754
|
+
vtyp = nested_type(next(iter(value.values())))
|
755
|
+
return canonicalize(type(value))[ktyp, vtyp] # type: ignore
|
756
|
+
else:
|
757
|
+
ktyp = functools.reduce(operator.or_, map(nested_type, value.keys()))
|
758
|
+
vtyp = functools.reduce(operator.or_, map(nested_type, value.values()))
|
759
|
+
if isinstance(ktyp, UnionType) or isinstance(vtyp, UnionType):
|
760
|
+
return type(value)
|
761
|
+
else:
|
762
|
+
return canonicalize(type(value))[ktyp, vtyp] # type: ignore
|
763
|
+
|
764
|
+
|
765
|
+
@nested_type.register
|
766
|
+
def _(value: collections.abc.Collection):
|
767
|
+
if len(value) == 0:
|
768
|
+
return type(value)
|
769
|
+
elif len(value) == 1:
|
770
|
+
vtyp = nested_type(next(iter(value)))
|
771
|
+
return canonicalize(type(value))[vtyp] # type: ignore
|
772
|
+
else:
|
773
|
+
valtyp = functools.reduce(operator.or_, map(nested_type, value))
|
774
|
+
if isinstance(valtyp, UnionType):
|
775
|
+
return type(value)
|
776
|
+
else:
|
777
|
+
return canonicalize(type(value))[valtyp] # type: ignore
|
778
|
+
|
779
|
+
|
780
|
+
@nested_type.register
|
781
|
+
def _(value: tuple):
|
782
|
+
return (
|
783
|
+
nested_type.dispatch(collections.abc.Sequence)(value)
|
784
|
+
if type(value) != tuple or len(value) == 0
|
785
|
+
else tuple[tuple(nested_type(item) for item in value)] # type: ignore
|
786
|
+
)
|
787
|
+
|
788
|
+
|
789
|
+
@nested_type.register
|
790
|
+
def _(value: str | bytes | range | None):
|
791
|
+
return type(value)
|
792
|
+
|
793
|
+
|
794
|
+
def freetypevars(typ) -> collections.abc.Set[TypeVariable]:
|
795
|
+
"""
|
796
|
+
Return a set of free type variables in the given type expression.
|
797
|
+
|
798
|
+
This function recursively traverses a type expression to find all TypeVar
|
799
|
+
instances that appear within it. It handles both simple types and generic
|
800
|
+
type aliases with nested type arguments. TypeVars are considered "free"
|
801
|
+
when they are not bound to a specific concrete type.
|
802
|
+
|
803
|
+
Args:
|
804
|
+
typ: The type expression to analyze. Can be a plain type (e.g., int),
|
805
|
+
a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]).
|
806
|
+
|
807
|
+
Returns:
|
808
|
+
A set containing all TypeVar instances found in the type expression.
|
809
|
+
Returns an empty set if no TypeVars are present.
|
810
|
+
|
811
|
+
Examples:
|
812
|
+
>>> T = typing.TypeVar('T')
|
813
|
+
>>> K = typing.TypeVar('K')
|
814
|
+
>>> V = typing.TypeVar('V')
|
815
|
+
|
816
|
+
>>> # TypeVar returns itself
|
817
|
+
>>> freetypevars(T)
|
818
|
+
{~T}
|
819
|
+
|
820
|
+
>>> # Generic type with one TypeVar
|
821
|
+
>>> freetypevars(list[T])
|
822
|
+
{~T}
|
823
|
+
|
824
|
+
>>> # Generic type with multiple TypeVars
|
825
|
+
>>> freetypevars(dict[K, V]) == {K, V}
|
826
|
+
True
|
827
|
+
|
828
|
+
>>> # Nested generic types
|
829
|
+
>>> freetypevars(list[dict[K, V]]) == {K, V}
|
830
|
+
True
|
831
|
+
|
832
|
+
>>> # Concrete types have no free TypeVars
|
833
|
+
>>> freetypevars(int)
|
834
|
+
set()
|
835
|
+
|
836
|
+
>>> # Generic types with concrete arguments have no free TypeVars
|
837
|
+
>>> freetypevars(list[int])
|
838
|
+
set()
|
839
|
+
|
840
|
+
>>> # Mixed concrete and TypeVar arguments
|
841
|
+
>>> freetypevars(dict[str, T])
|
842
|
+
{~T}
|
843
|
+
"""
|
844
|
+
return set(_freetypevars((typ,)))
|
845
|
+
|
846
|
+
|
847
|
+
def substitute(typ, subs: Substitutions) -> TypeExpressions:
|
848
|
+
"""
|
849
|
+
Substitute type variables in a type expression with concrete types.
|
850
|
+
|
851
|
+
This function recursively traverses a type expression and replaces any TypeVar
|
852
|
+
instances found with their corresponding concrete types from the substitution
|
853
|
+
mapping. If a TypeVar is not present in the substitution mapping, it remains
|
854
|
+
unchanged. The function handles nested generic types by recursively substituting
|
855
|
+
in their type arguments.
|
856
|
+
|
857
|
+
Args:
|
858
|
+
typ: The type expression to perform substitution on. Can be a plain type,
|
859
|
+
a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]).
|
860
|
+
subs: A mapping from TypeVar instances to concrete types that should
|
861
|
+
replace them.
|
862
|
+
|
863
|
+
Returns:
|
864
|
+
A new type expression with all mapped TypeVars replaced by their
|
865
|
+
corresponding concrete types.
|
866
|
+
|
867
|
+
Examples:
|
868
|
+
>>> T = typing.TypeVar('T')
|
869
|
+
>>> K = typing.TypeVar('K')
|
870
|
+
>>> V = typing.TypeVar('V')
|
871
|
+
|
872
|
+
>>> # Simple TypeVar substitution
|
873
|
+
>>> substitute(T, {T: int})
|
874
|
+
<class 'int'>
|
875
|
+
|
876
|
+
>>> # Generic type substitution
|
877
|
+
>>> substitute(list[T], {T: str})
|
878
|
+
list[str]
|
879
|
+
|
880
|
+
>>> # Nested generic substitution
|
881
|
+
>>> substitute(dict[K, list[V]], {K: str, V: int})
|
882
|
+
dict[str, list[int]]
|
883
|
+
|
884
|
+
>>> # TypeVar not in mapping remains unchanged
|
885
|
+
>>> substitute(T, {K: int})
|
886
|
+
~T
|
887
|
+
|
888
|
+
>>> # Non-generic types pass through unchanged
|
889
|
+
>>> substitute(int, {T: str})
|
890
|
+
<class 'int'>
|
891
|
+
"""
|
892
|
+
if isinstance(typ, typing.TypeVar | typing.ParamSpec | typing.TypeVarTuple):
|
893
|
+
return substitute(subs[typ], subs) if typ in subs else typ
|
894
|
+
elif isinstance(typ, list | tuple):
|
895
|
+
return type(typ)(substitute(item, subs) for item in typ)
|
896
|
+
elif any(fv in subs for fv in freetypevars(typ)):
|
897
|
+
args = tuple(subs.get(fv, fv) for fv in _freetypevars((typ,)))
|
898
|
+
return substitute(typ[args], subs)
|
899
|
+
else:
|
900
|
+
return typ
|