torch-einops-kit 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/PKG-INFO +40 -17
  2. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/README.md +24 -15
  3. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/pyproject.toml +24 -2
  4. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/__init__.py +1 -5
  5. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_types.py +2 -5
  6. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/save_load.py +12 -21
  7. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/scaleValues.py +47 -4
  8. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_cat_and_stack.py +0 -0
  9. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_dimensions.py +0 -0
  10. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_helpers.py +0 -0
  11. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_masking.py +0 -0
  12. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_padding.py +0 -0
  13. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_semiotics.py +0 -0
  14. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/_slicing.py +0 -0
  15. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/device.py +0 -0
  16. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/einops.py +0 -0
  17. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/py.typed +0 -0
  18. {torch_einops_kit-0.1.0 → torch_einops_kit-0.1.2}/src/torch_einops_kit/utils.py +0 -0
@@ -1,18 +1,32 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-kit
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`. A superset of `lucidrains/torch-einops-utils` with similar utilities from other lucidrains packages. Adds strict typing, extensive tests, and comprehensive docstrings.
5
+ Keywords: artificial intelligence,einops,machine learning,pytorch,torch
5
6
  Author: Phil Wang, Hunter Hogan
6
7
  Author-email: Phil Wang <lucidrains@gmail.com>, Hunter Hogan <HunterHogan@pm.me>
7
8
  License-Expression: CC-BY-NC-4.0
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Environment :: Console
11
+ Classifier: Environment :: GPU :: NVIDIA CUDA
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Natural Language :: English
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.14
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Topic :: Utilities
20
+ Classifier: Typing :: Typed
8
21
  Requires-Dist: einops>=0.8.2
9
- Requires-Dist: packaging>=26.0
10
22
  Requires-Dist: torch>=2.10.0
11
23
  Requires-Dist: typing-extensions>=4.0.0
12
24
  Maintainer: Hunter Hogan
13
25
  Maintainer-email: Hunter Hogan <HunterHogan@pm.me>
14
26
  Requires-Python: >=3.10
27
+ Project-URL: Documentation, https://context7.com/hunterhogan/torch_einops_kit
15
28
  Project-URL: Donate, https://www.patreon.com/integrated
29
+ Project-URL: Download, https://pypi.org/project/torch_einops_kit
16
30
  Project-URL: Homepage, https://github.com/hunterhogan/torch_einops_kit
17
31
  Project-URL: Issues, https://github.com/hunterhogan/torch_einops_kit/issues
18
32
  Project-URL: Repository, https://github.com/hunterhogan/torch_einops_kit.git
@@ -22,8 +36,8 @@ Description-Content-Type: text/markdown
22
36
 
23
37
  Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`.
24
38
 
25
- [![pip install torch_einops_kit](https://img.shields.io/badge/pip%20install-torch_einops_kit-gray.svg?colorB=3b434b)](https://pypi.org/project/torch_einops_kit/)
26
- [![uv add torch_einops_kit](https://img.shields.io/badge/uv%20add-torch_einops_kit-gray.svg?colorB=3b434b)](https://pypi.org/project/torch_einops_kit/)
39
+ [![pip install torch-einops-kit](https://img.shields.io/badge/pip_install-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
40
+ [![uv add torch-einops-kit](https://img.shields.io/badge/uv_add-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
27
41
 
28
42
  This repository is a superset of [`lucidrains/torch-einops-utils`](https://github.com/lucidrains/torch-einops-utils). The upstream repository is a compact collection of small utilities that show up repeatedly in lucidrains model repositories. `torch_einops_kit` keeps that role. The main difference is emphasis. This fork adds roughly 6000 lines of tests, typing, and docstrings so the utility layer is easier to trust, easier to search, and easier to apply correctly.
29
43
 
@@ -38,7 +52,7 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
38
52
  - Project name: `torch_einops_kit`.
39
53
  - Import path: `torch_einops_kit`.
40
54
  - Python requirement: `>=3.10`.
41
- - Runtime dependencies: `torch`, `einops`, `packaging`, and `typing-extensions`.
55
+ - Runtime dependencies: `torch`, `einops`, and `typing-extensions`.
42
56
  - Root package exports: helper functions, slicing helpers, rank-alignment helpers, mask helpers, safe concatenation helpers, padding helpers, normalization helpers, and PyTree / `einops` helpers.
43
57
  - Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
44
58
  - Typing status: the package ships a `py.typed` marker and the repository uses strict type checking.
@@ -46,10 +60,16 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
46
60
 
47
61
  ## Installation
48
62
 
49
- Install from this repository with `uv`:
63
+ ### `uv`
50
64
 
51
65
  ```bash
52
- uv add git+https://github.com/hunterhogan/torch_einops_kit.git
66
+ uv add torch_einops_kit
67
+ ```
68
+
69
+ ### `pip`
70
+
71
+ ```bash
72
+ pip install torch_einops_kit
53
73
  ```
54
74
 
55
75
  ## Import map
@@ -61,9 +81,7 @@ from torch_einops_kit import (
61
81
  align_dims_left,
62
82
  and_masks,
63
83
  broadcast_cat,
64
- l2norm,
65
84
  lens_to_mask,
66
- masked_mean,
67
85
  maybe,
68
86
  once,
69
87
  or_masks,
@@ -108,6 +126,17 @@ from torch_einops_kit.save_load import (
108
126
  )
109
127
  ```
110
128
 
129
+ Import checkpoint decorators from `torch_einops_kit.scaleValues`:
130
+
131
+ ```python
132
+ from torch_einops_kit.scaleValues import (
133
+ exclusive_cumsum,
134
+ l2norm,
135
+ RMSNorm,
136
+ masked_mean,
137
+ )
138
+ ```
139
+
111
140
  ## Quick examples
112
141
 
113
142
  ### Batch variable-length tensors and build a mask
@@ -265,13 +294,6 @@ These functions add numeric padding values along an existing tensor dimension.
265
294
 
266
295
  When `pad_lens=True` and `return_lens=True`, the second tensor contains padding widths rather than original lengths.
267
296
 
268
- ### Normalization and masked reduction helpers
269
-
270
- | Name | Contract |
271
- | ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
272
- | `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension by dividing by its L2 norm. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
273
- | `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
274
-
275
297
  ### PyTree helpers
276
298
 
277
299
  | Name | Contract |
@@ -281,10 +303,11 @@ When `pad_lens=True` and `return_lens=True`, the second tensor contains padding
281
303
 
282
304
  ## `scaleValues` submodule reference
283
305
 
284
- The `torch_einops_kit.scaleValues` submodule contains vector normalization, masked mean computation, and the `RMSNorm` layer. `l2norm` and `masked_mean` are also re-exported from the package root.
306
+ The `torch_einops_kit.scaleValues` submodule contains exclusive prefix sums, vector normalization, masked mean computation, and the `RMSNorm` layer.
285
307
 
286
308
  | Name | Contract |
287
309
  | ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
310
+ | `exclusive_cumsum(t, dim=-1)` | Computes the exclusive prefix sum of `t` along `dim`. Each output position receives the sum of all elements that strictly precede it along `dim`. The element at index zero is always zero. |
288
311
  | `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
289
312
  | `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
290
313
  | `RMSNorm(dim)` | `torch.nn.Module` that normalizes the last feature axis to unit length, multiplies by `√dim`, and applies a learned per-feature `gamma` parameter. Use as a pre-normalization layer before attention, feedforward, or linear projection sublayers in transformer-style modules. |
@@ -2,8 +2,8 @@
2
2
 
3
3
  Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`.
4
4
 
5
- [![pip install torch_einops_kit](https://img.shields.io/badge/pip%20install-torch_einops_kit-gray.svg?colorB=3b434b)](https://pypi.org/project/torch_einops_kit/)
6
- [![uv add torch_einops_kit](https://img.shields.io/badge/uv%20add-torch_einops_kit-gray.svg?colorB=3b434b)](https://pypi.org/project/torch_einops_kit/)
5
+ [![pip install torch-einops-kit](https://img.shields.io/badge/pip_install-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
6
+ [![uv add torch-einops-kit](https://img.shields.io/badge/uv_add-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
7
7
 
8
8
  This repository is a superset of [`lucidrains/torch-einops-utils`](https://github.com/lucidrains/torch-einops-utils). The upstream repository is a compact collection of small utilities that show up repeatedly in lucidrains model repositories. `torch_einops_kit` keeps that role. The main difference is emphasis. This fork adds roughly 6000 lines of tests, typing, and docstrings so the utility layer is easier to trust, easier to search, and easier to apply correctly.
9
9
 
@@ -18,7 +18,7 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
18
18
  - Project name: `torch_einops_kit`.
19
19
  - Import path: `torch_einops_kit`.
20
20
  - Python requirement: `>=3.10`.
21
- - Runtime dependencies: `torch`, `einops`, `packaging`, and `typing-extensions`.
21
+ - Runtime dependencies: `torch`, `einops`, and `typing-extensions`.
22
22
  - Root package exports: helper functions, slicing helpers, rank-alignment helpers, mask helpers, safe concatenation helpers, padding helpers, normalization helpers, and PyTree / `einops` helpers.
23
23
  - Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
24
24
  - Typing status: the package ships a `py.typed` marker and the repository uses strict type checking.
@@ -26,10 +26,16 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
26
26
 
27
27
  ## Installation
28
28
 
29
- Install from this repository with `uv`:
29
+ ### `uv`
30
30
 
31
31
  ```bash
32
- uv add git+https://github.com/hunterhogan/torch_einops_kit.git
32
+ uv add torch_einops_kit
33
+ ```
34
+
35
+ ### `pip`
36
+
37
+ ```bash
38
+ pip install torch_einops_kit
33
39
  ```
34
40
 
35
41
  ## Import map
@@ -41,9 +47,7 @@ from torch_einops_kit import (
41
47
  align_dims_left,
42
48
  and_masks,
43
49
  broadcast_cat,
44
- l2norm,
45
50
  lens_to_mask,
46
- masked_mean,
47
51
  maybe,
48
52
  once,
49
53
  or_masks,
@@ -88,6 +92,17 @@ from torch_einops_kit.save_load import (
88
92
  )
89
93
  ```
90
94
 
95
+ Import checkpoint decorators from `torch_einops_kit.scaleValues`:
96
+
97
+ ```python
98
+ from torch_einops_kit.scaleValues import (
99
+ exclusive_cumsum,
100
+ l2norm,
101
+ RMSNorm,
102
+ masked_mean,
103
+ )
104
+ ```
105
+
91
106
  ## Quick examples
92
107
 
93
108
  ### Batch variable-length tensors and build a mask
@@ -245,13 +260,6 @@ These functions add numeric padding values along an existing tensor dimension.
245
260
 
246
261
  When `pad_lens=True` and `return_lens=True`, the second tensor contains padding widths rather than original lengths.
247
262
 
248
- ### Normalization and masked reduction helpers
249
-
250
- | Name | Contract |
251
- | ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
252
- | `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension by dividing by its L2 norm. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
253
- | `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
254
-
255
263
  ### PyTree helpers
256
264
 
257
265
  | Name | Contract |
@@ -261,10 +269,11 @@ When `pad_lens=True` and `return_lens=True`, the second tensor contains padding
261
269
 
262
270
  ## `scaleValues` submodule reference
263
271
 
264
- The `torch_einops_kit.scaleValues` submodule contains vector normalization, masked mean computation, and the `RMSNorm` layer. `l2norm` and `masked_mean` are also re-exported from the package root.
272
+ The `torch_einops_kit.scaleValues` submodule contains exclusive prefix sums, vector normalization, masked mean computation, and the `RMSNorm` layer.
265
273
 
266
274
  | Name | Contract |
267
275
  | ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
276
+ | `exclusive_cumsum(t, dim=-1)` | Computes the exclusive prefix sum of `t` along `dim`. Each output position receives the sum of all elements that strictly precede it along `dim`. The element at index zero is always zero. |
268
277
  | `l2norm(t)` | Normalizes each vector in `t` to unit length along the last dimension. Delegates to `torch.nn.functional.normalize` with `p=2` and `dim=-1`. |
269
278
  | `masked_mean(t, mask=None, dim=None, eps=1e-5)` | Computes a masked mean. When `mask is None`, the function falls back to `t.mean(...)`. When no masked position is selected and `dim is None`, the function returns zero by summing over the empty selection. When `mask.ndim < t.ndim`, the function right-pads mask rank before broadcasting. |
270
279
  | `RMSNorm(dim)` | `torch.nn.Module` that normalizes the last feature axis to unit length, multiplies by `√dim`, and applies a learned per-feature `gamma` parameter. Use as a pre-normalization layer before attention, feedforward, or linear projection sublayers in transformer-style modules. |
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torch-einops-kit"
3
- version = "0.1.0"
3
+ version = "0.1.2"
4
4
  description = "Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`. A superset of `lucidrains/torch-einops-utils` with similar utilities from other lucidrains packages. Adds strict typing, extensive tests, and comprehensive docstrings."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -10,15 +10,37 @@ authors = [
10
10
  { name = "Hunter Hogan", email = "HunterHogan@pm.me" }
11
11
  ]
12
12
  maintainers = [{ name = "Hunter Hogan", email = "HunterHogan@pm.me" }]
13
+ keywords = [
14
+ "artificial intelligence",
15
+ "einops",
16
+ "machine learning",
17
+ "pytorch",
18
+ "torch",
19
+ ]
20
+ classifiers = [
21
+ "Development Status :: 4 - Beta",
22
+ "Environment :: Console",
23
+ "Environment :: GPU :: NVIDIA CUDA",
24
+ "Intended Audience :: Developers",
25
+ "Intended Audience :: Science/Research",
26
+ "Natural Language :: English",
27
+ "Operating System :: OS Independent",
28
+ "Programming Language :: Python :: 3.10",
29
+ "Programming Language :: Python :: 3.14",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ "Topic :: Utilities",
32
+ "Typing :: Typed",
33
+ ]
13
34
  dependencies = [
14
35
  "einops>=0.8.2",
15
- "packaging>=26.0",
16
36
  "torch>=2.10.0",
17
37
  "typing-extensions>=4.0.0",
18
38
  ]
19
39
 
20
40
  [project.urls]
41
+ Documentation = "https://context7.com/hunterhogan/torch_einops_kit"
21
42
  Donate = "https://www.patreon.com/integrated"
43
+ Download = "https://pypi.org/project/torch_einops_kit"
22
44
  Homepage = "https://github.com/hunterhogan/torch_einops_kit"
23
45
  Issues = "https://github.com/hunterhogan/torch_einops_kit/issues"
24
46
  Repository = "https://github.com/hunterhogan/torch_einops_kit.git"
@@ -92,10 +92,6 @@ pad_sequence_and_cat
92
92
 
93
93
  Utilities
94
94
  ---------
95
- l2norm
96
- Normalize `Tensor` vectors to unit length along the last dimension.
97
- masked_mean
98
- Compute a mean over positions selected by a boolean mask.
99
95
  tree_flatten_with_inverse
100
96
  Flatten a PyTree and return an inverse reconstruction function.
101
97
  tree_map_tensor
@@ -110,7 +106,7 @@ einops
110
106
  save_load
111
107
  Decorate `torch.nn.Module` subclasses with checkpoint save, load, and reconstruction helpers.
112
108
  scaleValues
113
- Normalize feature vectors and compute masked means.
109
+ Compute exclusive prefix sums, normalize feature vectors, and compute masked means.
114
110
  """
115
111
  # isort: split
116
112
  from torch_einops_kit._semiotics import decreasing as decreasing, zeroIndexed as zeroIndexed
@@ -192,8 +192,8 @@ class DehydratedCheckpoint(TypedDict):
192
192
  `Module` instances with `DehydratedTorchNNModule` [9] reconstruction records.
193
193
  version : str | None
194
194
  An optional version string written at save time. When both the stored version and the
195
- `save_load` `version` argument are set and differ under `packaging.version.parse` [10], the
196
- generated load method prints a notice but still restores model state.
195
+ `save_load` `version` argument are different, the generated load method prints a notice but
196
+ still restores model state.
197
197
 
198
198
  See Also
199
199
  --------
@@ -220,9 +220,6 @@ class DehydratedCheckpoint(TypedDict):
220
220
  [8] torch_einops_kit.save_load.dehydrate_config
221
221
 
222
222
  [9] torch_einops_kit.DehydratedTorchNNModule
223
-
224
- [10] packaging.version - packaging documentation
225
- https://packaging.pypa.io/en/stable/version.html
226
223
  """
227
224
  model: dict[str, Tensor]
228
225
  config: bytes
@@ -20,7 +20,6 @@ from __future__ import annotations
20
20
 
21
21
  from collections.abc import Callable
22
22
  from functools import wraps
23
- from packaging import version as packaging_version
24
23
  from pathlib import Path
25
24
  from torch.nn import Module
26
25
  from torch_einops_kit import (
@@ -299,13 +298,12 @@ def save_load(
299
298
  -----
300
299
  The generated save method writes a dictionary that is compatible with `torch.save` [5], and the
301
300
  generated load paths read the dictionary with `torch.load` [6] on CPU before restoring state.
302
- When `version` and the stored checkpoint version both exist and differ under
303
- `packaging.version.parse` [7], the generated load method prints a notice but still restores the
304
- stored model state.
301
+ When `version` and the stored checkpoint version are different, the generated load method prints
302
+ a notice but still restores the stored model state.
305
303
 
306
304
  Examples
307
305
  --------
308
- From `tests.test_save_load.test_init_and_load` [8]:
306
+ From `tests.test_save_load.test_init_and_load` [7]:
309
307
 
310
308
  ```python
311
309
  from pathlib import Path
@@ -327,7 +325,7 @@ def save_load(
327
325
  restored_model = SimpleNet.init_and_load(str(path))
328
326
  ```
329
327
 
330
- From `tests.test_save_load_extended` [9]:
328
+ From `tests.test_save_load_extended` [8]:
331
329
 
332
330
  ```python
333
331
  import torch
@@ -367,11 +365,9 @@ def save_load(
367
365
  https://pytorch.org/docs/stable/generated/torch.save.html
368
366
  [6] torch.load - PyTorch documentation
369
367
  https://pytorch.org/docs/stable/generated/torch.load.html
370
- [7] packaging.version - packaging documentation
371
- https://packaging.pypa.io/en/stable/version.html
372
- [8] tests.test_save_load.test_init_and_load
368
+ [7] tests.test_save_load.test_init_and_load
373
369
 
374
- [9] tests.test_save_load_extended.test_save_load_supports_custom_method_names_and_config_storage
370
+ [8] tests.test_save_load_extended.test_save_load_supports_custom_method_names_and_config_storage
375
371
  """
376
372
  def _save_load(klass: type[TorchNNModule]) -> type[TorchNNModule]:
377
373
  if not issubclass(klass, Module):
@@ -435,10 +431,8 @@ def save_load(
435
431
  """Restore model state from a checkpoint file.
436
432
 
437
433
  You can use this method to load parameter values into an already-constructed decorated
438
- module instance. The method reads the checkpoint via `torch.load` [1] on CPU, emits a
439
- `UserWarning` when the stored version and the decoration-time version both exist and
440
- differ under `packaging.version.parse` [2], and then restores parameter values via
441
- `load_state_dict`.
434
+ module instance. The method reads the checkpoint via `torch.load` [1] on CPU, then
435
+ restores parameter values via `load_state_dict`.
442
436
 
443
437
  Parameters
444
438
  ----------
@@ -460,16 +454,13 @@ def save_load(
460
454
  Warns
461
455
  -----
462
456
  UserWarning
463
- Emitted when the checkpoint's stored version and the decoration-time version both
464
- exist and differ under `packaging.version.parse` [2].
457
+ Emitted when the checkpoint's stored version and the decoration-time version are
458
+ different [2].
465
459
 
466
460
  References
467
461
  ----------
468
462
  [1] torch.load - PyTorch documentation
469
463
  https://pytorch.org/docs/stable/generated/torch.load.html
470
-
471
- [2] packaging.version - packaging documentation
472
- https://packaging.pypa.io/en/stable/version.html
473
464
  """
474
465
  path = Path(path)
475
466
  if not path.exists():
@@ -478,7 +469,7 @@ def save_load(
478
469
 
479
470
  pkg: DehydratedCheckpoint = torch.load(str(path), map_location = 'cpu')
480
471
 
481
- if exists(version) and exists(pkg['version']) and packaging_version.parse(version) != packaging_version.parse(pkg['version']):
472
+ if exists(version) and exists(pkg['version']) and (version != pkg['version']):
482
473
  message: str = f'I received a checkpoint saved at version `{pkg["version"]}`, but the current package version is `{version}`.'
483
474
  warnings.warn(message, UserWarning, stacklevel=2)
484
475
 
@@ -537,7 +528,7 @@ def save_load(
537
528
  # set decorated init as well as save, load, and init_and_load
538
529
 
539
530
  klass.__init__ = __init__ # ty:ignore[invalid-assignment]
540
- # TODO figure out how to use something like `wraps` to get the signature and docstring are public.
531
+ # TODO figure out how to use something like `wraps` so the signature and docstring are public.
541
532
  setattr(klass, save_method_name, _save)
542
533
  setattr(klass, load_method_name, _load)
543
534
  setattr(klass, init_and_load_classmethod_name, _init_and_load_from)
@@ -1,12 +1,14 @@
1
- """Provide vector normalization, masked mean computation, and learned RMS normalization.
1
+ """Provide exclusive prefix sums, vector normalization, masked mean computation, and learned RMS normalization.
2
2
 
3
- You can use this module to normalize feature vectors to unit length, compute masked mean
4
- reductions over selected tensor positions, and apply learned root-mean-square normalization
5
- to transformer and neural network feature channels.
3
+ You can use this module to compute exclusive prefix sums of tensors, normalize feature vectors
4
+ to unit length, compute masked mean reductions over selected tensor positions, and apply learned
5
+ root-mean-square normalization to transformer and neural network feature channels.
6
6
 
7
7
  Contents
8
8
  --------
9
9
  Functions
10
+ exclusive_cumsum
11
+ Compute the exclusive prefix sum of a `Tensor` along a dimension.
10
12
  l2norm
11
13
  Normalize `Tensor` vectors to unit length along the last dimension.
12
14
  masked_mean
@@ -24,6 +26,47 @@ from torch_einops_kit import exists, pad_right_ndim
24
26
  import torch
25
27
  import torch.nn.functional as F
26
28
 
29
+ def exclusive_cumsum(t: Tensor, dim: int = -1) -> Tensor:
30
+ """Compute the exclusive prefix sum of `Tensor` `t` along its dimension `dim`.
31
+
32
+ You can use `exclusive_cumsum` to produce a shifted cumulative sum where each output
33
+ position accumulates only the elements that strictly precede it along `dim`. The element
34
+ at index zero is always zero. This operation is useful for converting absolute segment
35
+ lengths into starting offsets.
36
+
37
+ Parameters
38
+ ----------
39
+ t : Tensor
40
+ Input `Tensor` to reduce.
41
+ dim : int = -1
42
+ Dimension along which to compute the exclusive prefix sum. Negative values index
43
+ from the last axis.
44
+
45
+ Returns
46
+ -------
47
+ exclusivePrefixSum : Tensor
48
+ `Tensor` with the same shape as `t` where `exclusivePrefixSum[..., i, ...]` equals
49
+ the sum of all elements of `t` with index strictly less than `i` along `dim`.
50
+
51
+ Mathematical Basis
52
+ ------------------
53
+ Let t = [t₀, t₁, ..., tₙ₋₁] be the elements of `t` along `dim`. The exclusive prefix
54
+ sum S is defined element-wise as:
55
+
56
+ S[i] = Σⱼ₌₀^(i−1) tⱼ, with S[0] = 0.
57
+
58
+ Equivalently, S equals the inclusive cumulative sum minus `t`:
59
+
60
+ S = cumsum(t, dim) − t.
61
+
62
+ References
63
+ ----------
64
+ [1] torch.Tensor.cumsum - PyTorch documentation
65
+ https://pytorch.org/docs/stable/generated/torch.Tensor.cumsum.html
66
+ """
67
+ return t.cumsum(dim = dim) - t
68
+
69
+
27
70
  def l2norm(t: Tensor) -> Tensor:
28
71
  """Normalize `Tensor` vectors to unit length.
29
72