typingkit 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typingkit/__init__.py +11 -0
- typingkit/_typed/__init__.py +11 -0
- typingkit/_typed/_debug.py +37 -0
- typingkit/_typed/context.py +203 -0
- typingkit/_typed/dimexpr.py +154 -0
- typingkit/_typed/factory.py +71 -0
- typingkit/_typed/generics.py +50 -0
- typingkit/_typed/helpers.py +248 -0
- typingkit/_typed/list.py +206 -0
- typingkit/_typed/ndarray.py +513 -0
- typingkit/py.typed +0 -0
- typingkit-0.2.2.dist-info/METADATA +92 -0
- typingkit-0.2.2.dist-info/RECORD +14 -0
- typingkit-0.2.2.dist-info/WHEEL +4 -0
typingkit/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Debugging utilities
|
|
3
|
+
=======
|
|
4
|
+
"""
|
|
5
|
+
# src/typingkit/_typed/debug.py
|
|
6
|
+
|
|
7
|
+
from typing import Any, TypeVar, get_args, get_origin
|
|
8
|
+
|
|
9
|
+
from rich.tree import Tree
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def diagnostic(obj: Any, pfx: str | None = None) -> Tree:
|
|
13
|
+
_pfx = f"[bold cyan]{pfx}[/] " if pfx else ""
|
|
14
|
+
tree = Tree(
|
|
15
|
+
f"{_pfx}[yellow]obj[/]=[green]{obj!r}[/], "
|
|
16
|
+
f"[yellow]type[/]=[magenta]{type(obj).__name__}[/]"
|
|
17
|
+
)
|
|
18
|
+
match obj:
|
|
19
|
+
case tuple():
|
|
20
|
+
for x in obj: # type: ignore
|
|
21
|
+
tree.add(diagnostic(x))
|
|
22
|
+
|
|
23
|
+
case TypeVar():
|
|
24
|
+
tree.add(diagnostic(obj.__bound__, "__bound__:"))
|
|
25
|
+
tree.add(diagnostic(obj.__constraints__, "__constraints__:"))
|
|
26
|
+
tree.add(diagnostic(obj.__default__, "__default__:"))
|
|
27
|
+
tree.add(diagnostic(obj.__covariant__, "__covariant__:"))
|
|
28
|
+
tree.add(diagnostic(obj.__contravariant__, "__contravariant__:"))
|
|
29
|
+
|
|
30
|
+
# GenericAlias | Literal | UnionType
|
|
31
|
+
case _ if (origin := get_origin(obj)) is not None:
|
|
32
|
+
tree.add(diagnostic(origin, "origin:"))
|
|
33
|
+
tree.add(diagnostic(get_args(obj), "args:"))
|
|
34
|
+
|
|
35
|
+
case _:
|
|
36
|
+
pass
|
|
37
|
+
return tree
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context binding
|
|
3
|
+
=======
|
|
4
|
+
Manages TypeVar binding contexts for shape validation.
|
|
5
|
+
"""
|
|
6
|
+
# src/typingkit/_typed/context.py
|
|
7
|
+
|
|
8
|
+
# pyright: reportPrivateUsage = false
|
|
9
|
+
|
|
10
|
+
import inspect
|
|
11
|
+
from contextvars import ContextVar
|
|
12
|
+
from functools import wraps
|
|
13
|
+
from typing import (
|
|
14
|
+
Any,
|
|
15
|
+
Callable,
|
|
16
|
+
Concatenate,
|
|
17
|
+
ParamSpec,
|
|
18
|
+
TypeVar,
|
|
19
|
+
get_args,
|
|
20
|
+
get_origin,
|
|
21
|
+
get_type_hints,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from typingkit._typed.ndarray import (
|
|
25
|
+
DimensionError,
|
|
26
|
+
_validate_shape,
|
|
27
|
+
_validate_shape_against_contexts,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Context variables
|
|
31
|
+
|
|
32
|
+
_class_typevar_context = ContextVar[dict[int, dict[TypeVar, int]]](
|
|
33
|
+
"_class_typevar_context", default=dict[int, dict[TypeVar, int]]()
|
|
34
|
+
)
|
|
35
|
+
_method_typevar_context = ContextVar[dict[TypeVar, int]](
|
|
36
|
+
"_method_typevar_context", default=dict[TypeVar, int]()
|
|
37
|
+
)
|
|
38
|
+
_active_class_context = ContextVar[dict[TypeVar, int]](
|
|
39
|
+
"_active_class_context", default=dict[TypeVar, int]()
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
T = TypeVar("T")
|
|
43
|
+
P = ParamSpec("P") # ParamSpec for function parameters
|
|
44
|
+
R = TypeVar("R") # TypeVar for return type
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _extract_shape_dims(annotation: Any) -> tuple[Any, ...] | None:
|
|
48
|
+
"""Return shape dimension spec tuple[...] from TypedNDArray annotation."""
|
|
49
|
+
|
|
50
|
+
origin = get_origin(annotation)
|
|
51
|
+
if origin is None:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
args = get_args(annotation)
|
|
55
|
+
if not args:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
shape_spec = args[0]
|
|
59
|
+
if get_origin(shape_spec) is not tuple:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
return get_args(shape_spec)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _validate_and_bind(
|
|
66
|
+
*,
|
|
67
|
+
shape_dims: tuple[Any, ...],
|
|
68
|
+
actual_shape: tuple[int, ...],
|
|
69
|
+
owner_cls: type,
|
|
70
|
+
func_name: str,
|
|
71
|
+
param_name: str | None,
|
|
72
|
+
class_context: dict[TypeVar, int],
|
|
73
|
+
method_context: dict[TypeVar, int],
|
|
74
|
+
) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Performs:
|
|
77
|
+
1. Structural validation (_validate_shape)
|
|
78
|
+
2. TypeVar binding + consistency checks
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# Structural validation (Literal, int, rank, repeated TypeVar)
|
|
82
|
+
_validate_shape(shape_dims, actual_shape)
|
|
83
|
+
|
|
84
|
+
for dim_idx, dim in enumerate(shape_dims):
|
|
85
|
+
if not isinstance(dim, TypeVar):
|
|
86
|
+
continue
|
|
87
|
+
if dim_idx >= len(actual_shape):
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
actual_dim = actual_shape[dim_idx]
|
|
91
|
+
is_class_level = _is_class_level_typevar(dim, owner_cls)
|
|
92
|
+
context = class_context if is_class_level else method_context
|
|
93
|
+
|
|
94
|
+
if dim in context:
|
|
95
|
+
expected_dim = context[dim]
|
|
96
|
+
if actual_dim != expected_dim:
|
|
97
|
+
level = "class" if is_class_level else "method"
|
|
98
|
+
location = (
|
|
99
|
+
f"parameter `{param_name}`"
|
|
100
|
+
if param_name is not None
|
|
101
|
+
else "return value"
|
|
102
|
+
)
|
|
103
|
+
raise DimensionError(
|
|
104
|
+
f"In {func_name}(...), {location} "
|
|
105
|
+
f"dimension {dim_idx} [{dim}] "
|
|
106
|
+
f"expected {expected_dim} ({level}-level binding), "
|
|
107
|
+
f"got {actual_dim}"
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
context[dim] = actual_dim
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _is_class_level_typevar(typevar: TypeVar, owner_cls: type) -> bool:
|
|
114
|
+
"""Check if a TypeVar is bound at class level vs method level."""
|
|
115
|
+
cls_params = getattr(owner_cls, "__parameters__", ())
|
|
116
|
+
return typevar in cls_params
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _get_instance_class_context(instance: Any) -> dict[TypeVar, int]:
|
|
120
|
+
"""Get or create the class-level TypeVar binding context for an instance."""
|
|
121
|
+
ctx = _class_typevar_context.get()
|
|
122
|
+
instance_id = id(instance)
|
|
123
|
+
if instance_id not in ctx:
|
|
124
|
+
ctx = ctx.copy()
|
|
125
|
+
ctx[instance_id] = {}
|
|
126
|
+
_class_typevar_context.set(ctx)
|
|
127
|
+
return ctx[instance_id]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def enforce_shapes(
|
|
131
|
+
func: Callable[Concatenate[T, P], R],
|
|
132
|
+
) -> Callable[Concatenate[T, P], R]:
|
|
133
|
+
"""
|
|
134
|
+
Decorator to automatically validate TypeVar shape bindings.
|
|
135
|
+
- Class-level TypeVars (from Generic[T]) are bound per-instance, persist across calls
|
|
136
|
+
- Method-level TypeVars are validated per-call only, local to single invocation
|
|
137
|
+
- Validates both parameter and return types
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
@wraps(func)
|
|
141
|
+
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
|
142
|
+
hints = get_type_hints(func, include_extras=True)
|
|
143
|
+
sig = inspect.signature(func)
|
|
144
|
+
bound_args = sig.bind(self, *args, **kwargs)
|
|
145
|
+
bound_args.apply_defaults()
|
|
146
|
+
|
|
147
|
+
owner_cls = self.__class__
|
|
148
|
+
class_context = _get_instance_class_context(self)
|
|
149
|
+
method_context = dict[TypeVar, int]()
|
|
150
|
+
|
|
151
|
+
# Validate arguments
|
|
152
|
+
for param_name, param_value in bound_args.arguments.items():
|
|
153
|
+
if param_name == "self":
|
|
154
|
+
continue
|
|
155
|
+
if param_name not in hints:
|
|
156
|
+
continue
|
|
157
|
+
if not hasattr(param_value, "shape"):
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
shape_dims = _extract_shape_dims(hints[param_name])
|
|
161
|
+
if shape_dims is None:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
_validate_and_bind(
|
|
165
|
+
shape_dims=shape_dims,
|
|
166
|
+
actual_shape=param_value.shape,
|
|
167
|
+
owner_cls=owner_cls,
|
|
168
|
+
func_name=func.__name__,
|
|
169
|
+
param_name=param_name,
|
|
170
|
+
class_context=class_context,
|
|
171
|
+
method_context=method_context,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Execute function with active contexts
|
|
175
|
+
method_token = _method_typevar_context.set(method_context)
|
|
176
|
+
class_token = _active_class_context.set(class_context)
|
|
177
|
+
try:
|
|
178
|
+
result = func(self, *args, **kwargs)
|
|
179
|
+
finally:
|
|
180
|
+
_method_typevar_context.reset(method_token)
|
|
181
|
+
_active_class_context.reset(class_token)
|
|
182
|
+
|
|
183
|
+
# Validate return
|
|
184
|
+
if "return" in hints and result is not None and hasattr(result, "shape"):
|
|
185
|
+
shape_dims = _extract_shape_dims(hints["return"])
|
|
186
|
+
actual_shape = getattr(result, "shape")
|
|
187
|
+
if shape_dims is not None:
|
|
188
|
+
_validate_and_bind(
|
|
189
|
+
shape_dims=shape_dims,
|
|
190
|
+
actual_shape=actual_shape,
|
|
191
|
+
owner_cls=owner_cls,
|
|
192
|
+
func_name=func.__name__,
|
|
193
|
+
param_name=None,
|
|
194
|
+
class_context=class_context,
|
|
195
|
+
method_context=method_context,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# context-aware validation (cross-call)
|
|
199
|
+
_validate_shape_against_contexts(shape_dims, actual_shape)
|
|
200
|
+
|
|
201
|
+
return result
|
|
202
|
+
|
|
203
|
+
return wrapper
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DimExpr
|
|
3
|
+
=======
|
|
4
|
+
"""
|
|
5
|
+
# src/typingkit/_typed/dimexpr.py
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Generic,
|
|
11
|
+
Literal,
|
|
12
|
+
NoReturn,
|
|
13
|
+
TypeAlias,
|
|
14
|
+
TypeAliasType,
|
|
15
|
+
TypeVar,
|
|
16
|
+
get_args,
|
|
17
|
+
get_origin,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
Arg1 = TypeVar("Arg1", bound=int)
|
|
21
|
+
Arg2 = TypeVar("Arg2", bound=int)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
## Base
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DimExpr(int):
|
|
28
|
+
def __new__(cls, *args: Any) -> NoReturn:
|
|
29
|
+
raise TypeError("Shape expressions cannot be instantiated")
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def expr(cls, args: tuple[Any, ...], /) -> int:
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UnaryOp(Generic[Arg1], DimExpr):
|
|
37
|
+
@classmethod
|
|
38
|
+
def expr(cls, args: tuple[Arg1], /) -> int:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BinaryOp(Generic[Arg1, Arg2], DimExpr):
|
|
43
|
+
@classmethod
|
|
44
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
## Core operations
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Neg(UnaryOp[Arg1]):
|
|
52
|
+
@classmethod
|
|
53
|
+
def expr(cls, args: tuple[Arg1], /) -> int:
|
|
54
|
+
return -args[0]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Add(BinaryOp[Arg1, Arg2]):
|
|
58
|
+
@classmethod
|
|
59
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
60
|
+
return args[0] + args[1]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Mul(BinaryOp[Arg1, Arg2]):
|
|
64
|
+
@classmethod
|
|
65
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
66
|
+
return args[0] * args[1]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Pow(BinaryOp[Arg1, Arg2]):
|
|
70
|
+
@classmethod
|
|
71
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
72
|
+
return args[0] ** args[1]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
## Additional operations
|
|
76
|
+
|
|
77
|
+
# Through TypeAliases
|
|
78
|
+
Sub: TypeAlias = Add[Arg1, Neg[Arg2]]
|
|
79
|
+
PlusOne: TypeAlias = Add[Arg1, Literal[1]]
|
|
80
|
+
|
|
81
|
+
# Through PEP-695 style TypeAliasTypes
|
|
82
|
+
type MinusOne[Arg1: int] = Sub[Arg1, Literal[1]]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Through Subclassing
|
|
86
|
+
class Squared(Mul[Arg1, Arg1]): ...
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Custom operations
|
|
90
|
+
# Can also directly define with a `expr`, with some custom logic, say
|
|
91
|
+
class Cubed(UnaryOp[Arg1]):
|
|
92
|
+
@classmethod
|
|
93
|
+
def expr(cls, args: tuple[Arg1], /) -> int:
|
|
94
|
+
return args[0] ** 3
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Log(BinaryOp[Arg1, Arg2]):
|
|
98
|
+
@classmethod
|
|
99
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
100
|
+
x = math.log(args[0], args[1])
|
|
101
|
+
xi = round(x)
|
|
102
|
+
if math.isclose(x, xi, rel_tol=0, abs_tol=1e-12):
|
|
103
|
+
return xi
|
|
104
|
+
raise TypeError("Invalid dimension. Not an integer.")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
Log2 = Log[Arg1, Literal[2]]
|
|
108
|
+
Log10 = Log[Arg1, Literal[10]]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
## Evaluation
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _resolve_dim(tp: Any) -> Any:
|
|
115
|
+
# type[Any] / type[int] / EllipsisType / TypeVar
|
|
116
|
+
if tp is Any or tp is int or tp is Ellipsis or isinstance(tp, TypeVar):
|
|
117
|
+
return tp
|
|
118
|
+
|
|
119
|
+
origin = get_origin(tp)
|
|
120
|
+
|
|
121
|
+
# Literal[N]
|
|
122
|
+
if origin is Literal:
|
|
123
|
+
# [TODO] How should we handle multiple args case?
|
|
124
|
+
return get_args(tp)[0]
|
|
125
|
+
|
|
126
|
+
# TypeAliasType
|
|
127
|
+
if isinstance(origin, TypeAliasType):
|
|
128
|
+
return _resolve_dim(origin.__value__[get_args(tp)])
|
|
129
|
+
|
|
130
|
+
# DimExpr
|
|
131
|
+
if origin and issubclass(origin, DimExpr):
|
|
132
|
+
args = get_args(tp)
|
|
133
|
+
|
|
134
|
+
# Case 1: class defines its own expr
|
|
135
|
+
if "expr" in origin.__dict__:
|
|
136
|
+
values = tuple(_resolve_dim(arg) for arg in args)
|
|
137
|
+
if not all(isinstance(v, int) for v in values):
|
|
138
|
+
return origin[values] # pyright: ignore[reportInvalidTypeArguments]
|
|
139
|
+
return origin.expr(values)
|
|
140
|
+
|
|
141
|
+
# Case 2: subclassing
|
|
142
|
+
if hasattr(origin, "__orig_bases__"):
|
|
143
|
+
for base in getattr(origin, "__orig_bases__"):
|
|
144
|
+
base_origin = get_origin(base)
|
|
145
|
+
if base_origin and issubclass(base_origin, DimExpr):
|
|
146
|
+
base_args = get_args(base)
|
|
147
|
+
|
|
148
|
+
parameters = getattr(origin, "__parameters__", ())
|
|
149
|
+
param_map = dict(zip(parameters, args))
|
|
150
|
+
substituted = tuple(param_map.get(a, a) for a in base_args)
|
|
151
|
+
|
|
152
|
+
return _resolve_dim(base_origin[substituted]) # pyright: ignore[reportInvalidTypeArguments]
|
|
153
|
+
|
|
154
|
+
raise TypeError(f"Cannot evaluate {tp}")
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shape factory for TypedNDArray
|
|
3
|
+
=======
|
|
4
|
+
"""
|
|
5
|
+
# src/typingkit/_typed/factory.py
|
|
6
|
+
|
|
7
|
+
# pyright: reportPrivateUsage = false
|
|
8
|
+
|
|
9
|
+
from types import GenericAlias
|
|
10
|
+
from typing import Any, Generic, TypeVar
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
|
|
15
|
+
from typingkit._typed.ndarray import TypedNDArray, _AnyShape
|
|
16
|
+
|
|
17
|
+
_ShapeT = TypeVar("_ShapeT", bound=_AnyShape)
|
|
18
|
+
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
|
|
19
|
+
_NewScalarT = TypeVar("_NewScalarT", bound=np.generic)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _ShapeBuilder(Generic[_ShapeT, _ScalarT]):
|
|
23
|
+
__slots__ = ("_shape_spec", "_dtype_spec")
|
|
24
|
+
|
|
25
|
+
def __init__(self, shape_spec: GenericAlias, dtype_spec: GenericAlias):
|
|
26
|
+
self._shape_spec = shape_spec
|
|
27
|
+
self._dtype_spec = dtype_spec
|
|
28
|
+
|
|
29
|
+
def dtype(self, dtype: type[_NewScalarT]) -> "_ShapeBuilder[_ShapeT, _NewScalarT]":
|
|
30
|
+
return _ShapeBuilder(self._shape_spec, GenericAlias(np.dtype, (dtype,)))
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self, object: npt.ArrayLike
|
|
34
|
+
) -> TypedNDArray[_ShapeT, np.dtype[_ScalarT]]:
|
|
35
|
+
dtype = self._dtype_spec.__args__[0]
|
|
36
|
+
dtype = None if dtype is Any else dtype
|
|
37
|
+
return TypedNDArray[self._shape_spec](object, dtype) # type: ignore
|
|
38
|
+
|
|
39
|
+
def __repr__(self) -> str:
|
|
40
|
+
return f"TypedNDArrayFactory[{self._shape_spec.__args__}, {self._dtype_spec.__args__[0]}]"
|
|
41
|
+
|
|
42
|
+
## Convenience helpers
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def float64(self) -> "_ShapeBuilder[_ShapeT, np.float64]":
|
|
46
|
+
return self.dtype(np.float64)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def float32(self) -> "_ShapeBuilder[_ShapeT, np.float32]":
|
|
50
|
+
return self.dtype(np.float32)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def int64(self) -> "_ShapeBuilder[_ShapeT, np.int64]":
|
|
54
|
+
return self.dtype(np.int64)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class _ShapeFactory:
|
|
58
|
+
def __getitem__(self, dims: _ShapeT) -> _ShapeBuilder[_ShapeT, Any]:
|
|
59
|
+
if not isinstance(dims, tuple): # pyright: ignore[reportUnnecessaryIsInstance]
|
|
60
|
+
dims = (dims,)
|
|
61
|
+
|
|
62
|
+
shape_spec = GenericAlias(tuple, dims)
|
|
63
|
+
dtype_spec = GenericAlias(np.dtype, (Any,))
|
|
64
|
+
return _ShapeBuilder(shape_spec, dtype_spec)
|
|
65
|
+
|
|
66
|
+
# Prefer [...] rather than (...), but this is also provided
|
|
67
|
+
def __call__(self, dims: _ShapeT) -> _ShapeBuilder[_ShapeT, Any]:
|
|
68
|
+
return self.__getitem__(dims)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
Shaped = _ShapeFactory()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generics
|
|
3
|
+
=======
|
|
4
|
+
"""
|
|
5
|
+
# src/typingkit/_typed/generics.py
|
|
6
|
+
|
|
7
|
+
from types import GenericAlias
|
|
8
|
+
from typing import Any, Generic, Self, TypeVarTuple, Unpack, cast, get_args, get_origin
|
|
9
|
+
|
|
10
|
+
Ts = TypeVarTuple("Ts", default=Unpack[tuple[Any, ...]])
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RuntimeGeneric(Generic[Unpack[Ts]]):
|
|
14
|
+
@classmethod
|
|
15
|
+
def __class_getitem__(cls, item: Any, /) -> GenericAlias:
|
|
16
|
+
# [HACK] Misuses __class_getitem__
|
|
17
|
+
# See https://docs.python.org/3/reference/datamodel.html#the-purpose-of-class-getitem
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
ga = cast(GenericAlias, super().__class_getitem__(item)) # type: ignore[misc]
|
|
21
|
+
except: # noqa: E722
|
|
22
|
+
# Fallback if superclass does not implement `__class_getitem__`
|
|
23
|
+
ga = GenericAlias(cls, item)
|
|
24
|
+
return _RuntimeGenericAlias.from_generic_alias(ga)
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def __pre_new__(cls, alias: GenericAlias, *args: Any, **kwargs: Any) -> Self:
|
|
28
|
+
return cls(*args, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _RuntimeGenericAlias(GenericAlias):
|
|
32
|
+
"""
|
|
33
|
+
Deferred RuntimeGeneric constructor.
|
|
34
|
+
Enables progressive type specialisation, behaving like a type-level curry.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_generic_alias(cls, alias: GenericAlias) -> Self:
|
|
39
|
+
origin = get_origin(alias)
|
|
40
|
+
typeargs = get_args(alias)
|
|
41
|
+
return cls(origin, typeargs)
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, typeargs: Any) -> Self:
|
|
44
|
+
ga = super().__getitem__(typeargs)
|
|
45
|
+
return type(self).from_generic_alias(ga)
|
|
46
|
+
|
|
47
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
48
|
+
origin: type[RuntimeGeneric] = get_origin(self)
|
|
49
|
+
obj = origin.__pre_new__(self, *args, **kwargs)
|
|
50
|
+
return obj
|