torch-einops-utils 0.0.23__tar.gz → 0.0.25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-utils
3
- Version: 0.0.23
3
+ Version: 0.0.25
4
4
  Summary: Personal utility functions
5
5
  Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
6
6
  Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torch-einops-utils"
3
- version = "0.0.23"
3
+ version = "0.0.25"
4
4
  description = "Personal utility functions"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,17 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ def test_module_device():
5
+ from torch_einops_utils.device import module_device
6
+
7
+ assert module_device(nn.Linear(3, 3)) == torch.device('cpu')
8
+ assert module_device(nn.Identity()) is None
9
+
10
+ def test_move_input_to_device():
11
+ from torch_einops_utils.device import move_inputs_to_device
12
+
13
+ def fn(t):
14
+ return t
15
+
16
+ decorated = move_inputs_to_device(torch.device('cpu'))(fn)
17
+ decorated(torch.randn(3))
@@ -17,6 +17,7 @@ from torch_einops_utils.torch_einops_utils import (
17
17
  pad_left_at_dim_to,
18
18
  pad_right_at_dim_to,
19
19
  pad_sequence,
20
+ pad_sequence_and_cat,
20
21
  lens_to_mask,
21
22
  and_masks,
22
23
  or_masks,
@@ -138,8 +139,8 @@ def test_pad_sequence_uneven_images():
138
139
  assert len(padded_height) == 3
139
140
  assert all([t.shape[1] == 17 for t in padded_height])
140
141
 
141
- stacked = pad_sequence(padded_height, dim = -1, return_stacked = True)
142
- assert stacked.shape == (3, 3, 17, 18)
142
+ stacked = pad_sequence_and_cat(padded_height, dim_cat = 0)
143
+ assert stacked.shape == (9, 17, 18)
143
144
 
144
145
  def test_and_masks():
145
146
  assert not exists(and_masks([None]))
@@ -38,7 +38,8 @@ from torch_einops_utils.torch_einops_utils import (
38
38
  pad_right_at_dim,
39
39
  pad_left_at_dim_to,
40
40
  pad_right_at_dim_to,
41
- pad_sequence
41
+ pad_sequence,
42
+ pad_sequence_and_cat
42
43
  )
43
44
 
44
45
  from torch_einops_utils.torch_einops_utils import (
@@ -0,0 +1,38 @@
1
+ from itertools import chain
2
+ from functools import wraps
3
+
4
+ import torch
5
+ from torch.nn import Module
6
+
7
+ from torch_einops_utils.torch_einops_utils import tree_map_tensor
8
+
9
+ # helpers
10
+
11
+ def exists(v):
12
+ return v is not None
13
+
14
+ # infer the device for a module
15
+
16
+ def module_device(m: Module):
17
+
18
+ first_param_or_buffer = next(chain(m.parameters(), m.buffers()), None)
19
+
20
+ if not exists(first_param_or_buffer):
21
+ return None
22
+
23
+ return first_param_or_buffer.device
24
+
25
+ # moving all inputs into a function onto a device
26
+
27
+ def move_inputs_to_device(device):
28
+
29
+ def decorator(fn):
30
+ @wraps(fn)
31
+ def inner(*args, **kwargs):
32
+ args, kwargs = tree_map_tensor(lambda t: t.to(device), (args, kwargs))
33
+
34
+ return fn(*args, **kwargs)
35
+
36
+ return inner
37
+
38
+ return decorator
@@ -271,6 +271,17 @@ def pad_sequence(
271
271
 
272
272
  return output, lens
273
273
 
274
+ def pad_sequence_and_cat(
275
+ tensors,
276
+ *,
277
+ dim_cat = 0,
278
+ **kwargs
279
+ ):
280
+ assert 'return_stacked' not in kwargs
281
+
282
+ padded = pad_sequence(tensors, return_stacked = False, **kwargs)
283
+ return cat(padded, dim = dim_cat)
284
+
274
285
  # tree flatten with inverse
275
286
 
276
287
  def tree_map_tensor(fn, tree):