torch-einops-kit 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/PKG-INFO +33 -13
  2. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/README.md +30 -10
  3. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/pyproject.toml +3 -3
  4. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/__init__.py +13 -4
  5. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/_cat_and_stack.py +15 -13
  6. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/_dimensions.py +19 -18
  7. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/_helpers.py +48 -43
  8. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/_masking.py +22 -28
  9. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/_padding.py +59 -34
  10. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/_semiotics.py +3 -0
  11. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/_slicing.py +16 -17
  12. torch_einops_kit-0.1.2/src/torch_einops_kit/_types.py → torch_einops_kit-0.1.4/src/torch_einops_kit/_theTypes.py +6 -2
  13. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/device.py +25 -26
  14. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/einops.py +16 -11
  15. torch_einops_kit-0.1.4/src/torch_einops_kit/nn.py +178 -0
  16. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/save_load.py +32 -34
  17. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/scaleValues.py +37 -23
  18. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/utils.py +12 -9
  19. {torch_einops_kit-0.1.2 → torch_einops_kit-0.1.4}/src/torch_einops_kit/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-kit
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`. A superset of `lucidrains/torch-einops-utils` with similar utilities from other lucidrains packages. Adds strict typing, extensive tests, and comprehensive docstrings.
5
5
  Keywords: artificial intelligence,einops,machine learning,pytorch,torch
6
6
  Author: Phil Wang, Hunter Hogan
@@ -19,12 +19,12 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
19
  Classifier: Topic :: Utilities
20
20
  Classifier: Typing :: Typed
21
21
  Requires-Dist: einops>=0.8.2
22
- Requires-Dist: torch>=2.10.0
22
+ Requires-Dist: torch>=2.0.0
23
23
  Requires-Dist: typing-extensions>=4.0.0
24
24
  Maintainer: Hunter Hogan
25
25
  Maintainer-email: Hunter Hogan <HunterHogan@pm.me>
26
26
  Requires-Python: >=3.10
27
- Project-URL: Documentation, https://context7.com/hunterhogan/torch_einops_kit
27
+ Project-URL: Context7, https://context7.com/hunterhogan/torch_einops_kit
28
28
  Project-URL: Donate, https://www.patreon.com/integrated
29
29
  Project-URL: Download, https://pypi.org/project/torch_einops_kit
30
30
  Project-URL: Homepage, https://github.com/hunterhogan/torch_einops_kit
@@ -34,7 +34,7 @@ Description-Content-Type: text/markdown
34
34
 
35
35
  # torch_einops_kit
36
36
 
37
- Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`.
37
+ Typed tensor-shaping, masking, padding, device-routing, lightweight `nn.Module` adapters, and checkpoint utilities for PyTorch and `einops`.
38
38
 
