torch-einops-utils 0.0.24__tar.gz → 0.0.25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/PKG-INFO +1 -1
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/pyproject.toml +1 -1
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/tests/test_device.py +9 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/tests/test_utils.py +3 -2
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/torch_einops_utils/__init__.py +2 -1
- torch_einops_utils-0.0.25/torch_einops_utils/device.py +38 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/torch_einops_utils/torch_einops_utils.py +11 -0
- torch_einops_utils-0.0.24/torch_einops_utils/device.py +0 -20
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/.github/workflows/python-publish.yml +0 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/.github/workflows/test.yml +0 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/.gitignore +0 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/LICENSE +0 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/README.md +0 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/tests/test_save_load.py +0 -0
- {torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/torch_einops_utils/save_load.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-einops-utils
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.25
|
|
4
4
|
Summary: Personal utility functions
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
|
|
@@ -6,3 +6,12 @@ def test_module_device():
|
|
|
6
6
|
|
|
7
7
|
assert module_device(nn.Linear(3, 3)) == torch.device('cpu')
|
|
8
8
|
assert module_device(nn.Identity()) is None
|
|
9
|
+
|
|
10
|
+
def test_move_input_to_device():
|
|
11
|
+
from torch_einops_utils.device import move_inputs_to_device
|
|
12
|
+
|
|
13
|
+
def fn(t):
|
|
14
|
+
return t
|
|
15
|
+
|
|
16
|
+
decorated = move_inputs_to_device(torch.device('cpu'))(fn)
|
|
17
|
+
decorated(torch.randn(3))
|
|
@@ -17,6 +17,7 @@ from torch_einops_utils.torch_einops_utils import (
|
|
|
17
17
|
pad_left_at_dim_to,
|
|
18
18
|
pad_right_at_dim_to,
|
|
19
19
|
pad_sequence,
|
|
20
|
+
pad_sequence_and_cat,
|
|
20
21
|
lens_to_mask,
|
|
21
22
|
and_masks,
|
|
22
23
|
or_masks,
|
|
@@ -138,8 +139,8 @@ def test_pad_sequence_uneven_images():
|
|
|
138
139
|
assert len(padded_height) == 3
|
|
139
140
|
assert all([t.shape[1] == 17 for t in padded_height])
|
|
140
141
|
|
|
141
|
-
stacked =
|
|
142
|
-
assert stacked.shape == (
|
|
142
|
+
stacked = pad_sequence_and_cat(padded_height, dim_cat = 0)
|
|
143
|
+
assert stacked.shape == (9, 17, 18)
|
|
143
144
|
|
|
144
145
|
def test_and_masks():
|
|
145
146
|
assert not exists(and_masks([None]))
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from itertools import chain
|
|
2
|
+
from functools import wraps
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn import Module
|
|
6
|
+
|
|
7
|
+
from torch_einops_utils.torch_einops_utils import tree_map_tensor
|
|
8
|
+
|
|
9
|
+
# helpers
|
|
10
|
+
|
|
11
|
+
def exists(v):
|
|
12
|
+
return v is not None
|
|
13
|
+
|
|
14
|
+
# infer the device for a module
|
|
15
|
+
|
|
16
|
+
def module_device(m: Module):
|
|
17
|
+
|
|
18
|
+
first_param_or_buffer = next(chain(m.parameters(), m.buffers()), None)
|
|
19
|
+
|
|
20
|
+
if not exists(first_param_or_buffer):
|
|
21
|
+
return None
|
|
22
|
+
|
|
23
|
+
return first_param_or_buffer.device
|
|
24
|
+
|
|
25
|
+
# moving all inputs into a function onto a device
|
|
26
|
+
|
|
27
|
+
def move_inputs_to_device(device):
|
|
28
|
+
|
|
29
|
+
def decorator(fn):
|
|
30
|
+
@wraps(fn)
|
|
31
|
+
def inner(*args, **kwargs):
|
|
32
|
+
args, kwargs = tree_map_tensor(lambda t: t.to(device), (args, kwargs))
|
|
33
|
+
|
|
34
|
+
return fn(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
return inner
|
|
37
|
+
|
|
38
|
+
return decorator
|
{torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/torch_einops_utils/torch_einops_utils.py
RENAMED
|
@@ -271,6 +271,17 @@ def pad_sequence(
|
|
|
271
271
|
|
|
272
272
|
return output, lens
|
|
273
273
|
|
|
274
|
+
def pad_sequence_and_cat(
|
|
275
|
+
tensors,
|
|
276
|
+
*,
|
|
277
|
+
dim_cat = 0,
|
|
278
|
+
**kwargs
|
|
279
|
+
):
|
|
280
|
+
assert 'return_stacked' not in kwargs
|
|
281
|
+
|
|
282
|
+
padded = pad_sequence(tensors, return_stacked = False, **kwargs)
|
|
283
|
+
return cat(padded, dim = dim_cat)
|
|
284
|
+
|
|
274
285
|
# tree flatten with inverse
|
|
275
286
|
|
|
276
287
|
def tree_map_tensor(fn, tree):
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
from itertools import chain
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch.nn import Module
|
|
5
|
-
|
|
6
|
-
# helpers
|
|
7
|
-
|
|
8
|
-
def exists(v):
|
|
9
|
-
return v is not None
|
|
10
|
-
|
|
11
|
-
# infer the device for a module
|
|
12
|
-
|
|
13
|
-
def module_device(m: Module):
|
|
14
|
-
|
|
15
|
-
first_param_or_buffer = next(chain(m.parameters(), m.buffers()), None)
|
|
16
|
-
|
|
17
|
-
if not exists(first_param_or_buffer):
|
|
18
|
-
return None
|
|
19
|
-
|
|
20
|
-
return first_param_or_buffer.device
|
{torch_einops_utils-0.0.24 → torch_einops_utils-0.0.25}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|