hyper-connections 0.0.16__tar.gz → 0.0.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.16
3
+ Version: 0.0.18
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -86,7 +86,8 @@ class HyperConnections(Module):
86
86
  branch: Module | None = None,
87
87
  layer_index = None,
88
88
  tanh = True,
89
- channel_first = False
89
+ channel_first = False,
90
+ dropout = 0.
90
91
  ):
91
92
  """
92
93
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -116,6 +117,10 @@ class HyperConnections(Module):
116
117
  self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
117
118
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
118
119
 
120
+ # dropouts
121
+
122
+ self.dropout = nn.Dropout(dropout)
123
+
119
124
  # channel first option
120
125
 
121
126
  self.channel_first = channel_first
@@ -184,7 +189,7 @@ class HyperConnections(Module):
184
189
  if self.channel_first:
185
190
  output = rearrange(output, 'b ... d -> b d ...')
186
191
 
187
- return output
192
+ return self.dropout(output)
188
193
 
189
194
  def decorate_branch(self, branch: Callable):
190
195
  assert not exists(self.branch), 'branch was already wrapped on init'
@@ -6,8 +6,8 @@ from random import randrange
6
6
 
7
7
  import torch
8
8
  from torch import nn
9
- from torch.nn import Module
10
9
  import torch.nn.functional as F
10
+ from torch.nn import Module, ModuleList
11
11
  from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
@@ -22,6 +22,9 @@ def exists(v):
22
22
  def default(v, d):
23
23
  return v if exists(v) else d
24
24
 
25
+ def divisible_by(num, den):
26
+ return (num % den) == 0
27
+
25
28
  def identity(t):
26
29
  return t
27
30
 
@@ -35,10 +38,11 @@ class HyperConnections(Module):
35
38
  num_residual_streams,
36
39
  *,
37
40
  dim,
38
- branch: Module | None = None,
41
+ branch: Module | tuple[Module, ...] | list[Module] | None = None,
39
42
  layer_index = None,
40
43
  tanh = True,
41
44
  channel_first = False,
45
+ dropout = 0.,
42
46
  num_branch_inputs = 1 # residuals will be linearly combined to multiple inputs, fed through the branch, then linearly combined back out to residuals
43
47
  ):
44
48
  """
@@ -46,7 +50,15 @@ class HyperConnections(Module):
46
50
  """
47
51
  super().__init__()
48
52
 
49
- self.branch = branch
53
+ self.branches = None
54
+
55
+ if isinstance(branch, Module):
56
+ branch = [branch]
57
+
58
+ if exists(branch):
59
+ assert divisible_by(num_branch_inputs, len(branch))
60
+
61
+ self.branches = ModuleList(branch)
50
62
 
51
63
  # activation, seemingly results were wishy washy depending on using tanh or not
52
64
 
@@ -78,6 +90,10 @@ class HyperConnections(Module):
78
90
  self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_branch_inputs))
79
91
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
80
92
 
93
+ # dropout
94
+
95
+ self.dropout = nn.Dropout(dropout)
96
+
81
97
  # channel first option
82
98
 
83
99
  self.channel_first = channel_first
@@ -153,15 +169,24 @@ class HyperConnections(Module):
153
169
  if self.channel_first:
154
170
  output = rearrange(output, 'b ... d -> b d ...')
155
171
 
156
- return output
172
+ return self.dropout(output)
157
173
 
158
- def decorate_branch(self, branch: Callable):
159
- assert not exists(self.branch), 'branch was already wrapped on init'
174
+ def decorate_branch(self, branch: Callable | tuple[Callable, ...] | list[Callable]):
175
+ assert not exists(self.branches), 'branch was already wrapped on init'
160
176
 
161
177
  def forward_and_add_residual(residual, *args, **kwargs):
162
178
  branch_input, add_residual = self.forward(residual)
163
179
 
164
- branch_output = branch(branch_input, *args, **kwargs)
180
+ if callable(branch):
181
+ branches = [branch]
182
+ else:
183
+ branches = branch
184
+
185
+ branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(branches))
186
+
187
+ branch_outputs = [fn(x, *args, **kwargs) for fn, x in zip(branches, branch_inputs)]
188
+
189
+ branch_output = torch.cat(branch_outputs)
165
190
 
166
191
  residual = add_residual(branch_output)
167
192
 
@@ -180,9 +205,13 @@ class HyperConnections(Module):
180
205
 
181
206
  return tree_unflatten((branch_out, *rest), tree_spec)
182
207
 
183
- if not exists(self.branch):
208
+ if not exists(self.branches):
184
209
  return branch_input, add_residual_fn
185
210
 
186
- branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
211
+ branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(self.branches))
212
+
213
+ branch_outputs = [fn(x, *branch_args, **branch_kwargs) for fn, x in zip(self.branches, branch_inputs)]
214
+
215
+ branch_output = torch.cat(branch_outputs)
187
216
 
188
217
  return add_residual_fn(branch_output)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.16"
3
+ version = "0.0.18"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }