torch-einops-utils 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -38,7 +38,8 @@ from torch_einops_utils.torch_einops_utils import (
38
38
  pad_right_at_dim,
39
39
  pad_left_at_dim_to,
40
40
  pad_right_at_dim_to,
41
- pad_sequence
41
+ pad_sequence,
42
+ pad_sequence_and_cat
42
43
  )
43
44
 
44
45
  from torch_einops_utils.torch_einops_utils import (
@@ -1,8 +1,11 @@
1
1
  from itertools import chain
2
+ from functools import wraps
2
3
 
3
4
  import torch
4
5
  from torch.nn import Module
5
6
 
7
+ from torch_einops_utils.torch_einops_utils import tree_map_tensor
8
+
6
9
  # helpers
7
10
 
8
11
  def exists(v):
@@ -18,3 +21,18 @@ def module_device(m: Module):
18
21
  return None
19
22
 
20
23
  return first_param_or_buffer.device
24
+
25
+ # moving all inputs into a function onto a device
26
+
27
+ def move_inputs_to_device(device):
28
+
29
+ def decorator(fn):
30
+ @wraps(fn)
31
+ def inner(*args, **kwargs):
32
+ args, kwargs = tree_map_tensor(lambda t: t.to(device), (args, kwargs))
33
+
34
+ return fn(*args, **kwargs)
35
+
36
+ return inner
37
+
38
+ return decorator
@@ -271,6 +271,17 @@ def pad_sequence(
271
271
 
272
272
  return output, lens
273
273
 
274
+ def pad_sequence_and_cat(
275
+ tensors,
276
+ *,
277
+ dim_cat = 0,
278
+ **kwargs
279
+ ):
280
+ assert 'return_stacked' not in kwargs
281
+
282
+ padded = pad_sequence(tensors, return_stacked = False, **kwargs)
283
+ return cat(padded, dim = dim_cat)
284
+
274
285
  # tree flatten with inverse
275
286
 
276
287
  def tree_map_tensor(fn, tree):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-utils
3
- Version: 0.0.24
3
+ Version: 0.0.25
4
4
  Summary: Personal utility functions
5
5
  Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
6
6
  Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
@@ -0,0 +1,8 @@
1
+ torch_einops_utils/__init__.py,sha256=RBqs2yPtP51KVi73b-D75S7rIbQJpUOh1IoE3G9j-9E,1008
2
+ torch_einops_utils/device.py,sha256=r4hNXVluxMqk0awCtEMmzbBR5-uJ0KBxPIKGPzVsSwU,794
3
+ torch_einops_utils/save_load.py,sha256=K-i7nmLyXBHdAfBLN3rGQzI3NVf6RRwF_GcrKQnfQsc,2669
4
+ torch_einops_utils/torch_einops_utils.py,sha256=GG1QMs-bH1y3twJ9eu7brmO4fE8tMrmXN_SdfkLD2ZI,6728
5
+ torch_einops_utils-0.0.25.dist-info/METADATA,sha256=FqPrFqbyuWx18fr_TiAsVZj3p_b_NwN6YWaGXs4oIuo,2139
6
+ torch_einops_utils-0.0.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ torch_einops_utils-0.0.25.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
8
+ torch_einops_utils-0.0.25.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- torch_einops_utils/__init__.py,sha256=l7S9xJOlbMfwofQD3RI2Z4QxT3H713XSPc1zN_kF4Ww,982
2
- torch_einops_utils/device.py,sha256=2BFn_RqyNqsThlvjh_Un4bNSkYovUHlNzKF0scq6gaY,366
3
- torch_einops_utils/save_load.py,sha256=K-i7nmLyXBHdAfBLN3rGQzI3NVf6RRwF_GcrKQnfQsc,2669
4
- torch_einops_utils/torch_einops_utils.py,sha256=raKFiWpRgGE7Vf7tTqkokU9525gr54lJMpQy_jnUTVI,6498
5
- torch_einops_utils-0.0.24.dist-info/METADATA,sha256=7WdwsdDrjmI1YuCpyGUkDN0wh1N4X3vyhPbF2mWdK6I,2139
6
- torch_einops_utils-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- torch_einops_utils-0.0.24.dist-info/licenses/LICENSE,sha256=e6AOF7Z8EFdK3IdcL0x0fLw4cY7Q0d0kNR0o0TmBewM,1066
8
- torch_einops_utils-0.0.24.dist-info/RECORD,,