torch-einops-utils 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_einops_utils/__init__.py +2 -1
- torch_einops_utils/device.py +18 -0
- torch_einops_utils/torch_einops_utils.py +11 -0
- {torch_einops_utils-0.0.24.dist-info → torch_einops_utils-0.0.25.dist-info}/METADATA +1 -1
- torch_einops_utils-0.0.25.dist-info/RECORD +8 -0
- torch_einops_utils-0.0.24.dist-info/RECORD +0 -8
- {torch_einops_utils-0.0.24.dist-info → torch_einops_utils-0.0.25.dist-info}/WHEEL +0 -0
- {torch_einops_utils-0.0.24.dist-info → torch_einops_utils-0.0.25.dist-info}/licenses/LICENSE +0 -0
torch_einops_utils/__init__.py
CHANGED
torch_einops_utils/device.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from itertools import chain
|
|
2
|
+
from functools import wraps
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch.nn import Module
|
|
5
6
|
|
|
7
|
+
from torch_einops_utils.torch_einops_utils import tree_map_tensor
|
|
8
|
+
|
|
6
9
|
# helpers
|
|
7
10
|
|
|
8
11
|
def exists(v):
|
|
@@ -18,3 +21,18 @@ def module_device(m: Module):
|
|
|
18
21
|
return None
|
|
19
22
|
|
|
20
23
|
return first_param_or_buffer.device
|
|
24
|
+
|
|
25
|
+
# moving all inputs into a function onto a device
|
|
26
|
+
|
|
27
|
+
def move_inputs_to_device(device):
|
|
28
|
+
|
|
29
|
+
def decorator(fn):
|
|
30
|
+
@wraps(fn)
|
|
31
|
+
def inner(*args, **kwargs):
|
|
32
|
+
args, kwargs = tree_map_tensor(lambda t: t.to(device), (args, kwargs))
|
|
33
|
+
|
|
34
|
+
return fn(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
return inner
|
|
37
|
+
|
|
38
|
+
return decorator
|
|
@@ -271,6 +271,17 @@ def pad_sequence(
|
|
|
271
271
|
|
|
272
272
|
return output, lens
|
|
273
273
|
|
|
274
|
+
def pad_sequence_and_cat(
|
|
275
|
+
tensors,
|
|
276
|
+
*,
|
|
277
|
+
dim_cat = 0,
|
|
278
|
+
**kwargs
|
|
279
|
+
):
|
|
280
|
+
assert 'return_stacked' not in kwargs
|
|
281
|
+
|
|
282
|
+
padded = pad_sequence(tensors, return_stacked = False, **kwargs)
|
|
283
|
+
return cat(padded, dim = dim_cat)
|
|
284
|
+
|
|
274
285
|
# tree flatten with inverse
|
|
275
286
|
|
|
276
287
|
def tree_map_tensor(fn, tree):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-einops-utils
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.25
|
|
4
4
|
Summary: Personal utility functions
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
torch_einops_utils/__init__.py,sha256=RBqs2yPtP51KVi73b-D75S7rIbQJpUOh1IoE3G9j-9E,1008
|
|
2
|
+
torch_einops_utils/device.py,sha256=r4hNXVluxMqk0awCtEMmzbBR5-uJ0KBxPIKGPzVsSwU,794
|
|
3
|
+
torch_einops_utils/save_load.py,sha256=K-i7nmLyXBHdAfBLN3rGQzI3NVf6RRwF_GcrKQnfQsc,2669
|
|
4
|
+
torch_einops_utils/torch_einops_utils.py,sha256=GG1QMs-bH1y3twJ9eu7brmO4fE8tMrmXN_SdfkLD2ZI,6728
|
|
5
|
+
torch_einops_utils-0.0.25.dist-info/METADATA,sha256=FqPrFqbyuWx18fr_TiAsVZj3p_b_NwN6YWaGXs4oIuo,2139
|
|
6
|
+
torch_einops_utils-0.0.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
torch_einops_utils-0.0.25.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
|
|
8
|
+
torch_einops_utils-0.0.25.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
torch_einops_utils/__init__.py,sha256=l7S9xJOlbMfwofQD3RI2Z4QxT3H713XSPc1zN_kF4Ww,982
|
|
2
|
-
torch_einops_utils/device.py,sha256=2BFn_RqyNqsThlvjh_Un4bNSkYovUHlNzKF0scq6gaY,366
|
|
3
|
-
torch_einops_utils/save_load.py,sha256=K-i7nmLyXBHdAfBLN3rGQzI3NVf6RRwF_GcrKQnfQsc,2669
|
|
4
|
-
torch_einops_utils/torch_einops_utils.py,sha256=raKFiWpRgGE7Vf7tTqkokU9525gr54lJMpQy_jnUTVI,6498
|
|
5
|
-
torch_einops_utils-0.0.24.dist-info/METADATA,sha256=7WdwsdDrjmI1YuCpyGUkDN0wh1N4X3vyhPbF2mWdK6I,2139
|
|
6
|
-
torch_einops_utils-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
-
torch_einops_utils-0.0.24.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
|
|
8
|
-
torch_einops_utils-0.0.24.dist-info/RECORD,,
|
|
File without changes
|
{torch_einops_utils-0.0.24.dist-info → torch_einops_utils-0.0.25.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|