39
39
  [![pip install torch-einops-kit](https://img.shields.io/badge/pip_install-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
40
40
  [![uv add torch-einops-kit](https://img.shields.io/badge/uv_add-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
@@ -43,7 +43,7 @@ This repository is a superset of [`lucidrains/torch-einops-utils`](https://githu
43
43
 
44
44
  `torch_einops_kit` is most useful when combined with other lucidrains repositories. Repositories such as [`dreamer4`](https://github.com/lucidrains/dreamer4), [`metacontroller`](https://github.com/lucidrains/metacontroller), [`mimic-video`](https://github.com/lucidrains/mimic-video), [`pi-zero-pytorch`](https://github.com/lucidrains/pi-zero-pytorch), [`sdft-pytorch`](https://github.com/lucidrains/sdft-pytorch), and [`locoformer`](https://github.com/lucidrains/locoformer) repeatedly need operations such as `align_dims_left`, `shape_with_replace`, `lens_to_mask`, `pad_sequence`, `safe_cat`, and `pack_with_inverse`. This package centralizes those operations in one typed import surface instead of re-implementing the same tensor utility layer in each model repository.
45
45
 
46
- If you already know `torch-einops-utils`, `torch_einops_kit` began as a typed substitute for that package and has since grown into a superset. In addition to everything from `torch-einops-utils`, this repository centralizes small utility functions that appear repeatedly in other lucidrains model repositories but were never collected in one place, such as `l2norm`, `once`, `pack_one`, and `unpack_one`. The function family remains the same kind: small PyTorch and `einops` helpers for shape work, masks, padding, optional tensors, PyTree traversal, device routing, and checkpoint reconstruction. The import path is `torch_einops_kit`, not `torch_einops_utils`. The relationship is conceptual, not literal import-path compatibility.
46
+ If you already know `torch-einops-utils`, `torch_einops_kit` began as a typed substitute for that package and has since grown into a superset. In addition to everything from `torch-einops-utils`, this repository centralizes small utility functions that appear repeatedly in other lucidrains model repositories but were never collected in one place, such as `l2norm`, `once`, `pack_one`, and `unpack_one`. The function family remains the same kind: small PyTorch and `einops` helpers for shape work, masks, padding, optional tensors, PyTree traversal, device routing, lightweight `nn.Module` adaptation, and checkpoint reconstruction. The import path is `torch_einops_kit`, not `torch_einops_utils`. The relationship is conceptual, not literal import-path compatibility.
47
47
 
48
48
  Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused modules, extensive tests, and docstrings written for both humans and AI assistants. Use upstream when you want the most compact possible version of the same idea.
49
49
 
@@ -54,7 +54,7 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
54
54
  - Python requirement: `>=3.10`.
55
55
  - Runtime dependencies: `torch`, `einops`, and `typing-extensions`.
56
56
  - Root package exports: helper functions, slicing helpers, rank-alignment helpers, mask helpers, safe concatenation helpers, padding helpers, normalization helpers, and PyTree / `einops` helpers.
57
- - Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
57
+ - Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.nn`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
58
58
  - Typing status: the package ships a `py.typed` marker and the repository uses strict type checking.
59
59
  - Best fit: lucidrains-style model repositories that work with variable-length tensors, `einops` patterns, optional intermediate tensors, and nested `torch.nn.Module` graphs.
60
60
 
@@ -116,6 +116,16 @@ from torch_einops_kit.device import (
116
116
  )
117
117
  ```
118
118
 
119
+ Import lightweight `nn.Module` adapters from `torch_einops_kit.nn`:
120
+
121
+ ```python
122
+ from torch_einops_kit.nn import (
123
+ Identity,
124
+ Lambda,
125
+ Sequential,
126
+ )
127
+ ```
128
+
119
129
  Import checkpoint decorators from `torch_einops_kit.save_load`:
120
130
 
121
131
  ```python
@@ -126,7 +136,7 @@ from torch_einops_kit.save_load import (
126
136
  )
127
137
  ```
128
138
 
129
- Import checkpoint decorators from `torch_einops_kit.scaleValues`:
139
+ Import normalization and masked-reduction helpers from `torch_einops_kit.scaleValues`:
130
140
 
131
141
  ```python
132
142
  from torch_einops_kit.scaleValues import (
@@ -296,10 +306,10 @@ When `pad_lens=True` and `return_lens=True`, the second tensor contains padding
296
306
 
297
307
  ### PyTree helpers
298
308
 
299
- | Name | Contract |
300
- | --------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
301
- | `tree_map_tensor(fn, tree)` | Applies `fn` to every tensor leaf in a PyTree and leaves non-tensor leaves unchanged. |
302
- | `tree_flatten_with_inverse(tree)` | Returns a flat list of leaves and an inverse function that reconstructs the original PyTree shape from a replacement iterable of leaves. |
309
+ | Name | Contract |
310
+ | --------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
311
+ | `tree_map_tensor(fn, tree)` | Applies `fn` to every tensor leaf in a PyTree and leaves non-tensor leaves unchanged. |
312
+ | `tree_flatten_with_inverse(tree)` | Returns a flat list of leaves and an inverse function that reconstructs the original PyTree shape from a replacement iterable of leaves. |
303
313
 
304
314
  ## `scaleValues` submodule reference
305
315
 
@@ -332,6 +342,16 @@ The `torch_einops_kit.device` submodule contains three utilities for device infe
332
342
  | `move_inputs_to_device(device)` | Decorator that recursively moves every tensor inside positional and keyword arguments to `device` before calling the wrapped function. Non-tensor values pass through unchanged. |
333
343
  | `move_inputs_to_module_device(fn)` | Decorator for methods whose first argument is a `torch.nn.Module`. The decorator infers the target device with `module_device(self)` and moves every tensor argument after `self` to that device. If `module_device(self)` returns `None`, the call is a no-op. |
334
344
 
345
+ ## `nn` submodule reference
346
+
347
+ The `torch_einops_kit.nn` submodule contains lightweight `torch.nn.Module` adapters and a `Sequential` constructor that ignores `None` values.
348
+
349
+ | Name | Contract |
350
+ | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
351
+ | `Sequential(*modules)` | Equivalent to `nn.Sequential(*compact(modules))`. Every `None` argument is discarded before construction. |
352
+ | `Identity()` | `torch.nn.Module` wrapper around `identity`. Unlike `torch.nn.Identity`, `Identity.forward` accepts additional positional arguments and keyword arguments and returns the first positional argument unchanged. |
353
+ | `Lambda(fn)` | `torch.nn.Module` wrapper that stores `fn` on `self.fn` and delegates `forward(*args, **kwargs)` to `fn(*args, **kwargs)` without changing the argument structure. |
354
+
335
355
  ## `save_load` submodule reference
336
356
 
337
357
  The `torch_einops_kit.save_load` submodule contains the checkpoint decorator and the two advanced configuration helpers that support nested decorated modules.
@@ -406,7 +426,7 @@ If you are an AI assistant adapting code from `torch-einops-utils`, use these tr
406
426
  This repository is not a repackaged mirror of upstream. This repository makes a different trade-off.
407
427
 
408
428
  - Upstream is intentionally compact.
409
- - This fork splits the implementation across focused modules such as `_helpers.py`, `_padding.py`, `device.py`, and `save_load.py` while still re-exporting most tensor helpers from the package root.
429
+ - This fork splits the implementation across focused modules such as `_helpers.py`, `_padding.py`, `device.py`, `nn.py`, and `save_load.py` while still re-exporting most tensor helpers from the package root.
410
430
  - This fork adds strict typing, a `py.typed` marker, extensive tests, and detailed docstrings.
411
431
  - This fork is best treated as a typed, documented branch of the same utility idea rather than a literal import-path-compatible drop-in replacement.
412
432
 
@@ -414,6 +434,7 @@ This repository is not a repackaged mirror of upstream. This repository makes a
414
434
 
415
435
  - `src/torch_einops_kit/` — package source.
416
436
  - `src/torch_einops_kit/device.py` — device inference and input-routing decorators.
437
+ - `src/torch_einops_kit/nn.py` — lightweight `torch.nn.Module` adapters and a `Sequential` constructor that ignores `None` modules.
417
438
  - `src/torch_einops_kit/save_load.py` — checkpoint save / load decorator and nested reconstruction helpers.
418
439
  - `src/torch_einops_kit/scaleValues.py` — vector normalization, masked mean, and the `RMSNorm` layer.
419
440
  - `tests/` — regression tests and usage examples for helpers, masks, padding, device routing, and checkpoint reconstruction.
@@ -435,7 +456,6 @@ pytest
435
456
  Run static analysis:
436
457
 
437
458
  ```bash
438
- pyright
439
459
  ruff check .
440
460
  ```
441
461
 
@@ -1,6 +1,6 @@
1
1
  # torch_einops_kit
2
2
 
3
- Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`.
3
+ Typed tensor-shaping, masking, padding, device-routing, lightweight `nn.Module` adapters, and checkpoint utilities for PyTorch and `einops`.
4
4
 
5
5
  [![pip install torch-einops-kit](https://img.shields.io/badge/pip_install-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
6
6
  [![uv add torch-einops-kit](https://img.shields.io/badge/uv_add-torch--einops--kit-gray.svg?labelColor=blue)](https://pypi.org/project/torch-einops-kit/)
@@ -9,7 +9,7 @@ This repository is a superset of [`lucidrains/torch-einops-utils`](https://githu
9
9
 
10
10
  `torch_einops_kit` is most useful when combined with other lucidrains repositories. Repositories such as [`dreamer4`](https://github.com/lucidrains/dreamer4), [`metacontroller`](https://github.com/lucidrains/metacontroller), [`mimic-video`](https://github.com/lucidrains/mimic-video), [`pi-zero-pytorch`](https://github.com/lucidrains/pi-zero-pytorch), [`sdft-pytorch`](https://github.com/lucidrains/sdft-pytorch), and [`locoformer`](https://github.com/lucidrains/locoformer) repeatedly need operations such as `align_dims_left`, `shape_with_replace`, `lens_to_mask`, `pad_sequence`, `safe_cat`, and `pack_with_inverse`. This package centralizes those operations in one typed import surface instead of re-implementing the same tensor utility layer in each model repository.
11
11
 
12
- If you already know `torch-einops-utils`, `torch_einops_kit` began as a typed substitute for that package and has since grown into a superset. In addition to everything from `torch-einops-utils`, this repository centralizes small utility functions that appear repeatedly in other lucidrains model repositories but were never collected in one place, such as `l2norm`, `once`, `pack_one`, and `unpack_one`. The function family remains the same kind: small PyTorch and `einops` helpers for shape work, masks, padding, optional tensors, PyTree traversal, device routing, and checkpoint reconstruction. The import path is `torch_einops_kit`, not `torch_einops_utils`. The relationship is conceptual, not literal import-path compatibility.
12
+ If you already know `torch-einops-utils`, `torch_einops_kit` began as a typed substitute for that package and has since grown into a superset. In addition to everything from `torch-einops-utils`, this repository centralizes small utility functions that appear repeatedly in other lucidrains model repositories but were never collected in one place, such as `l2norm`, `once`, `pack_one`, and `unpack_one`. The function family remains the same kind: small PyTorch and `einops` helpers for shape work, masks, padding, optional tensors, PyTree traversal, device routing, lightweight `nn.Module` adaptation, and checkpoint reconstruction. The import path is `torch_einops_kit`, not `torch_einops_utils`. The relationship is conceptual, not literal import-path compatibility.
13
13
 
14
14
  Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused modules, extensive tests, and docstrings written for both humans and AI assistants. Use upstream when you want the most compact possible version of the same idea.
15
15
 
@@ -20,7 +20,7 @@ Use `torch_einops_kit` when you want strict typing, a `py.typed` marker, focused
20
20
  - Python requirement: `>=3.10`.
21
21
  - Runtime dependencies: `torch`, `einops`, and `typing-extensions`.
22
22
  - Root package exports: helper functions, slicing helpers, rank-alignment helpers, mask helpers, safe concatenation helpers, padding helpers, normalization helpers, and PyTree / `einops` helpers.
23
- - Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
23
+ - Submodules with dedicated imports: `torch_einops_kit.device`, `torch_einops_kit.einops`, `torch_einops_kit.nn`, `torch_einops_kit.save_load`, and `torch_einops_kit.scaleValues`.
24
24
  - Typing status: the package ships a `py.typed` marker and the repository uses strict type checking.
25
25
  - Best fit: lucidrains-style model repositories that work with variable-length tensors, `einops` patterns, optional intermediate tensors, and nested `torch.nn.Module` graphs.
26
26
 
@@ -82,6 +82,16 @@ from torch_einops_kit.device import (
82
82
  )
83
83
  ```
84
84
 
85
+ Import lightweight `nn.Module` adapters from `torch_einops_kit.nn`:
86
+
87
+ ```python
88
+ from torch_einops_kit.nn import (
89
+ Identity,
90
+ Lambda,
91
+ Sequential,
92
+ )
93
+ ```
94
+
85
95
  Import checkpoint decorators from `torch_einops_kit.save_load`:
86
96
 
87
97
  ```python
@@ -92,7 +102,7 @@ from torch_einops_kit.save_load import (
92
102
  )
93
103
  ```
94
104
 
95
- Import checkpoint decorators from `torch_einops_kit.scaleValues`:
105
+ Import normalization and masked-reduction helpers from `torch_einops_kit.scaleValues`:
96
106
 
97
107
  ```python
98
108
  from torch_einops_kit.scaleValues import (
@@ -262,10 +272,10 @@ When `pad_lens=True` and `return_lens=True`, the second tensor contains padding
262
272
 
263
273
  ### PyTree helpers
264
274
 
265
- | Name | Contract |
266
- | --------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
267
- | `tree_map_tensor(fn, tree)` | Applies `fn` to every tensor leaf in a PyTree and leaves non-tensor leaves unchanged. |
268
- | `tree_flatten_with_inverse(tree)` | Returns a flat list of leaves and an inverse function that reconstructs the original PyTree shape from a replacement iterable of leaves. |
275
+ | Name | Contract |
276
+ | --------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
277
+ | `tree_map_tensor(fn, tree)` | Applies `fn` to every tensor leaf in a PyTree and leaves non-tensor leaves unchanged. |
278
+ | `tree_flatten_with_inverse(tree)` | Returns a flat list of leaves and an inverse function that reconstructs the original PyTree shape from a replacement iterable of leaves. |
269
279
 
270
280
  ## `scaleValues` submodule reference
271
281
 
@@ -298,6 +308,16 @@ The `torch_einops_kit.device` submodule contains three utilities for device infe
298
308
  | `move_inputs_to_device(device)` | Decorator that recursively moves every tensor inside positional and keyword arguments to `device` before calling the wrapped function. Non-tensor values pass through unchanged. |
299
309
  | `move_inputs_to_module_device(fn)` | Decorator for methods whose first argument is a `torch.nn.Module`. The decorator infers the target device with `module_device(self)` and moves every tensor argument after `self` to that device. If `module_device(self)` returns `None`, the call is a no-op. |
300
310
 
311
+ ## `nn` submodule reference
312
+
313
+ The `torch_einops_kit.nn` submodule contains lightweight `torch.nn.Module` adapters and a `Sequential` constructor that ignores `None` values.
314
+
315
+ | Name | Contract |
316
+ | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
317
+ | `Sequential(*modules)` | Equivalent to `nn.Sequential(*compact(modules))`. Every `None` argument is discarded before construction. |
318
+ | `Identity()` | `torch.nn.Module` wrapper around `identity`. Unlike `torch.nn.Identity`, `Identity.forward` accepts additional positional arguments and keyword arguments and returns the first positional argument unchanged. |
319
+ | `Lambda(fn)` | `torch.nn.Module` wrapper that stores `fn` on `self.fn` and delegates `forward(*args, **kwargs)` to `fn(*args, **kwargs)` without changing the argument structure. |
320
+
301
321
  ## `save_load` submodule reference
302
322
 
303
323
  The `torch_einops_kit.save_load` submodule contains the checkpoint decorator and the two advanced configuration helpers that support nested decorated modules.
@@ -372,7 +392,7 @@ If you are an AI assistant adapting code from `torch-einops-utils`, use these tr
372
392
  This repository is not a repackaged mirror of upstream. This repository makes a different trade-off.
373
393
 
374
394
  - Upstream is intentionally compact.
375
- - This fork splits the implementation across focused modules such as `_helpers.py`, `_padding.py`, `device.py`, and `save_load.py` while still re-exporting most tensor helpers from the package root.
395
+ - This fork splits the implementation across focused modules such as `_helpers.py`, `_padding.py`, `device.py`, `nn.py`, and `save_load.py` while still re-exporting most tensor helpers from the package root.
376
396
  - This fork adds strict typing, a `py.typed` marker, extensive tests, and detailed docstrings.
377
397
  - This fork is best treated as a typed, documented branch of the same utility idea rather than a literal import-path-compatible drop-in replacement.
378
398
 
@@ -380,6 +400,7 @@ This repository is not a repackaged mirror of upstream. This repository makes a
380
400
 
381
401
  - `src/torch_einops_kit/` — package source.
382
402
  - `src/torch_einops_kit/device.py` — device inference and input-routing decorators.
403
+ - `src/torch_einops_kit/nn.py` — lightweight `torch.nn.Module` adapters and a `Sequential` constructor that ignores `None` modules.
383
404
  - `src/torch_einops_kit/save_load.py` — checkpoint save / load decorator and nested reconstruction helpers.
384
405
  - `src/torch_einops_kit/scaleValues.py` — vector normalization, masked mean, and the `RMSNorm` layer.
385
406
  - `tests/` — regression tests and usage examples for helpers, masks, padding, device routing, and checkpoint reconstruction.
@@ -401,7 +422,6 @@ pytest
401
422
  Run static analysis:
402
423
 
403
424
  ```bash
404
- pyright
405
425
  ruff check .
406
426
  ```
407
427
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torch-einops-kit"
3
- version = "0.1.2"
3
+ version = "0.1.4"
4
4
  description = "Typed tensor-shaping, masking, padding, device-routing, and checkpoint utilities for PyTorch and `einops`. A superset of `lucidrains/torch-einops-utils` with similar utilities from other lucidrains packages. Adds strict typing, extensive tests, and comprehensive docstrings."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -33,12 +33,12 @@ classifiers = [
33
33
  ]
34
34
  dependencies = [
35
35
  "einops>=0.8.2",
36
- "torch>=2.10.0",
36
+ "torch>=2.0.0",
37
37
  "typing-extensions>=4.0.0",
38
38
  ]
39
39
 
40
40
  [project.urls]
41
- Documentation = "https://context7.com/hunterhogan/torch_einops_kit"
41
+ Context7 = "https://context7.com/hunterhogan/torch_einops_kit"
42
42
  Donate = "https://www.patreon.com/integrated"
43
43
  Download = "https://pypi.org/project/torch_einops_kit"
44
44
  Homepage = "https://github.com/hunterhogan/torch_einops_kit"
@@ -1,8 +1,8 @@
1
- """Access PyTorch tensor-shaping, masking, padding, and checkpoint utilities.
1
+ """Access PyTorch tensor-shaping, masking, padding, `nn.Module` adaptation, and checkpoint utilities.
2
2
 
3
3
  You can use this package for optional-value handling, tensor slicing, rank alignment, mask
4
- construction, safe concatenation, sequence padding, PyTree traversal, and `einops`-pattern tensor
5
- packing.
4
+ construction, safe concatenation, sequence padding, PyTree traversal, lightweight `nn.Module`
5
+ adapters, and `einops` pattern tensor packing.
6
6
 
7
7
  Helpers
8
8
  -------
@@ -103,16 +103,21 @@ device
103
103
  Determine `torch.nn.Module` devices and decorate callables to move `Tensor` arguments automatically.
104
104
  einops
105
105
  Pack and unpack `Tensor` objects with `einops` patterns and paired inverse functions.
106
+ nn
107
+ Adapt callables and optional module sequences to the `torch.nn.Module` interface.
106
108
  save_load
107
109
  Decorate `torch.nn.Module` subclasses with checkpoint save, load, and reconstruction helpers.
108
110
  scaleValues
109
111
  Compute exclusive prefix sums, normalize feature vectors, and compute masked means.
110
112
  """
113
+
111
114
  # isort: split
115
+ from __future__ import annotations
116
+
112
117
  from torch_einops_kit._semiotics import decreasing as decreasing, zeroIndexed as zeroIndexed
113
118
 
114
119
  # isort: split
115
- from torch_einops_kit._types import (
120
+ from torch_einops_kit._theTypes import (
116
121
  ConfigArgsKwargs as ConfigArgsKwargs, DehydratedCheckpoint as DehydratedCheckpoint, DehydratedTorchNNModule as DehydratedTorchNNModule,
117
122
  DimAndValue as DimAndValue, IdentityCallable as IdentityCallable, PSpec as PSpec, RVar as RVar, StrPath as StrPath,
118
123
  SupportsIntIndex as SupportsIntIndex, T_co as T_co, TorchNNModule as TorchNNModule, TVar as TVar)
@@ -151,4 +156,8 @@ from torch_einops_kit.utils import tree_flatten_with_inverse as tree_flatten_wit
151
156
  # isort: split
152
157
  # NOTE These imports are for backwards compatibility. Linters ought to tell users to import from the correct submodules.
153
158
  from torch_einops_kit.einops import pack_with_inverse # pyright: ignore[reportUnusedImport]
159
+ from torch_einops_kit.nn import Identity, Lambda, Sequential # pyright: ignore[reportUnusedImport]
154
160
  from torch_einops_kit.scaleValues import l2norm, masked_mean # pyright: ignore[reportUnusedImport]
161
+
162
+ # NOTE `broadcat` is the identifier used in lucidrains packages.
163
+ broadcat = broadcast_cat
@@ -1,9 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Sequence
4
3
  from torch import broadcast_tensors, cat, stack, Tensor # pyright: ignore[reportUnknownVariableType]
5
4
  from torch_einops_kit import safe
6
- from typing import cast
5
+ from typing import cast, TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from collections.abc import Sequence
7
9
 
8
10
  def broadcast_cat(tensors: Sequence[Tensor], dim: int = -1) -> Tensor:
9
11
  """Broadcast tensor groups before concatenation.
@@ -29,10 +31,10 @@ def broadcast_cat(tensors: Sequence[Tensor], dim: int = -1) -> Tensor:
29
31
  [1] PyTorch - Context7
30
32
  https://context7.com/pytorch/pytorch
31
33
  """
32
- return cat(cast(list[Tensor], broadcast_tensors(*tensors)), dim)
34
+ return cat(cast('list[Tensor]', broadcast_tensors(*tensors)), dim)
33
35
 
34
36
  @safe
35
- def safe_cat(tensors: Sequence[Tensor], dim: int = 0) -> Tensor | None:
37
+ def safe_cat(tensors: tuple[Tensor, ...] | list[Tensor], dim: int = 0) -> Tensor | None:
36
38
  """Concatenate tensors from `tensors` along an existing dimension, skipping `None` values.
37
39
 
38
40
  You can use `safe_cat` to concatenate a mixed sequence of `Tensor` and `None` values. The `safe`
@@ -45,7 +47,7 @@ def safe_cat(tensors: Sequence[Tensor], dim: int = 0) -> Tensor | None:
45
47
 
46
48
  Parameters
47
49
  ----------
48
- tensors : Sequence[Tensor | None]
50
+ tensors : tuple[Tensor | None, ...] | list[Tensor | None]
49
51
  A sequence of `Tensor` or `None` values. `None` values are filtered out before concatenation.
50
52
  All non-`None` `Tensor` values must have the same shape in every dimension except `dim`.
51
53
  dim : int = 0
@@ -79,9 +81,9 @@ def safe_cat(tensors: Sequence[Tensor], dim: int = 0) -> Tensor | None:
79
81
  From sdft_pytorch [4], accumulating per-step token losses across a generation loop where
80
82
  `token_kl_div_losses` is initialized to `None` before the loop:
81
83
 
82
- ```python
84
+ ```python
83
85
  token_kl_div_losses = safe_cat((token_kl_div_losses, token_kl_div), dim=1)
84
- ```
86
+ ```
85
87
 
86
88
  References
87
89
  ----------
@@ -94,10 +96,10 @@ def safe_cat(tensors: Sequence[Tensor], dim: int = 0) -> Tensor | None:
94
96
  [4] lucidrains/sdft-pytorch
95
97
  https://github.com/lucidrains/sdft-pytorch
96
98
  """
97
- return cat(tensors, dim = dim) # pyright: ignore[reportUnknownVariableType, reportCallIssue, reportArgumentType] https://github.com/pytorch/pytorch/issues/179391 # ty:ignore[no-matching-overload]
99
+ return None if len(tensors) == 0 else cat(tensors, dim=dim)
98
100
 
99
101
  @safe
100
- def safe_stack(tensors: Sequence[Tensor], dim: int = 0) -> Tensor | None:
102
+ def safe_stack(tensors: tuple[Tensor, ...] | list[Tensor], dim: int = 0) -> Tensor | None:
101
103
  """Stack tensors from `tensors` along a new dimension, skipping `None` values.
102
104
 
103
105
  You can use `safe_stack` to stack a mixed sequence of `Tensor` and `None` values. The `safe` [1]
@@ -106,8 +108,8 @@ def safe_stack(tensors: Sequence[Tensor], dim: int = 0) -> Tensor | None:
106
108
 
107
109
  Parameters
108
110
  ----------
109
- tensors : Sequence[Tensor | None]
110
- A `Sequence` of `Tensor` or `None` values. `None` values are filtered out before stacking.
111
+ tensors : tuple[Tensor | None, ...] | list[Tensor | None]
112
+ A sequence of `Tensor` or `None` values. `None` values are filtered out before stacking.
111
113
  All non-`None` `Tensor` values must have the same shape.
112
114
  dim : int = 0
113
115
  The dimension along which to stack. The result has one more dimension than each input
@@ -163,10 +165,10 @@ def safe_stack(tensors: Sequence[Tensor], dim: int = 0) -> Tensor | None:
163
165
  [4] lucidrains/dreamer4
164
166
  https://github.com/lucidrains/dreamer4
165
167
  """
166
- return stack(tensors, dim = dim) # pyright: ignore[reportArgumentType] https://github.com/pytorch/pytorch/issues/179391 # ty:ignore[invalid-argument-type]
168
+ return None if len(tensors) == 0 else stack(tensors, dim=dim)
167
169
 
168
170
  """
169
- Some or all of the logic in this module may be protected by the following.
171
+ Some of the logic in this module may be protected by the following.
170
172
 
171
173
  MIT License
172
174
 
@@ -1,6 +1,11 @@
1
- from collections.abc import Sequence
2
- from torch import Tensor
1
+ from __future__ import annotations
2
+
3
3
  from torch_einops_kit import exists
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from collections.abc import Sequence
8
+ from torch import Tensor
4
9
 
5
10
  def pad_right_ndim(t: Tensor, ndims: int) -> Tensor:
6
11
  """Reshape a tensor by inserting singleton dimensions at the trailing side.
@@ -45,11 +50,7 @@ def pad_right_ndim(t: Tensor, ndims: int) -> Tensor:
45
50
  """
46
51
  return pad_ndim(t, (0, ndims))
47
52
 
48
- def align_dims_left(
49
- tensors: Sequence[Tensor],
50
- *,
51
- ndim: int | None = None,
52
- ) -> tuple[Tensor, ...]:
53
+ def align_dims_left(tensors: Sequence[Tensor], *, ndim: int | None = None) -> tuple[Tensor, ...]:
53
54
  """Pad all tensors in a sequence with trailing singleton dimensions to a common rank.
54
55
 
55
56
  You can use this function to align a heterogeneous sequence of tensors to the same number of
@@ -87,33 +88,33 @@ def align_dims_left(
87
88
  Align a PPO advantage tensor `(b, n)` with a log-probability ratio tensor `(b, n, d)` for
88
89
  element-wise multiplication:
89
90
 
90
- ```python
91
+ ```python
91
92
  from torch_einops_kit import align_dims_left
92
93
 
93
94
  # metacontroller: align ratio and advantages before the PPO surrogate loss
94
95
  ratio, advantages = align_dims_left((ratio, advantages))
95
96
  surr1 = ratio * advantages
96
- ```
97
+ ```
97
98
 
98
99
  Align a noise schedule `(b,)` with a latent tensor `(b, n, d)` for linear interpolation:
99
100
 
100
- ```python
101
+ ```python
101
102
  from torch_einops_kit import align_dims_left
102
103
 
103
104
  # dreamer4: align time with latents before noising
104
105
  aligned_times, _ = align_dims_left((times, latents))
105
106
  noised_latents = noise.lerp(latents, aligned_times)
106
- ```
107
+ ```
107
108
 
108
109
  Align a 1-D time value with an action tensor before flow-matching noise interpolation:
109
110
 
110
- ```python
111
+ ```python
111
112
  from torch_einops_kit import align_dims_left
112
113
 
113
114
  # mimic_video: align time with actions for noise interpolation
114
115
  actions, left_aligned_time = align_dims_left((actions, time))
115
116
  noised = noise.lerp(actions, left_aligned_time)
116
- ```
117
+ ```
117
118
 
118
119
  References
119
120
  ----------
@@ -179,7 +180,7 @@ def pad_ndim(t: Tensor, ndims: tuple[int, int]) -> Tensor:
179
180
  shape: tuple[int, ...] = t.shape
180
181
  left, right = ndims
181
182
  if left < 0 or right < 0:
182
- message: str = f"I received `{left = }` and `{right = }`, but I need both values to be greater than or equal to `0`."
183
+ message: str = f'I received `{left = }` and `{right = }`, but I need both values to be greater than or equal to `0`.'
183
184
  raise ValueError(message)
184
185
 
185
186
  ones: tuple[int] = (1,)
@@ -311,22 +312,22 @@ def pad_right_ndim_to(t: Tensor, ndims: int) -> Tensor:
311
312
  Broadcast a scalar time value against a video tensor of shape `(b, c, t, h, w)`
312
313
  for flow interpolation:
313
314
 
314
- ```python
315
+ ```python
315
316
  from torch_einops_kit import pad_right_ndim_to
316
317
 
317
318
  # dreamer4: align time '(b,)' with video '(b, c, t, h, w)'
318
319
  padded_time = pad_right_ndim_to(time[None], video.ndim)
319
320
  pred_flow = (pred_video - video) / (1.0 - padded_time)
320
- ```
321
+ ```
321
322
 
322
323
  Scale a flow prediction using a denominator with lower rank than the prediction:
323
324
 
324
- ```python
325
+ ```python
325
326
  from torch_einops_kit import pad_right_ndim_to
326
327
 
327
328
  # mimic_video: convert model output to flow space
328
329
  pred_flow = (pred - actions) / pad_right_ndim_to(1.0 - action_time, pred.ndim).clamp_min(eps)
329
- ```
330
+ ```
330
331
 
331
332
  References
332
333
  ----------