hyper-connections 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +23 -6
- hyper_connections/hyper_connections_channel_first.py +207 -0
- {hyper_connections-0.1.7.dist-info → hyper_connections-0.1.9.dist-info}/METADATA +1 -1
- hyper_connections-0.1.9.dist-info/RECORD +9 -0
- hyper_connections-0.1.7.dist-info/RECORD +0 -8
- {hyper_connections-0.1.7.dist-info → hyper_connections-0.1.9.dist-info}/WHEEL +0 -0
- {hyper_connections-0.1.7.dist-info → hyper_connections-0.1.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -132,6 +132,7 @@ class HyperConnections(Module):
|
|
|
132
132
|
channel_first = False,
|
|
133
133
|
dropout = 0.,
|
|
134
134
|
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
135
|
+
add_branch_out_to_residual = True # will disable depth connections (weighted residual sum with beta) if set False
|
|
135
136
|
):
|
|
136
137
|
"""
|
|
137
138
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -151,7 +152,7 @@ class HyperConnections(Module):
|
|
|
151
152
|
self.num_residual_streams = num_residual_streams
|
|
152
153
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
153
154
|
|
|
154
|
-
|
|
155
|
+
# width connection
|
|
155
156
|
|
|
156
157
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
157
158
|
init_alpha0[init_residual_index, 0] = 1.
|
|
@@ -160,8 +161,15 @@ class HyperConnections(Module):
|
|
|
160
161
|
|
|
161
162
|
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
|
162
163
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
163
|
-
|
|
164
|
-
|
|
164
|
+
|
|
165
|
+
# depth connection related (beta)
|
|
166
|
+
|
|
167
|
+
self.add_branch_out_to_residual = add_branch_out_to_residual
|
|
168
|
+
|
|
169
|
+
if add_branch_out_to_residual:
|
|
170
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
171
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
|
172
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
165
173
|
|
|
166
174
|
# dropouts
|
|
167
175
|
|
|
@@ -196,9 +204,12 @@ class HyperConnections(Module):
|
|
|
196
204
|
|
|
197
205
|
# beta for weights from branch output back to residual streams
|
|
198
206
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
207
|
+
beta = None
|
|
208
|
+
|
|
209
|
+
if self.add_branch_out_to_residual:
|
|
210
|
+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
|
211
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
212
|
+
beta = dynamic_beta + self.static_beta
|
|
202
213
|
|
|
203
214
|
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
204
215
|
|
|
@@ -210,6 +221,8 @@ class HyperConnections(Module):
|
|
|
210
221
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
211
222
|
|
|
212
223
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
224
|
+
assert self.add_branch_out_to_residual
|
|
225
|
+
|
|
213
226
|
# 'depth' connection
|
|
214
227
|
|
|
215
228
|
if self.channel_first:
|
|
@@ -244,6 +257,10 @@ class HyperConnections(Module):
|
|
|
244
257
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
245
258
|
|
|
246
259
|
def add_residual_fn(branch_out):
|
|
260
|
+
|
|
261
|
+
if not self.add_branch_out_to_residual:
|
|
262
|
+
return branch_out
|
|
263
|
+
|
|
247
264
|
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
248
265
|
|
|
249
266
|
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from random import randrange
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.nn import Module
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
|
+
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
from einops.layers.torch import Reduce, Rearrange
|
|
15
|
+
|
|
16
|
+
from hyper_connections.hyper_connections import (
|
|
17
|
+
Residual,
|
|
18
|
+
RMSNorm
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
ein notation:
|
|
23
|
+
b - batch
|
|
24
|
+
d - feature dimension
|
|
25
|
+
s - residual streams
|
|
26
|
+
t - residual streams + num branch inputs
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# helper functions
|
|
30
|
+
|
|
31
|
+
def exists(v):
|
|
32
|
+
return v is not None
|
|
33
|
+
|
|
34
|
+
def default(v, d):
|
|
35
|
+
return v if exists(v) else d
|
|
36
|
+
|
|
37
|
+
def identity(t):
|
|
38
|
+
return t
|
|
39
|
+
|
|
40
|
+
# main functions
|
|
41
|
+
|
|
42
|
+
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
43
|
+
|
|
44
|
+
if num_streams == 1 or disable:
|
|
45
|
+
return (nn.Identity(), nn.Identity())
|
|
46
|
+
|
|
47
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
48
|
+
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
49
|
+
|
|
50
|
+
return expand_fn, reduce_fn
|
|
51
|
+
|
|
52
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
53
|
+
|
|
54
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
55
|
+
|
|
56
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
57
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
58
|
+
|
|
59
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
60
|
+
|
|
61
|
+
# norms
|
|
62
|
+
|
|
63
|
+
class RMSNorm(Module):
|
|
64
|
+
def __init__(self, dim):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.scale = dim ** 0.5
|
|
67
|
+
self.gamma = nn.Parameter(torch.zeros(dim, 1, 1))
|
|
68
|
+
|
|
69
|
+
def forward(self, x):
|
|
70
|
+
return F.normalize(x, dim = 1) * self.scale * (self.gamma + 1)
|
|
71
|
+
|
|
72
|
+
# hyper connection residual streams
|
|
73
|
+
|
|
74
|
+
class HyperConnections(Module):
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
num_residual_streams,
|
|
78
|
+
*,
|
|
79
|
+
dim,
|
|
80
|
+
branch: Module | None = None,
|
|
81
|
+
layer_index = None,
|
|
82
|
+
tanh = True,
|
|
83
|
+
channel_first = True,
|
|
84
|
+
dropout = 0.,
|
|
85
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
89
|
+
"""
|
|
90
|
+
super().__init__()
|
|
91
|
+
|
|
92
|
+
self.branch = branch
|
|
93
|
+
|
|
94
|
+
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
95
|
+
|
|
96
|
+
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
97
|
+
|
|
98
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
99
|
+
|
|
100
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
101
|
+
|
|
102
|
+
self.num_residual_streams = num_residual_streams
|
|
103
|
+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
104
|
+
|
|
105
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
106
|
+
|
|
107
|
+
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
108
|
+
init_alpha0[init_residual_index, 0] = 1.
|
|
109
|
+
|
|
110
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
111
|
+
|
|
112
|
+
self.dynamic_alpha_fn = nn.Conv2d(dim, num_residual_streams + 1, 1, bias = False)
|
|
113
|
+
nn.init.zeros_(self.dynamic_alpha_fn.weight)
|
|
114
|
+
|
|
115
|
+
self.dynamic_beta_fn = nn.Sequential(
|
|
116
|
+
nn.Conv2d(dim, 1, 1, bias = False),
|
|
117
|
+
Rearrange('b 1 ... -> b ...')
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
nn.init.zeros_(self.dynamic_beta_fn[0].weight)
|
|
121
|
+
|
|
122
|
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
123
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# dropouts
|
|
127
|
+
|
|
128
|
+
self.dropout = nn.Dropout(dropout)
|
|
129
|
+
|
|
130
|
+
# maybe residual transform
|
|
131
|
+
|
|
132
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
133
|
+
|
|
134
|
+
def width_connection(self, residuals):
|
|
135
|
+
|
|
136
|
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
137
|
+
|
|
138
|
+
# width connection
|
|
139
|
+
|
|
140
|
+
normed = self.norm(residuals)
|
|
141
|
+
|
|
142
|
+
# alpha for weighted sum of residuals going into branch
|
|
143
|
+
|
|
144
|
+
wc_weight = self.act(self.dynamic_alpha_fn(normed))
|
|
145
|
+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
|
146
|
+
|
|
147
|
+
dynamic_alpha = rearrange(dynamic_alpha, '(b s) ... -> b s ...', s = self.num_residual_streams)
|
|
148
|
+
alpha = dynamic_alpha + rearrange(self.static_alpha, 's t -> s t 1 1')
|
|
149
|
+
|
|
150
|
+
# beta for weights from branch output back to residual streams
|
|
151
|
+
|
|
152
|
+
dc_weight = self.act(self.dynamic_beta_fn(normed))
|
|
153
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
154
|
+
dynamic_beta = rearrange(dynamic_beta, '(b s) ... -> b s ...', s = self.num_residual_streams)
|
|
155
|
+
beta = dynamic_beta + rearrange(self.static_beta, 's -> s 1 1')
|
|
156
|
+
|
|
157
|
+
residuals = rearrange(residuals, '(b s) ... -> b s ...', s = self.num_residual_streams)
|
|
158
|
+
mix_h = einsum(alpha, residuals, 'b s t ..., b s d ... -> b t d ...')
|
|
159
|
+
|
|
160
|
+
branch_input, residuals = mix_h[:, 0, ...], mix_h[:, 1:, ...]
|
|
161
|
+
|
|
162
|
+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
163
|
+
|
|
164
|
+
def depth_connection(self, branch_output, residuals, *, beta):
|
|
165
|
+
# 'depth' connection
|
|
166
|
+
|
|
167
|
+
output = einsum(branch_output, beta, 'b d ..., b s ... -> b s d ...')
|
|
168
|
+
output = rearrange(output, 'b s d ... -> (b s) d ...')
|
|
169
|
+
|
|
170
|
+
residuals = residuals + output
|
|
171
|
+
|
|
172
|
+
return self.dropout(residuals)
|
|
173
|
+
|
|
174
|
+
def decorate_branch(self, branch: Callable):
|
|
175
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
176
|
+
|
|
177
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
178
|
+
branch_input, add_residual = self.forward(residual)
|
|
179
|
+
|
|
180
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
181
|
+
|
|
182
|
+
residual = add_residual(branch_output)
|
|
183
|
+
|
|
184
|
+
return residual
|
|
185
|
+
|
|
186
|
+
return forward_and_add_residual
|
|
187
|
+
|
|
188
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
189
|
+
|
|
190
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
191
|
+
|
|
192
|
+
def add_residual_fn(branch_out):
|
|
193
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
194
|
+
|
|
195
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
196
|
+
|
|
197
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
198
|
+
|
|
199
|
+
if not exists(self.branch):
|
|
200
|
+
return branch_input, add_residual_fn
|
|
201
|
+
|
|
202
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
203
|
+
|
|
204
|
+
return add_residual_fn(branch_output)
|
|
205
|
+
|
|
206
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
207
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=F81iJkGMpxgCZPaBTLf0c3CYIE-ROAVgZJWY3NlrsJw,11068
|
|
3
|
+
hyper_connections/hyper_connections_channel_first.py,sha256=BojfO2dcT4jX1rlcU3kr0B6B_CjrkkS2AZU4ZXeWvh8,6769
|
|
4
|
+
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
5
|
+
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
6
|
+
hyper_connections-0.1.9.dist-info/METADATA,sha256=XxicphOwzNfTmBLF4Py89MhTWSpJqVk1EG-DV1gpFvo,5230
|
|
7
|
+
hyper_connections-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
+
hyper_connections-0.1.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
9
|
+
hyper_connections-0.1.9.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=L2e4DduzPGdH30NhfHuiSiVZTwXRgeZW2MDAZ0Z-TKk,10541
|
|
3
|
-
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=2JABz1slrF5_XP33L1CMNTmmixfoD464FtQpUADqneU,7806
|
|
4
|
-
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=UIKceEZEBLKFL5VuP5tR1KTDeZNIJEKjFuPAwXkcp0I,11282
|
|
5
|
-
hyper_connections-0.1.7.dist-info/METADATA,sha256=YThD719ySS2H6ABQnrNHKmoWI9vaGh5d1H9mbMKehV0,5230
|
|
6
|
-
hyper_connections-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
hyper_connections-0.1.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
8
|
-
hyper_connections-0.1.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|