hyper-connections 0.0.16__tar.gz → 0.0.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/PKG-INFO +1 -1
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/hyper_connections/hyper_connections.py +7 -2
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +38 -9
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/pyproject.toml +1 -1
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/.gitignore +0 -0
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/LICENSE +0 -0
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/README.md +0 -0
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/hyper-connections.png +0 -0
- {hyper_connections-0.0.16 → hyper_connections-0.0.18}/hyper_connections/__init__.py +0 -0
{hyper_connections-0.0.16 → hyper_connections-0.0.18}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -86,7 +86,8 @@ class HyperConnections(Module):
|
|
|
86
86
|
branch: Module | None = None,
|
|
87
87
|
layer_index = None,
|
|
88
88
|
tanh = True,
|
|
89
|
-
channel_first = False
|
|
89
|
+
channel_first = False,
|
|
90
|
+
dropout = 0.
|
|
90
91
|
):
|
|
91
92
|
"""
|
|
92
93
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -116,6 +117,10 @@ class HyperConnections(Module):
|
|
|
116
117
|
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
|
117
118
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
118
119
|
|
|
120
|
+
# dropouts
|
|
121
|
+
|
|
122
|
+
self.dropout = nn.Dropout(dropout)
|
|
123
|
+
|
|
119
124
|
# channel first option
|
|
120
125
|
|
|
121
126
|
self.channel_first = channel_first
|
|
@@ -184,7 +189,7 @@ class HyperConnections(Module):
|
|
|
184
189
|
if self.channel_first:
|
|
185
190
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
186
191
|
|
|
187
|
-
return output
|
|
192
|
+
return self.dropout(output)
|
|
188
193
|
|
|
189
194
|
def decorate_branch(self, branch: Callable):
|
|
190
195
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -6,8 +6,8 @@ from random import randrange
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch import nn
|
|
9
|
-
from torch.nn import Module
|
|
10
9
|
import torch.nn.functional as F
|
|
10
|
+
from torch.nn import Module, ModuleList
|
|
11
11
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
@@ -22,6 +22,9 @@ def exists(v):
|
|
|
22
22
|
def default(v, d):
|
|
23
23
|
return v if exists(v) else d
|
|
24
24
|
|
|
25
|
+
def divisible_by(num, den):
|
|
26
|
+
return (num % den) == 0
|
|
27
|
+
|
|
25
28
|
def identity(t):
|
|
26
29
|
return t
|
|
27
30
|
|
|
@@ -35,10 +38,11 @@ class HyperConnections(Module):
|
|
|
35
38
|
num_residual_streams,
|
|
36
39
|
*,
|
|
37
40
|
dim,
|
|
38
|
-
branch: Module | None = None,
|
|
41
|
+
branch: Module | tuple[Module, ...] | list[Module] | None = None,
|
|
39
42
|
layer_index = None,
|
|
40
43
|
tanh = True,
|
|
41
44
|
channel_first = False,
|
|
45
|
+
dropout = 0.,
|
|
42
46
|
num_branch_inputs = 1 # residuals will be linearly combined to multiple inputs, fed through the branch, then linearly combined back out to residuals
|
|
43
47
|
):
|
|
44
48
|
"""
|
|
@@ -46,7 +50,15 @@ class HyperConnections(Module):
|
|
|
46
50
|
"""
|
|
47
51
|
super().__init__()
|
|
48
52
|
|
|
49
|
-
self.
|
|
53
|
+
self.branches = None
|
|
54
|
+
|
|
55
|
+
if isinstance(branch, Module):
|
|
56
|
+
branch = [branch]
|
|
57
|
+
|
|
58
|
+
if exists(branch):
|
|
59
|
+
assert divisible_by(num_branch_inputs, len(branch))
|
|
60
|
+
|
|
61
|
+
self.branches = ModuleList(branch)
|
|
50
62
|
|
|
51
63
|
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
52
64
|
|
|
@@ -78,6 +90,10 @@ class HyperConnections(Module):
|
|
|
78
90
|
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_branch_inputs))
|
|
79
91
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
80
92
|
|
|
93
|
+
# dropout
|
|
94
|
+
|
|
95
|
+
self.dropout = nn.Dropout(dropout)
|
|
96
|
+
|
|
81
97
|
# channel first option
|
|
82
98
|
|
|
83
99
|
self.channel_first = channel_first
|
|
@@ -153,15 +169,24 @@ class HyperConnections(Module):
|
|
|
153
169
|
if self.channel_first:
|
|
154
170
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
155
171
|
|
|
156
|
-
return output
|
|
172
|
+
return self.dropout(output)
|
|
157
173
|
|
|
158
|
-
def decorate_branch(self, branch: Callable):
|
|
159
|
-
assert not exists(self.
|
|
174
|
+
def decorate_branch(self, branch: Callable | tuple[Callable, ...] | list[Callable]):
|
|
175
|
+
assert not exists(self.branches), 'branch was already wrapped on init'
|
|
160
176
|
|
|
161
177
|
def forward_and_add_residual(residual, *args, **kwargs):
|
|
162
178
|
branch_input, add_residual = self.forward(residual)
|
|
163
179
|
|
|
164
|
-
|
|
180
|
+
if callable(branch):
|
|
181
|
+
branches = [branch]
|
|
182
|
+
else:
|
|
183
|
+
branches = branch
|
|
184
|
+
|
|
185
|
+
branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(branches))
|
|
186
|
+
|
|
187
|
+
branch_outputs = [fn(x, *args, **kwargs) for fn, x in zip(branches, branch_inputs)]
|
|
188
|
+
|
|
189
|
+
branch_output = torch.cat(branch_outputs)
|
|
165
190
|
|
|
166
191
|
residual = add_residual(branch_output)
|
|
167
192
|
|
|
@@ -180,9 +205,13 @@ class HyperConnections(Module):
|
|
|
180
205
|
|
|
181
206
|
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
182
207
|
|
|
183
|
-
if not exists(self.
|
|
208
|
+
if not exists(self.branches):
|
|
184
209
|
return branch_input, add_residual_fn
|
|
185
210
|
|
|
186
|
-
|
|
211
|
+
branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(self.branches))
|
|
212
|
+
|
|
213
|
+
branch_outputs = [fn(x, *branch_args, **branch_kwargs) for fn, x in zip(self.branches, branch_inputs)]
|
|
214
|
+
|
|
215
|
+
branch_output = torch.cat(branch_outputs)
|
|
187
216
|
|
|
188
217
|
return add_residual_fn(branch_output)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|