torch-einops-utils 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-utils
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: Personal utility functions
5
5
  Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
6
6
  Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torch-einops-utils"
3
- version = "0.1.2"
3
+ version = "0.1.4"
4
4
  description = "Personal utility functions"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  from torch import nn
3
- from torch_einops_utils.nn import Sequential, Lambda, Identity
3
+ from torch_einops_utils.nn import Sequential, Lambda, Identity, Residual
4
4
 
5
5
  def test_sequential():
6
6
  # Test that it filters out None
@@ -26,3 +26,18 @@ def test_identity():
26
26
  ident = Identity()
27
27
  x = torch.tensor([1., 2., 3.])
28
28
  assert torch.allclose(ident(x), x)
29
+
30
+ def test_residual():
31
+ fn = lambda x: x * 2
32
+ res = Residual(fn)
33
+ x = torch.tensor([1., 2., 3.])
34
+ assert torch.allclose(res(x), torch.tensor([3., 6., 9.]))
35
+
36
+ def fn_tuple(x):
37
+ return x * 2, x * 3, dict(a = x * 4)
38
+
39
+ res_tuple = Residual(fn_tuple)
40
+ out1, out2, out3 = res_tuple(x)
41
+ assert torch.allclose(out1, torch.tensor([3., 6., 9.]))
42
+ assert torch.allclose(out2, torch.tensor([3., 6., 9.]))
43
+ assert torch.allclose(out3['a'], torch.tensor([4., 8., 12.]))
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  from torch.nn import Module, Sequential as PyTorchSequential
5
- from torch_einops_utils.torch_einops_utils import exists
5
+ from torch_einops_utils.torch_einops_utils import exists, tree_flatten_with_inverse
6
6
 
7
7
  def Sequential(*modules):
8
8
  return PyTorchSequential(*filter(exists, modules))
@@ -18,3 +18,14 @@ class Lambda(Module):
18
18
 
19
19
  def forward(self, t, *args, **kwargs):
20
20
  return self.fn(t, *args, **kwargs)
21
+
22
+ class Residual(Module):
23
+ def __init__(self, fn: Callable):
24
+ super().__init__()
25
+ self.fn = fn
26
+
27
+ def forward(self, x, *args, **kwargs):
28
+ out = self.fn(x, *args, **kwargs)
29
+
30
+ (first, *rest), inverse = tree_flatten_with_inverse(out)
31
+ return inverse((first + x, *rest))