hyper-connections 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,8 +6,8 @@ from random import randrange
6
6
 
7
7
  import torch
8
8
  from torch import nn
9
- from torch.nn import Module
10
9
  import torch.nn.functional as F
10
+ from torch.nn import Module, ModuleList
11
11
  from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
@@ -22,6 +22,9 @@ def exists(v):
22
22
  def default(v, d):
23
23
  return v if exists(v) else d
24
24
 
25
+ def divisible_by(num, den):
26
+ return (num % den) == 0
27
+
25
28
  def identity(t):
26
29
  return t
27
30
 
@@ -35,7 +38,7 @@ class HyperConnections(Module):
35
38
  num_residual_streams,
36
39
  *,
37
40
  dim,
38
- branch: Module | None = None,
41
+ branch: Module | tuple[Module, ...] | list[Module] | None = None,
39
42
  layer_index = None,
40
43
  tanh = True,
41
44
  channel_first = False,
@@ -46,7 +49,15 @@ class HyperConnections(Module):
46
49
  """
47
50
  super().__init__()
48
51
 
49
- self.branch = branch
52
+ self.branches = None
53
+
54
+ if isinstance(branch, Module):
55
+ branch = [branch]
56
+
57
+ if exists(branch):
58
+ assert divisible_by(num_branch_inputs, len(branch))
59
+
60
+ self.branches = ModuleList(branch)
50
61
 
51
62
  # activation, seemingly results were wishy washy depending on using tanh or not
52
63
 
@@ -155,13 +166,22 @@ class HyperConnections(Module):
155
166
 
156
167
  return output
157
168
 
158
- def decorate_branch(self, branch: Callable):
159
- assert not exists(self.branch), 'branch was already wrapped on init'
169
+ def decorate_branch(self, branch: Callable | tuple[Callable, ...] | list[Callable]):
170
+ assert not exists(self.branches), 'branch was already wrapped on init'
160
171
 
161
172
  def forward_and_add_residual(residual, *args, **kwargs):
162
173
  branch_input, add_residual = self.forward(residual)
163
174
 
164
- branch_output = branch(branch_input, *args, **kwargs)
175
+ if callable(branch):
176
+ branches = [branch]
177
+ else:
178
+ branches = branch
179
+
180
+ branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(branches))
181
+
182
+ branch_outputs = [fn(x, *args, **kwargs) for fn, x in zip(branches, branch_inputs)]
183
+
184
+ branch_output = torch.cat(branch_outputs)
165
185
 
166
186
  residual = add_residual(branch_output)
167
187
 
@@ -180,9 +200,13 @@ class HyperConnections(Module):
180
200
 
181
201
  return tree_unflatten((branch_out, *rest), tree_spec)
182
202
 
183
- if not exists(self.branch):
203
+ if not exists(self.branches):
184
204
  return branch_input, add_residual_fn
185
205
 
186
- branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
206
+ branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(self.branches))
207
+
208
+ branch_outputs = [fn(x, *branch_args, **branch_kwargs) for fn, x in zip(self.branches, branch_inputs)]
209
+
210
+ branch_output = torch.cat(branch_outputs)
187
211
 
188
212
  return add_residual_fn(branch_output)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.16
3
+ Version: 0.0.17
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,7 @@
1
+ hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
+ hyper_connections/hyper_connections.py,sha256=LoyiGGDH81MLUcF9WKT39lA9Y1CXUAd-SYM9tJGQ61A,8099
3
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=l8oKkeVouFWQZY_0IVNTtgQPSQTA9qZdgLrfotRS_5w,7297
4
+ hyper_connections-0.0.17.dist-info/METADATA,sha256=dB2ZvPv63hTNJB6FEbyLlqxer4L7epwFMOvEVNjNOGk,5076
5
+ hyper_connections-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ hyper_connections-0.0.17.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
+ hyper_connections-0.0.17.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
- hyper_connections/hyper_connections.py,sha256=LoyiGGDH81MLUcF9WKT39lA9Y1CXUAd-SYM9tJGQ61A,8099
3
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=cdYND9EdrZp1Kircrdaxd0-d8KV6VWPQeBce0k3q7Vo,6451
4
- hyper_connections-0.0.16.dist-info/METADATA,sha256=4n5JPHvfqf3IMsDCkw4h3uTGm4WRTnWVz3sMRwQFHTI,5076
5
- hyper_connections-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- hyper_connections-0.0.16.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
- hyper_connections-0.0.16.dist-info/RECORD,,