typingkit 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typingkit-0.2.2/PKG-INFO +92 -0
- typingkit-0.2.2/README.md +82 -0
- typingkit-0.2.2/pyproject.toml +27 -0
- typingkit-0.2.2/src/typingkit/__init__.py +11 -0
- typingkit-0.2.2/src/typingkit/_typed/__init__.py +11 -0
- typingkit-0.2.2/src/typingkit/_typed/_debug.py +37 -0
- typingkit-0.2.2/src/typingkit/_typed/context.py +203 -0
- typingkit-0.2.2/src/typingkit/_typed/dimexpr.py +154 -0
- typingkit-0.2.2/src/typingkit/_typed/factory.py +71 -0
- typingkit-0.2.2/src/typingkit/_typed/generics.py +50 -0
- typingkit-0.2.2/src/typingkit/_typed/helpers.py +248 -0
- typingkit-0.2.2/src/typingkit/_typed/list.py +206 -0
- typingkit-0.2.2/src/typingkit/_typed/ndarray.py +513 -0
- typingkit-0.2.2/src/typingkit/py.typed +0 -0
typingkit-0.2.2/PKG-INFO
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: typingkit
|
|
3
|
+
Version: 0.2.2
|
|
4
|
+
Summary: Python strong typing suite, along with Typed NumPy: Static shape typing and runtime shape validation.
|
|
5
|
+
Author: Ashrith Sagar
|
|
6
|
+
Author-email: Ashrith Sagar <ashrith9sagar@gmail.com>
|
|
7
|
+
Requires-Dist: numpy>=2.2
|
|
8
|
+
Requires-Python: >=3.13
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
|
|
11
|
+
# typingkit
|
|
12
|
+
|
|
13
|
+
[](https://github.com/astral-sh/ruff)
|
|
14
|
+
|
|
15
|
+
Python strong typing suite, along with Typed NumPy: Static shape typing and runtime shape validation.
|
|
16
|
+
|
|
17
|
+
> [!WARNING]
|
|
18
|
+
> Experimental & WIP.
|
|
19
|
+
> See [USAGE.md](USAGE.md) for more details.
|
|
20
|
+
|
|
21
|
+
## Installation
|
|
22
|
+
|
|
23
|
+
<details>
|
|
24
|
+
|
|
25
|
+
<summary>Install uv (optional, recommended)</summary>
|
|
26
|
+
|
|
27
|
+
Install [`uv`](https://docs.astral.sh/uv/), if not already.
|
|
28
|
+
Check [here](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions.
|
|
29
|
+
|
|
30
|
+
It is recommended to use `uv`, as it will automatically install the dependencies in a virtual environment.
|
|
31
|
+
If you don't want to use `uv`, skip to the next step.
|
|
32
|
+
|
|
33
|
+
**TL;DR: Just run**
|
|
34
|
+
|
|
35
|
+
```shell
|
|
36
|
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
</details>
|
|
40
|
+
|
|
41
|
+
<details>
|
|
42
|
+
|
|
43
|
+
<summary>Install the package</summary>
|
|
44
|
+
|
|
45
|
+
The dependencies are listed in the [pyproject.toml](pyproject.toml) file.
|
|
46
|
+
At present, the only required dependency is `numpy`.
|
|
47
|
+
|
|
48
|
+
Install the package from the PyPI release:
|
|
49
|
+
|
|
50
|
+
```shell
|
|
51
|
+
# Using uv
|
|
52
|
+
uv add typingkit
|
|
53
|
+
|
|
54
|
+
# Or with pip
|
|
55
|
+
pip3 install typingkit
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
To install from the latest commit:
|
|
59
|
+
|
|
60
|
+
```shell
|
|
61
|
+
uv add git+https://github.com/AshrithSagar/typingkit.git@main
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
</details>
|
|
65
|
+
|
|
66
|
+
## Usage
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
from typing import TypeVar
|
|
70
|
+
|
|
71
|
+
from typingkit._typed.ndarray import TypedNDArray
|
|
72
|
+
|
|
73
|
+
# Shape variables are just regular TypeVar's
|
|
74
|
+
N = TypeVar("N", bound=int, default=int)
|
|
75
|
+
M = TypeVar("M", bound=int, default=int)
|
|
76
|
+
|
|
77
|
+
# Create aliases such as these, or use TypedNDArray directly
|
|
78
|
+
Vector = TypedNDArray[tuple[N]]
|
|
79
|
+
Matrix = TypedNDArray[tuple[M, N]]
|
|
80
|
+
|
|
81
|
+
v1 = Vector([1, 2, 3]) # Passes
|
|
82
|
+
v2 = Vector([4, 5, 6, 7]) # Also passes
|
|
83
|
+
|
|
84
|
+
v3 = TypedNDArray[tuple[int]]([[8, 9]])
|
|
85
|
+
# Fails, since expected 1D array but passed in a 2D array
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
See [USAGE.md](USAGE.md) for more details.
|
|
89
|
+
|
|
90
|
+
## License
|
|
91
|
+
|
|
92
|
+
This project falls under the [MIT License](LICENSE).
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# typingkit
|
|
2
|
+
|
|
3
|
+
[](https://github.com/astral-sh/ruff)
|
|
4
|
+
|
|
5
|
+
Python strong typing suite, along with Typed NumPy: Static shape typing and runtime shape validation.
|
|
6
|
+
|
|
7
|
+
> [!WARNING]
|
|
8
|
+
> Experimental & WIP.
|
|
9
|
+
> See [USAGE.md](USAGE.md) for more details.
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
|
|
13
|
+
<details>
|
|
14
|
+
|
|
15
|
+
<summary>Install uv (optional, recommended)</summary>
|
|
16
|
+
|
|
17
|
+
Install [`uv`](https://docs.astral.sh/uv/), if not already.
|
|
18
|
+
Check [here](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions.
|
|
19
|
+
|
|
20
|
+
It is recommended to use `uv`, as it will automatically install the dependencies in a virtual environment.
|
|
21
|
+
If you don't want to use `uv`, skip to the next step.
|
|
22
|
+
|
|
23
|
+
**TL;DR: Just run**
|
|
24
|
+
|
|
25
|
+
```shell
|
|
26
|
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
</details>
|
|
30
|
+
|
|
31
|
+
<details>
|
|
32
|
+
|
|
33
|
+
<summary>Install the package</summary>
|
|
34
|
+
|
|
35
|
+
The dependencies are listed in the [pyproject.toml](pyproject.toml) file.
|
|
36
|
+
At present, the only required dependency is `numpy`.
|
|
37
|
+
|
|
38
|
+
Install the package from the PyPI release:
|
|
39
|
+
|
|
40
|
+
```shell
|
|
41
|
+
# Using uv
|
|
42
|
+
uv add typingkit
|
|
43
|
+
|
|
44
|
+
# Or with pip
|
|
45
|
+
pip3 install typingkit
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
To install from the latest commit:
|
|
49
|
+
|
|
50
|
+
```shell
|
|
51
|
+
uv add git+https://github.com/AshrithSagar/typingkit.git@main
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
</details>
|
|
55
|
+
|
|
56
|
+
## Usage
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
from typing import TypeVar
|
|
60
|
+
|
|
61
|
+
from typingkit._typed.ndarray import TypedNDArray
|
|
62
|
+
|
|
63
|
+
# Shape variables are just regular TypeVar's
|
|
64
|
+
N = TypeVar("N", bound=int, default=int)
|
|
65
|
+
M = TypeVar("M", bound=int, default=int)
|
|
66
|
+
|
|
67
|
+
# Create aliases such as these, or use TypedNDArray directly
|
|
68
|
+
Vector = TypedNDArray[tuple[N]]
|
|
69
|
+
Matrix = TypedNDArray[tuple[M, N]]
|
|
70
|
+
|
|
71
|
+
v1 = Vector([1, 2, 3]) # Passes
|
|
72
|
+
v2 = Vector([4, 5, 6, 7]) # Also passes
|
|
73
|
+
|
|
74
|
+
v3 = TypedNDArray[tuple[int]]([[8, 9]])
|
|
75
|
+
# Fails, since expected 1D array but passed in a 2D array
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
See [USAGE.md](USAGE.md) for more details.
|
|
79
|
+
|
|
80
|
+
## License
|
|
81
|
+
|
|
82
|
+
This project falls under the [MIT License](LICENSE).
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "typingkit"
|
|
3
|
+
version = "0.2.2"
|
|
4
|
+
description = "Python strong typing suite, along with Typed NumPy: Static shape typing and runtime shape validation."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [{ name = "Ashrith Sagar", email = "ashrith9sagar@gmail.com" }]
|
|
7
|
+
requires-python = ">=3.13"
|
|
8
|
+
dependencies = ["numpy>=2.2"]
|
|
9
|
+
|
|
10
|
+
[build-system]
|
|
11
|
+
requires = ["uv_build>=0.10.7"]
|
|
12
|
+
build-backend = "uv_build"
|
|
13
|
+
|
|
14
|
+
[dependency-groups]
|
|
15
|
+
dev = ["rich>=14.3.3"]
|
|
16
|
+
test = ["pytest>=9.0.2", "pytest-memray>=1.8.0"]
|
|
17
|
+
lint = ["ruff>=0.15.4"]
|
|
18
|
+
typecheck = [
|
|
19
|
+
"basedpyright>=1.38.2",
|
|
20
|
+
"mypy[mypyc]>=1.19.1",
|
|
21
|
+
"pyrefly>=0.55.0",
|
|
22
|
+
"pyright>=1.1.408",
|
|
23
|
+
"ty>=0.0.20",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[tool.uv]
|
|
27
|
+
default-groups = "all"
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Debugging utilities
|
|
3
|
+
=======
|
|
4
|
+
"""
|
|
5
|
+
# src/typingkit/_typed/debug.py
|
|
6
|
+
|
|
7
|
+
from typing import Any, TypeVar, get_args, get_origin
|
|
8
|
+
|
|
9
|
+
from rich.tree import Tree
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def diagnostic(obj: Any, pfx: str | None = None) -> Tree:
|
|
13
|
+
_pfx = f"[bold cyan]{pfx}[/] " if pfx else ""
|
|
14
|
+
tree = Tree(
|
|
15
|
+
f"{_pfx}[yellow]obj[/]=[green]{obj!r}[/], "
|
|
16
|
+
f"[yellow]type[/]=[magenta]{type(obj).__name__}[/]"
|
|
17
|
+
)
|
|
18
|
+
match obj:
|
|
19
|
+
case tuple():
|
|
20
|
+
for x in obj: # type: ignore
|
|
21
|
+
tree.add(diagnostic(x))
|
|
22
|
+
|
|
23
|
+
case TypeVar():
|
|
24
|
+
tree.add(diagnostic(obj.__bound__, "__bound__:"))
|
|
25
|
+
tree.add(diagnostic(obj.__constraints__, "__constraints__:"))
|
|
26
|
+
tree.add(diagnostic(obj.__default__, "__default__:"))
|
|
27
|
+
tree.add(diagnostic(obj.__covariant__, "__covariant__:"))
|
|
28
|
+
tree.add(diagnostic(obj.__contravariant__, "__contravariant__:"))
|
|
29
|
+
|
|
30
|
+
# GenericAlias | Literal | UnionType
|
|
31
|
+
case _ if (origin := get_origin(obj)) is not None:
|
|
32
|
+
tree.add(diagnostic(origin, "origin:"))
|
|
33
|
+
tree.add(diagnostic(get_args(obj), "args:"))
|
|
34
|
+
|
|
35
|
+
case _:
|
|
36
|
+
pass
|
|
37
|
+
return tree
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context binding
|
|
3
|
+
=======
|
|
4
|
+
Manages TypeVar binding contexts for shape validation.
|
|
5
|
+
"""
|
|
6
|
+
# src/typingkit/_typed/context.py
|
|
7
|
+
|
|
8
|
+
# pyright: reportPrivateUsage = false
|
|
9
|
+
|
|
10
|
+
import inspect
|
|
11
|
+
from contextvars import ContextVar
|
|
12
|
+
from functools import wraps
|
|
13
|
+
from typing import (
|
|
14
|
+
Any,
|
|
15
|
+
Callable,
|
|
16
|
+
Concatenate,
|
|
17
|
+
ParamSpec,
|
|
18
|
+
TypeVar,
|
|
19
|
+
get_args,
|
|
20
|
+
get_origin,
|
|
21
|
+
get_type_hints,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from typingkit._typed.ndarray import (
|
|
25
|
+
DimensionError,
|
|
26
|
+
_validate_shape,
|
|
27
|
+
_validate_shape_against_contexts,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Context variables
|
|
31
|
+
|
|
32
|
+
_class_typevar_context = ContextVar[dict[int, dict[TypeVar, int]]](
|
|
33
|
+
"_class_typevar_context", default=dict[int, dict[TypeVar, int]]()
|
|
34
|
+
)
|
|
35
|
+
_method_typevar_context = ContextVar[dict[TypeVar, int]](
|
|
36
|
+
"_method_typevar_context", default=dict[TypeVar, int]()
|
|
37
|
+
)
|
|
38
|
+
_active_class_context = ContextVar[dict[TypeVar, int]](
|
|
39
|
+
"_active_class_context", default=dict[TypeVar, int]()
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
T = TypeVar("T")
|
|
43
|
+
P = ParamSpec("P") # ParamSpec for function parameters
|
|
44
|
+
R = TypeVar("R") # TypeVar for return type
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _extract_shape_dims(annotation: Any) -> tuple[Any, ...] | None:
|
|
48
|
+
"""Return shape dimension spec tuple[...] from TypedNDArray annotation."""
|
|
49
|
+
|
|
50
|
+
origin = get_origin(annotation)
|
|
51
|
+
if origin is None:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
args = get_args(annotation)
|
|
55
|
+
if not args:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
shape_spec = args[0]
|
|
59
|
+
if get_origin(shape_spec) is not tuple:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
return get_args(shape_spec)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _validate_and_bind(
|
|
66
|
+
*,
|
|
67
|
+
shape_dims: tuple[Any, ...],
|
|
68
|
+
actual_shape: tuple[int, ...],
|
|
69
|
+
owner_cls: type,
|
|
70
|
+
func_name: str,
|
|
71
|
+
param_name: str | None,
|
|
72
|
+
class_context: dict[TypeVar, int],
|
|
73
|
+
method_context: dict[TypeVar, int],
|
|
74
|
+
) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Performs:
|
|
77
|
+
1. Structural validation (_validate_shape)
|
|
78
|
+
2. TypeVar binding + consistency checks
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# Structural validation (Literal, int, rank, repeated TypeVar)
|
|
82
|
+
_validate_shape(shape_dims, actual_shape)
|
|
83
|
+
|
|
84
|
+
for dim_idx, dim in enumerate(shape_dims):
|
|
85
|
+
if not isinstance(dim, TypeVar):
|
|
86
|
+
continue
|
|
87
|
+
if dim_idx >= len(actual_shape):
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
actual_dim = actual_shape[dim_idx]
|
|
91
|
+
is_class_level = _is_class_level_typevar(dim, owner_cls)
|
|
92
|
+
context = class_context if is_class_level else method_context
|
|
93
|
+
|
|
94
|
+
if dim in context:
|
|
95
|
+
expected_dim = context[dim]
|
|
96
|
+
if actual_dim != expected_dim:
|
|
97
|
+
level = "class" if is_class_level else "method"
|
|
98
|
+
location = (
|
|
99
|
+
f"parameter `{param_name}`"
|
|
100
|
+
if param_name is not None
|
|
101
|
+
else "return value"
|
|
102
|
+
)
|
|
103
|
+
raise DimensionError(
|
|
104
|
+
f"In {func_name}(...), {location} "
|
|
105
|
+
f"dimension {dim_idx} [{dim}] "
|
|
106
|
+
f"expected {expected_dim} ({level}-level binding), "
|
|
107
|
+
f"got {actual_dim}"
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
context[dim] = actual_dim
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _is_class_level_typevar(typevar: TypeVar, owner_cls: type) -> bool:
|
|
114
|
+
"""Check if a TypeVar is bound at class level vs method level."""
|
|
115
|
+
cls_params = getattr(owner_cls, "__parameters__", ())
|
|
116
|
+
return typevar in cls_params
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _get_instance_class_context(instance: Any) -> dict[TypeVar, int]:
|
|
120
|
+
"""Get or create the class-level TypeVar binding context for an instance."""
|
|
121
|
+
ctx = _class_typevar_context.get()
|
|
122
|
+
instance_id = id(instance)
|
|
123
|
+
if instance_id not in ctx:
|
|
124
|
+
ctx = ctx.copy()
|
|
125
|
+
ctx[instance_id] = {}
|
|
126
|
+
_class_typevar_context.set(ctx)
|
|
127
|
+
return ctx[instance_id]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def enforce_shapes(
|
|
131
|
+
func: Callable[Concatenate[T, P], R],
|
|
132
|
+
) -> Callable[Concatenate[T, P], R]:
|
|
133
|
+
"""
|
|
134
|
+
Decorator to automatically validate TypeVar shape bindings.
|
|
135
|
+
- Class-level TypeVars (from Generic[T]) are bound per-instance, persist across calls
|
|
136
|
+
- Method-level TypeVars are validated per-call only, local to single invocation
|
|
137
|
+
- Validates both parameter and return types
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
@wraps(func)
|
|
141
|
+
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
|
142
|
+
hints = get_type_hints(func, include_extras=True)
|
|
143
|
+
sig = inspect.signature(func)
|
|
144
|
+
bound_args = sig.bind(self, *args, **kwargs)
|
|
145
|
+
bound_args.apply_defaults()
|
|
146
|
+
|
|
147
|
+
owner_cls = self.__class__
|
|
148
|
+
class_context = _get_instance_class_context(self)
|
|
149
|
+
method_context = dict[TypeVar, int]()
|
|
150
|
+
|
|
151
|
+
# Validate arguments
|
|
152
|
+
for param_name, param_value in bound_args.arguments.items():
|
|
153
|
+
if param_name == "self":
|
|
154
|
+
continue
|
|
155
|
+
if param_name not in hints:
|
|
156
|
+
continue
|
|
157
|
+
if not hasattr(param_value, "shape"):
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
shape_dims = _extract_shape_dims(hints[param_name])
|
|
161
|
+
if shape_dims is None:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
_validate_and_bind(
|
|
165
|
+
shape_dims=shape_dims,
|
|
166
|
+
actual_shape=param_value.shape,
|
|
167
|
+
owner_cls=owner_cls,
|
|
168
|
+
func_name=func.__name__,
|
|
169
|
+
param_name=param_name,
|
|
170
|
+
class_context=class_context,
|
|
171
|
+
method_context=method_context,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Execute function with active contexts
|
|
175
|
+
method_token = _method_typevar_context.set(method_context)
|
|
176
|
+
class_token = _active_class_context.set(class_context)
|
|
177
|
+
try:
|
|
178
|
+
result = func(self, *args, **kwargs)
|
|
179
|
+
finally:
|
|
180
|
+
_method_typevar_context.reset(method_token)
|
|
181
|
+
_active_class_context.reset(class_token)
|
|
182
|
+
|
|
183
|
+
# Validate return
|
|
184
|
+
if "return" in hints and result is not None and hasattr(result, "shape"):
|
|
185
|
+
shape_dims = _extract_shape_dims(hints["return"])
|
|
186
|
+
actual_shape = getattr(result, "shape")
|
|
187
|
+
if shape_dims is not None:
|
|
188
|
+
_validate_and_bind(
|
|
189
|
+
shape_dims=shape_dims,
|
|
190
|
+
actual_shape=actual_shape,
|
|
191
|
+
owner_cls=owner_cls,
|
|
192
|
+
func_name=func.__name__,
|
|
193
|
+
param_name=None,
|
|
194
|
+
class_context=class_context,
|
|
195
|
+
method_context=method_context,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# context-aware validation (cross-call)
|
|
199
|
+
_validate_shape_against_contexts(shape_dims, actual_shape)
|
|
200
|
+
|
|
201
|
+
return result
|
|
202
|
+
|
|
203
|
+
return wrapper
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DimExpr
|
|
3
|
+
=======
|
|
4
|
+
"""
|
|
5
|
+
# src/typingkit/_typed/dimexpr.py
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Generic,
|
|
11
|
+
Literal,
|
|
12
|
+
NoReturn,
|
|
13
|
+
TypeAlias,
|
|
14
|
+
TypeAliasType,
|
|
15
|
+
TypeVar,
|
|
16
|
+
get_args,
|
|
17
|
+
get_origin,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
Arg1 = TypeVar("Arg1", bound=int)
|
|
21
|
+
Arg2 = TypeVar("Arg2", bound=int)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
## Base
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DimExpr(int):
|
|
28
|
+
def __new__(cls, *args: Any) -> NoReturn:
|
|
29
|
+
raise TypeError("Shape expressions cannot be instantiated")
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def expr(cls, args: tuple[Any, ...], /) -> int:
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UnaryOp(Generic[Arg1], DimExpr):
|
|
37
|
+
@classmethod
|
|
38
|
+
def expr(cls, args: tuple[Arg1], /) -> int:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BinaryOp(Generic[Arg1, Arg2], DimExpr):
|
|
43
|
+
@classmethod
|
|
44
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
## Core operations
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Neg(UnaryOp[Arg1]):
|
|
52
|
+
@classmethod
|
|
53
|
+
def expr(cls, args: tuple[Arg1], /) -> int:
|
|
54
|
+
return -args[0]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Add(BinaryOp[Arg1, Arg2]):
|
|
58
|
+
@classmethod
|
|
59
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
60
|
+
return args[0] + args[1]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Mul(BinaryOp[Arg1, Arg2]):
|
|
64
|
+
@classmethod
|
|
65
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
66
|
+
return args[0] * args[1]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Pow(BinaryOp[Arg1, Arg2]):
|
|
70
|
+
@classmethod
|
|
71
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
72
|
+
return args[0] ** args[1]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
## Additional operations
|
|
76
|
+
|
|
77
|
+
# Through TypeAliases
|
|
78
|
+
Sub: TypeAlias = Add[Arg1, Neg[Arg2]]
|
|
79
|
+
PlusOne: TypeAlias = Add[Arg1, Literal[1]]
|
|
80
|
+
|
|
81
|
+
# Through PEP-695 style TypeAliasTypes
|
|
82
|
+
type MinusOne[Arg1: int] = Sub[Arg1, Literal[1]]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Through Subclassing
|
|
86
|
+
class Squared(Mul[Arg1, Arg1]): ...
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Custom operations
|
|
90
|
+
# Can also directly define with a `expr`, with some custom logic, say
|
|
91
|
+
class Cubed(UnaryOp[Arg1]):
|
|
92
|
+
@classmethod
|
|
93
|
+
def expr(cls, args: tuple[Arg1], /) -> int:
|
|
94
|
+
return args[0] ** 3
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Log(BinaryOp[Arg1, Arg2]):
|
|
98
|
+
@classmethod
|
|
99
|
+
def expr(cls, args: tuple[Arg1, Arg2], /) -> int:
|
|
100
|
+
x = math.log(args[0], args[1])
|
|
101
|
+
xi = round(x)
|
|
102
|
+
if math.isclose(x, xi, rel_tol=0, abs_tol=1e-12):
|
|
103
|
+
return xi
|
|
104
|
+
raise TypeError("Invalid dimension. Not an integer.")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
Log2 = Log[Arg1, Literal[2]]
|
|
108
|
+
Log10 = Log[Arg1, Literal[10]]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
## Evaluation
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _resolve_dim(tp: Any) -> Any:
|
|
115
|
+
# type[Any] / type[int] / EllipsisType / TypeVar
|
|
116
|
+
if tp is Any or tp is int or tp is Ellipsis or isinstance(tp, TypeVar):
|
|
117
|
+
return tp
|
|
118
|
+
|
|
119
|
+
origin = get_origin(tp)
|
|
120
|
+
|
|
121
|
+
# Literal[N]
|
|
122
|
+
if origin is Literal:
|
|
123
|
+
# [TODO] How should we handle multiple args case?
|
|
124
|
+
return get_args(tp)[0]
|
|
125
|
+
|
|
126
|
+
# TypeAliasType
|
|
127
|
+
if isinstance(origin, TypeAliasType):
|
|
128
|
+
return _resolve_dim(origin.__value__[get_args(tp)])
|
|
129
|
+
|
|
130
|
+
# DimExpr
|
|
131
|
+
if origin and issubclass(origin, DimExpr):
|
|
132
|
+
args = get_args(tp)
|
|
133
|
+
|
|
134
|
+
# Case 1: class defines its own expr
|
|
135
|
+
if "expr" in origin.__dict__:
|
|
136
|
+
values = tuple(_resolve_dim(arg) for arg in args)
|
|
137
|
+
if not all(isinstance(v, int) for v in values):
|
|
138
|
+
return origin[values] # pyright: ignore[reportInvalidTypeArguments]
|
|
139
|
+
return origin.expr(values)
|
|
140
|
+
|
|
141
|
+
# Case 2: subclassing
|
|
142
|
+
if hasattr(origin, "__orig_bases__"):
|
|
143
|
+
for base in getattr(origin, "__orig_bases__"):
|
|
144
|
+
base_origin = get_origin(base)
|
|
145
|
+
if base_origin and issubclass(base_origin, DimExpr):
|
|
146
|
+
base_args = get_args(base)
|
|
147
|
+
|
|
148
|
+
parameters = getattr(origin, "__parameters__", ())
|
|
149
|
+
param_map = dict(zip(parameters, args))
|
|
150
|
+
substituted = tuple(param_map.get(a, a) for a in base_args)
|
|
151
|
+
|
|
152
|
+
return _resolve_dim(base_origin[substituted]) # pyright: ignore[reportInvalidTypeArguments]
|
|
153
|
+
|
|
154
|
+
raise TypeError(f"Cannot evaluate {tp}")
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shape factory for TypedNDArray
|
|
3
|
+
=======
|
|
4
|
+
"""
|
|
5
|
+
# src/typingkit/_typed/factory.py
|
|
6
|
+
|
|
7
|
+
# pyright: reportPrivateUsage = false
|
|
8
|
+
|
|
9
|
+
from types import GenericAlias
|
|
10
|
+
from typing import Any, Generic, TypeVar
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
|
|
15
|
+
from typingkit._typed.ndarray import TypedNDArray, _AnyShape
|
|
16
|
+
|
|
17
|
+
_ShapeT = TypeVar("_ShapeT", bound=_AnyShape)
|
|
18
|
+
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
|
|
19
|
+
_NewScalarT = TypeVar("_NewScalarT", bound=np.generic)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _ShapeBuilder(Generic[_ShapeT, _ScalarT]):
|
|
23
|
+
__slots__ = ("_shape_spec", "_dtype_spec")
|
|
24
|
+
|
|
25
|
+
def __init__(self, shape_spec: GenericAlias, dtype_spec: GenericAlias):
|
|
26
|
+
self._shape_spec = shape_spec
|
|
27
|
+
self._dtype_spec = dtype_spec
|
|
28
|
+
|
|
29
|
+
def dtype(self, dtype: type[_NewScalarT]) -> "_ShapeBuilder[_ShapeT, _NewScalarT]":
|
|
30
|
+
return _ShapeBuilder(self._shape_spec, GenericAlias(np.dtype, (dtype,)))
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self, object: npt.ArrayLike
|
|
34
|
+
) -> TypedNDArray[_ShapeT, np.dtype[_ScalarT]]:
|
|
35
|
+
dtype = self._dtype_spec.__args__[0]
|
|
36
|
+
dtype = None if dtype is Any else dtype
|
|
37
|
+
return TypedNDArray[self._shape_spec](object, dtype) # type: ignore
|
|
38
|
+
|
|
39
|
+
def __repr__(self) -> str:
|
|
40
|
+
return f"TypedNDArrayFactory[{self._shape_spec.__args__}, {self._dtype_spec.__args__[0]}]"
|
|
41
|
+
|
|
42
|
+
## Convenience helpers
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def float64(self) -> "_ShapeBuilder[_ShapeT, np.float64]":
|
|
46
|
+
return self.dtype(np.float64)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def float32(self) -> "_ShapeBuilder[_ShapeT, np.float32]":
|
|
50
|
+
return self.dtype(np.float32)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def int64(self) -> "_ShapeBuilder[_ShapeT, np.int64]":
|
|
54
|
+
return self.dtype(np.int64)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class _ShapeFactory:
|
|
58
|
+
def __getitem__(self, dims: _ShapeT) -> _ShapeBuilder[_ShapeT, Any]:
|
|
59
|
+
if not isinstance(dims, tuple): # pyright: ignore[reportUnnecessaryIsInstance]
|
|
60
|
+
dims = (dims,)
|
|
61
|
+
|
|
62
|
+
shape_spec = GenericAlias(tuple, dims)
|
|
63
|
+
dtype_spec = GenericAlias(np.dtype, (Any,))
|
|
64
|
+
return _ShapeBuilder(shape_spec, dtype_spec)
|
|
65
|
+
|
|
66
|
+
# Prefer [...] rather than (...), but this is also provided
|
|
67
|
+
def __call__(self, dims: _ShapeT) -> _ShapeBuilder[_ShapeT, Any]:
|
|
68
|
+
return self.__getitem__(dims)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
Shaped = _ShapeFactory()
|