torch-einops-kit 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/PKG-INFO +40 -17
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/README.md +24 -15
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/pyproject.toml +24 -2
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/__init__.py +1 -5
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_types.py +2 -5
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/save_load.py +12 -21
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/scaleValues.py +47 -4
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_cat_and_stack.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_dimensions.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_helpers.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_masking.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_padding.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_semiotics.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_slicing.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/device.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/einops.py +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/py.typed +0 -0
- {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/utils.py +0 -0
|
@@ -1,18 +1,32 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-einops-kit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`. A superset of `lucidrains/torch-einops-utils` with similar utilities from other lucidrains packages. Adds strict typing, extensive tests, and comprehensive docstrings.
|
|
5
|
+
Keywords: artificial intelligence,einops,machine learning,pytorch,torch
|
|
5
6
|
Author: Phil Wang, Hunter Hogan
|
|
6
7
|
Author-email: Phil Wang <lucidrains@gmail.com>, Hunter Hogan <HunterHogan@pm.me>
|
|
7
8
|
License-Expression: CC-BY-NC-4.0
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Environment :: Console
|
|
11
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Natural Language :: English
|
|
15
|
+
Classifier: Operating System :: OS Independent
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Topic :: Utilities
|
|
20
|
+
Classifier: Typing :: Typed
|
|
8
21
|
Requires-Dist: einops>=0.8.2
|
|
9
|
-
Requires-Dist: packaging>=26.0
|
|
10
22
|
Requires-Dist: torch>=2.10.0
|
|
11
23
|
Requires-Dist: typing-extensions>=4.0.0
|
|
12
24
|
Maintainer: Hunter Hogan
|
|
13
25
|
Maintainer-email: Hunter Hogan <HunterHogan@pm.me>
|
|
14
26
|
Requires-Python: >=3.10
|
|
27
|
+
Project-URL: Documentation, https://context7.com/hunterhogan/torch_einops_kit
|
|
15
28
|
Project-URL: Donate, https://www.patreon.com/integrated
|
|
29
|
+
Project-URL: Download, https://pypi.org/project/torch_einops_kit
|
|
16
30
|
Project-URL: Homepage, https://github.com/hunterhogan/torch_einops_kit
|
|
17
31
|
Project-URL: Issues, https://github.com/hunterhogan/torch_einops_kit/issues
|
|
18
32
|
Project-URL: Repository, https://github.com/hunterhogan/torch_einops_kit.git
|
|
@@ -22,8 +36,8 @@ Description-Content-Type: text/markdown
|
|
|
22
36
|
|
|
23
37
|
Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`.
|
|
24
38
|
|
|
25
|
-
[](https://pypi.org/project/torch-einops-kit/)
|
|
40
|
+
[](https://pypi.org/project/torch-einops-kit/)
|
|
27
41
|
|
|
28
42
|
This repository is a superset of [`lucidrains/torch-einops-utils`](https://github.com/lucidrains/torch-einops-utils). The upstream repository is a compact collection of small utilities that show up repeatedly in lucidrains model repositories. `torch_einops_kit` keeps that role. The main difference is emphasis. This fork adds roughly 6000 lines of tests, typing, and docstrings so the utility layer is easier to trust, easier to search, and easier to apply correctly.
|
|
29
43
|
|
|
@@ -38,7 +52,7 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
|
|
|
38
52
|
- Project name: `torch_einops_kit`.
|
|
39
53
|
- Import path: `torch_einops_kit`.
|
|
40
54
|
- Python requirement: `>=3.10`.
|
|
41
|
-
- Runtime dependencies: `torch`, `einops`,
|
|
55
|
+
- Runtime dependencies: `torch`, `einops`, and `typing-extensions`.
|
|
42
56
|
- Root package exports: helper functions, slicing helpers, rank-alignment helpers, mask helpers, safe concatenation helpers, padding helpers, normalization helpers, and PyTree / `einops` helpers.
|
|
43
57
|
- Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
|
|
44
58
|
- Typing status: the package ships a `py.typed` marker and the repository uses strict type checking.
|
|
@@ -46,10 +60,16 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
|
|
|
46
60
|
|
|
47
61
|
## Installation
|
|
48
62
|
|
|
49
|
-
|
|
63
|
+
### `uv`
|
|
50
64
|
|
|
51
65
|
```bash
|
|
52
|
-
uv add
|
|
66
|
+
uv add torch_einops_kit
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
### `pip`
|
|
70
|
+
|
|
71
|
+
```bash
|
|
72
|
+
pip install torch_einops_kit
|
|
53
73
|
```
|
|
54
74
|
|
|
55
75
|
## Import map
|
|
@@ -61,9 +81,7 @@ from torch_einops_kit import (
|
|
|
61
81
|
align_dims_left,
|
|
62
82
|
and_masks,
|
|
63
83
|
broadcast_cat,
|
|
64
|
-
l2norm,
|
|
65
84
|
lens_to_mask,
|
|
66
|
-
masked_mean,
|
|
67
85
|
maybe,
|
|
68
86
|
once,
|
|
69
87
|
or_masks,
|
|
@@ -108,6 +126,17 @@ from torch_einops_kit.save_load import (
|
|
|
108
126
|
)
|
|
109
127
|
```
|
|
110
128
|
|
|
129
|
+
Import checkpoint decorators from `torch_einops_kit.scaleValues`:
|
|
130
|
+
|
|
131
|
+
```python
|
|
132
|
+
from torch_einops_kit.scaleValues import (
|
|
133
|
+
exclusive_cumsum,
|
|
134
|
+
l2norm,
|
|
135
|
+
RMSNorm,
|
|
136
|
+
masked_mean,
|
|
137
|
+
)
|
|
138
|
+
```
|
|
139
|
+
|
|
111
140
|
## Quick examples
|
|
112
141
|
|
|
113
142
|
### Batch variable-length tensors and build a mask
|
|
@@ -265,13 +294,6 @@ These functions add numeric padding values along an existing tensor dimension.
|
|
|
265
294
|
|
|
266
295
|
When `pad_lens=True` and `return_lens=True`, the second tensor contains padding widths rather than original lengths.
|
|
267
296
|
|
|
268
|
-
### Normalization and masked reduction helpers
|
|
269
|
-
|
|
270
|
-
| Name | Contract |
|
|
271
|
-
| ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
272
|
-
| `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension by dividing by its L2 norm. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
|
|
273
|
-
| `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
|
|
274
|
-
|
|
275
297
|
### PyTree helpers
|
|
276
298
|
|
|
277
299
|
| Name | Contract |
|
|
@@ -281,10 +303,11 @@ When `pad_lens=True` and `return_lens=True`, the second tensor contains padding
|
|
|
281
303
|
|
|
282
304
|
## `scaleValues` submodule reference
|
|
283
305
|
|
|
284
|
-
The `torch_einops_kit.scaleValues` submodule contains vector normalization, masked mean computation, and the `RMSNorm` layer.
|
|
306
|
+
The `torch_einops_kit.scaleValues` submodule contains exclusive prefix sums, vector normalization, masked mean computation, and the `RMSNorm` layer.
|
|
285
307
|
|
|
286
308
|
| Name | Contract |
|
|
287
309
|
| ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
310
|
+
| `exclusive_cumsum(t, dim=-1)` | Computes the exclusive prefix sum of `t` along `dim`. Each output position receives the sum of all elements that strictly precede it along `dim`. The element at index zero is always zero. |
|
|
288
311
|
| `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
|
|
289
312
|
| `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
|
|
290
313
|
| `RMSNorm(dim)` | `torch.nn.Module` that normalizes the last feature axis to unit length, multiplies by `√dim`, and applies a learned per-feature `gamma` parameter. Use as a pre-normalization layer before attention, feedforward, or linear projection sublayers in transformer-style modules. |
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`.
|
|
4
4
|
|
|
5
|
-
[](https://pypi.org/project/torch-einops-kit/)
|
|
6
|
+
[](https://pypi.org/project/torch-einops-kit/)
|
|
7
7
|
|
|
8
8
|
This repository is a superset of [`lucidrains/torch-einops-utils`](https://github.com/lucidrains/torch-einops-utils). The upstream repository is a compact collection of small utilities that show up repeatedly in lucidrains model repositories. `torch_einops_kit` keeps that role. The main difference is emphasis. This fork adds roughly 6000 lines of tests, typing, and docstrings so the utility layer is easier to trust, easier to search, and easier to apply correctly.
|
|
9
9
|
|
|
@@ -18,7 +18,7 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
|
|
|
18
18
|
- Project name: `torch_einops_kit`.
|
|
19
19
|
- Import path: `torch_einops_kit`.
|
|
20
20
|
- Python requirement: `>=3.10`.
|
|
21
|
-
- Runtime dependencies: `torch`, `einops`,
|
|
21
|
+
- Runtime dependencies: `torch`, `einops`, and `typing-extensions`.
|
|
22
22
|
- Root package exports: helper functions, slicing helpers, rank-alignment helpers, mask helpers, safe concatenation helpers, padding helpers, normalization helpers, and PyTree / `einops` helpers.
|
|
23
23
|
- Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
|
|
24
24
|
- Typing status: the package ships a `py.typed` marker and the repository uses strict type checking.
|
|
@@ -26,10 +26,16 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
|
|
|
26
26
|
|
|
27
27
|
## Installation
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
### `uv`
|
|
30
30
|
|
|
31
31
|
```bash
|
|
32
|
-
uv add
|
|
32
|
+
uv add torch_einops_kit
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
### `pip`
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install torch_einops_kit
|
|
33
39
|
```
|
|
34
40
|
|
|
35
41
|
## Import map
|
|
@@ -41,9 +47,7 @@ from torch_einops_kit import (
|
|
|
41
47
|
align_dims_left,
|
|
42
48
|
and_masks,
|
|
43
49
|
broadcast_cat,
|
|
44
|
-
l2norm,
|
|
45
50
|
lens_to_mask,
|
|
46
|
-
masked_mean,
|
|
47
51
|
maybe,
|
|
48
52
|
once,
|
|
49
53
|
or_masks,
|
|
@@ -88,6 +92,17 @@ from torch_einops_kit.save_load import (
|
|
|
88
92
|
)
|
|
89
93
|
```
|
|
90
94
|
|
|
95
|
+
Import checkpoint decorators from `torch_einops_kit.scaleValues`:
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
from torch_einops_kit.scaleValues import (
|
|
99
|
+
exclusive_cumsum,
|
|
100
|
+
l2norm,
|
|
101
|
+
RMSNorm,
|
|
102
|
+
masked_mean,
|
|
103
|
+
)
|
|
104
|
+
```
|
|
105
|
+
|
|
91
106
|
## Quick examples
|
|
92
107
|
|
|
93
108
|
### Batch variable-length tensors and build a mask
|
|
@@ -245,13 +260,6 @@ These functions add numeric padding values along an existing tensor dimension.
|
|
|
245
260
|
|
|
246
261
|
When `pad_lens=True` and `return_lens=True`, the second tensor contains padding widths rather than original lengths.
|
|
247
262
|
|
|
248
|
-
### Normalization and masked reduction helpers
|
|
249
|
-
|
|
250
|
-
| Name | Contract |
|
|
251
|
-
| ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
252
|
-
| `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension by dividing by its L2 norm. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
|
|
253
|
-
| `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
|
|
254
|
-
|
|
255
263
|
### PyTree helpers
|
|
256
264
|
|
|
257
265
|
| Name | Contract |
|
|
@@ -261,10 +269,11 @@ When `pad_lens=True` and `return_lens=True`, the second tensor contains padding
|
|
|
261
269
|
|
|
262
270
|
## `scaleValues` submodule reference
|
|
263
271
|
|
|
264
|
-
The `torch_einops_kit.scaleValues` submodule contains vector normalization, masked mean computation, and the `RMSNorm` layer.
|
|
272
|
+
The `torch_einops_kit.scaleValues` submodule contains exclusive prefix sums, vector normalization, masked mean computation, and the `RMSNorm` layer.
|
|
265
273
|
|
|
266
274
|
| Name | Contract |
|
|
267
275
|
| ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
276
|
+
| `exclusive_cumsum(t, dim=-1)` | Computes the exclusive prefix sum of `t` along `dim`. Each output position receives the sum of all elements that strictly precede it along `dim`. The element at index zero is always zero. |
|
|
268
277
|
| `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
|
|
269
278
|
| `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
|
|
270
279
|
| `RMSNorm(dim)` | `torch.nn.Module` that normalizes the last feature axis to unit length, multiplies by `√dim`, and applies a learned per-feature `gamma` parameter. Use as a pre-normalization layer before attention, feedforward, or linear projection sublayers in transformer-style modules. |
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "torch-einops-kit"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.2"
|
|
4
4
|
description = "Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`. A superset of `lucidrains/torch-einops-utils` with similar utilities from other lucidrains packages. Adds strict typing, extensive tests, and comprehensive docstrings."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.10"
|
|
@@ -10,15 +10,37 @@ authors = [
|
|
|
10
10
|
{ name = "Hunter Hogan", email = "HunterHogan@pm.me" }
|
|
11
11
|
]
|
|
12
12
|
maintainers = [{ name = "Hunter Hogan", email = "HunterHogan@pm.me" }]
|
|
13
|
+
keywords = [
|
|
14
|
+
"artificial intelligence",
|
|
15
|
+
"einops",
|
|
16
|
+
"machine learning",
|
|
17
|
+
"pytorch",
|
|
18
|
+
"torch",
|
|
19
|
+
]
|
|
20
|
+
classifiers = [
|
|
21
|
+
"Development Status :: 4 - Beta",
|
|
22
|
+
"Environment :: Console",
|
|
23
|
+
"Environment :: GPU :: NVIDIA CUDA",
|
|
24
|
+
"Intended Audience :: Developers",
|
|
25
|
+
"Intended Audience :: Science/Research",
|
|
26
|
+
"Natural Language :: English",
|
|
27
|
+
"Operating System :: OS Independent",
|
|
28
|
+
"Programming Language :: Python :: 3.10",
|
|
29
|
+
"Programming Language :: Python :: 3.14",
|
|
30
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
31
|
+
"Topic :: Utilities",
|
|
32
|
+
"Typing :: Typed",
|
|
33
|
+
]
|
|
13
34
|
dependencies = [
|
|
14
35
|
"einops>=0.8.2",
|
|
15
|
-
"packaging>=26.0",
|
|
16
36
|
"torch>=2.10.0",
|
|
17
37
|
"typing-extensions>=4.0.0",
|
|
18
38
|
]
|
|
19
39
|
|
|
20
40
|
[project.urls]
|
|
41
|
+
Documentation = "https://context7.com/hunterhogan/torch_einops_kit"
|
|
21
42
|
Donate = "https://www.patreon.com/integrated"
|
|
43
|
+
Download = "https://pypi.org/project/torch_einops_kit"
|
|
22
44
|
Homepage = "https://github.com/hunterhogan/torch_einops_kit"
|
|
23
45
|
Issues = "https://github.com/hunterhogan/torch_einops_kit/issues"
|
|
24
46
|
Repository = "https://github.com/hunterhogan/torch_einops_kit.git"
|
|
@@ -92,10 +92,6 @@ pad_sequence_and_cat
|
|
|
92
92
|
|
|
93
93
|
Utilities
|
|
94
94
|
---------
|
|
95
|
-
l2norm
|
|
96
|
-
Normalize `Tensor` vectors to unit length along the last dimension.
|
|
97
|
-
masked_mean
|
|
98
|
-
Compute a mean over positions selected by a boolean mask.
|
|
99
95
|
tree_flatten_with_inverse
|
|
100
96
|
Flatten a PyTree and return an inverse reconstruction function.
|
|
101
97
|
tree_map_tensor
|
|
@@ -110,7 +106,7 @@ einops
|
|
|
110
106
|
save_load
|
|
111
107
|
Decorate `torch.nn.Module` subclasses with checkpoint save, load, and reconstruction helpers.
|
|
112
108
|
scaleValues
|
|
113
|
-
|
|
109
|
+
Compute exclusive prefix sums, normalize feature vectors, and compute masked means.
|
|
114
110
|
"""
|
|
115
111
|
# isort: split
|
|
116
112
|
from torch_einops_kit._semiotics import decreasing as decreasing, zeroIndexed as zeroIndexed
|
|
@@ -192,8 +192,8 @@ class DehydratedCheckpoint(TypedDict):
|
|
|
192
192
|
`Module` instances with `DehydratedTorchNNModule` [9] reconstruction records.
|
|
193
193
|
version : str | None
|
|
194
194
|
An optional version string written at save time. When both the stored version and the
|
|
195
|
-
`save_load` `version` argument are
|
|
196
|
-
|
|
195
|
+
`save_load` `version` argument are different, the generated load method prints a notice but
|
|
196
|
+
still restores model state.
|
|
197
197
|
|
|
198
198
|
See Also
|
|
199
199
|
--------
|
|
@@ -220,9 +220,6 @@ class DehydratedCheckpoint(TypedDict):
|
|
|
220
220
|
[8] torch_einops_kit.save_load.dehydrate_config
|
|
221
221
|
|
|
222
222
|
[9] torch_einops_kit.DehydratedTorchNNModule
|
|
223
|
-
|
|
224
|
-
[10] packaging.version - packaging documentation
|
|
225
|
-
https://packaging.pypa.io/en/stable/version.html
|
|
226
223
|
"""
|
|
227
224
|
model: dict[str, Tensor]
|
|
228
225
|
config: bytes
|
|
@@ -20,7 +20,6 @@ from __future__ import annotations
|
|
|
20
20
|
|
|
21
21
|
from collections.abc import Callable
|
|
22
22
|
from functools import wraps
|
|
23
|
-
from packaging import version as packaging_version
|
|
24
23
|
from pathlib import Path
|
|
25
24
|
from torch.nn import Module
|
|
26
25
|
from torch_einops_kit import (
|
|
@@ -299,13 +298,12 @@ def save_load(
|
|
|
299
298
|
-----
|
|
300
299
|
The generated save method writes a dictionary that is compatible with `torch.save` [5], and the
|
|
301
300
|
generated load paths read the dictionary with `torch.load` [6] on CPU before restoring state.
|
|
302
|
-
When `version` and the stored checkpoint version
|
|
303
|
-
|
|
304
|
-
stored model state.
|
|
301
|
+
When `version` and the stored checkpoint version are different, the generated load method prints
|
|
302
|
+
a notice but still restores the stored model state.
|
|
305
303
|
|
|
306
304
|
Examples
|
|
307
305
|
--------
|
|
308
|
-
From `tests.test_save_load.test_init_and_load` [
|
|
306
|
+
From `tests.test_save_load.test_init_and_load` [7]:
|
|
309
307
|
|
|
310
308
|
```python
|
|
311
309
|
from pathlib import Path
|
|
@@ -327,7 +325,7 @@ def save_load(
|
|
|
327
325
|
restored_model = SimpleNet.init_and_load(str(path))
|
|
328
326
|
```
|
|
329
327
|
|
|
330
|
-
From `tests.test_save_load_extended` [
|
|
328
|
+
From `tests.test_save_load_extended` [8]:
|
|
331
329
|
|
|
332
330
|
```python
|
|
333
331
|
import torch
|
|
@@ -367,11 +365,9 @@ def save_load(
|
|
|
367
365
|
https://pytorch.org/docs/stable/generated/torch.save.html
|
|
368
366
|
[6] torch.load - PyTorch documentation
|
|
369
367
|
https://pytorch.org/docs/stable/generated/torch.load.html
|
|
370
|
-
[7]
|
|
371
|
-
https://packaging.pypa.io/en/stable/version.html
|
|
372
|
-
[8] tests.test_save_load.test_init_and_load
|
|
368
|
+
[7] tests.test_save_load.test_init_and_load
|
|
373
369
|
|
|
374
|
-
[
|
|
370
|
+
[8] tests.test_save_load_extended.test_save_load_supports_custom_method_names_and_config_storage
|
|
375
371
|
"""
|
|
376
372
|
def _save_load(klass: type[TorchNNModule]) -> type[TorchNNModule]:
|
|
377
373
|
if not issubclass(klass, Module):
|
|
@@ -435,10 +431,8 @@ def save_load(
|
|
|
435
431
|
"""Restore model state from a checkpoint file.
|
|
436
432
|
|
|
437
433
|
You can use this method to load parameter values into an already-constructed decorated
|
|
438
|
-
module instance. The method reads the checkpoint via `torch.load` [1] on CPU,
|
|
439
|
-
|
|
440
|
-
differ under `packaging.version.parse` [2], and then restores parameter values via
|
|
441
|
-
`load_state_dict`.
|
|
434
|
+
module instance. The method reads the checkpoint via `torch.load` [1] on CPU, then
|
|
435
|
+
restores parameter values via `load_state_dict`.
|
|
442
436
|
|
|
443
437
|
Parameters
|
|
444
438
|
----------
|
|
@@ -460,16 +454,13 @@ def save_load(
|
|
|
460
454
|
Warns
|
|
461
455
|
-----
|
|
462
456
|
UserWarning
|
|
463
|
-
Emitted when the checkpoint's stored version and the decoration-time version
|
|
464
|
-
|
|
457
|
+
Emitted when the checkpoint's stored version and the decoration-time version are
|
|
458
|
+
different [2].
|
|
465
459
|
|
|
466
460
|
References
|
|
467
461
|
----------
|
|
468
462
|
[1] torch.load - PyTorch documentation
|
|
469
463
|
https://pytorch.org/docs/stable/generated/torch.load.html
|
|
470
|
-
|
|
471
|
-
[2] packaging.version - packaging documentation
|
|
472
|
-
https://packaging.pypa.io/en/stable/version.html
|
|
473
464
|
"""
|
|
474
465
|
path = Path(path)
|
|
475
466
|
if not path.exists():
|
|
@@ -478,7 +469,7 @@ def save_load(
|
|
|
478
469
|
|
|
479
470
|
pkg: DehydratedCheckpoint = torch.load(str(path), map_location = 'cpu')
|
|
480
471
|
|
|
481
|
-
if exists(version) and exists(pkg['version']) and
|
|
472
|
+
if exists(version) and exists(pkg['version']) and (version != pkg['version']):
|
|
482
473
|
message: str = f'I received a checkpoint saved at version `{pkg["version"]}`, but the current package version is `{version}`.'
|
|
483
474
|
warnings.warn(message, UserWarning, stacklevel=2)
|
|
484
475
|
|
|
@@ -537,7 +528,7 @@ def save_load(
|
|
|
537
528
|
# set decorated init as well as save, load, and init_and_load
|
|
538
529
|
|
|
539
530
|
klass.__init__ = __init__ # ty:ignore[invalid-assignment]
|
|
540
|
-
# TODO figure out how to use something like `wraps`
|
|
531
|
+
# TODO figure out how to use something like `wraps` so the signature and docstring are public.
|
|
541
532
|
setattr(klass, save_method_name, _save)
|
|
542
533
|
setattr(klass, load_method_name, _load)
|
|
543
534
|
setattr(klass, init_and_load_classmethod_name, _init_and_load_from)
|
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
"""Provide vector normalization, masked mean computation, and learned RMS normalization.
|
|
1
|
+
"""Provide exclusive prefix sums, vector normalization, masked mean computation, and learned RMS normalization.
|
|
2
2
|
|
|
3
|
-
You can use this module to
|
|
4
|
-
reductions over selected tensor positions, and apply learned
|
|
5
|
-
to transformer and neural network feature channels.
|
|
3
|
+
You can use this module to compute exclusive prefix sums of tensors, normalize feature vectors
|
|
4
|
+
to unit length, compute masked mean reductions over selected tensor positions, and apply learned
|
|
5
|
+
root-mean-square normalization to transformer and neural network feature channels.
|
|
6
6
|
|
|
7
7
|
Contents
|
|
8
8
|
--------
|
|
9
9
|
Functions
|
|
10
|
+
exclusive_cumsum
|
|
11
|
+
Compute the exclusive prefix sum of a `Tensor` along a dimension.
|
|
10
12
|
l2norm
|
|
11
13
|
Normalize `Tensor` vectors to unit length along the last dimension.
|
|
12
14
|
masked_mean
|
|
@@ -24,6 +26,47 @@ from torch_einops_kit import exists, pad_right_ndim
|
|
|
24
26
|
import torch
|
|
25
27
|
import torch.nn.functional as F
|
|
26
28
|
|
|
29
|
+
def exclusive_cumsum(t: Tensor, dim: int = -1) -> Tensor:
|
|
30
|
+
"""Compute the exclusive prefix sum of `Tensor` `t` along its dimension `dim`.
|
|
31
|
+
|
|
32
|
+
You can use `exclusive_cumsum` to produce a shifted cumulative sum where each output
|
|
33
|
+
position accumulates only the elements that strictly precede it along `dim`. The element
|
|
34
|
+
at index zero is always zero. This operation is useful for converting absolute segment
|
|
35
|
+
lengths into starting offsets.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
t : Tensor
|
|
40
|
+
Input `Tensor` to reduce.
|
|
41
|
+
dim : int = -1
|
|
42
|
+
Dimension along which to compute the exclusive prefix sum. Negative values index
|
|
43
|
+
from the last axis.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
exclusivePrefixSum : Tensor
|
|
48
|
+
`Tensor` with the same shape as `t` where `exclusivePrefixSum[..., i, ...]` equals
|
|
49
|
+
the sum of all elements of `t` with index strictly less than `i` along `dim`.
|
|
50
|
+
|
|
51
|
+
Mathematical Basis
|
|
52
|
+
------------------
|
|
53
|
+
Let t = [t₀, t₁, ..., tₙ₋₁] be the elements of `t` along `dim`. The exclusive prefix
|
|
54
|
+
sum S is defined element-wise as:
|
|
55
|
+
|
|
56
|
+
S[i] = Σⱼ₌₀^(i−1) tⱼ, with S[0] = 0.
|
|
57
|
+
|
|
58
|
+
Equivalently, S equals the inclusive cumulative sum minus `t`:
|
|
59
|
+
|
|
60
|
+
S = cumsum(t, dim) − t.
|
|
61
|
+
|
|
62
|
+
References
|
|
63
|
+
----------
|
|
64
|
+
[1] torch.Tensor.cumsum - PyTorch documentation
|
|
65
|
+
https://pytorch.org/docs/stable/generated/torch.Tensor.cumsum.html
|
|
66
|
+
"""
|
|
67
|
+
return t.cumsum(dim = dim) - t
|
|
68
|
+
|
|
69
|
+
|
|
27
70
|
def l2norm(t: Tensor) -> Tensor:
|
|
28
71
|
"""Normalize `Tensor` vectors to unit length.
|
|
29
72
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|