hyper-connections 0.1.6__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/PKG-INFO +1 -1
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/hyper_connections/hyper_connections.py +18 -6
- hyper_connections-0.1.8/hyper_connections/hyper_connections_channel_first.py +207 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/pyproject.toml +1 -1
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/tests/test_hyper_connections.py +86 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/.gitignore +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/LICENSE +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/README.md +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/hyper-connections.png +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.1.6 → hyper_connections-0.1.8}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
|
@@ -73,16 +73,18 @@ class Residual(Module):
|
|
|
73
73
|
self,
|
|
74
74
|
*args,
|
|
75
75
|
branch: Module | None = None,
|
|
76
|
+
residual_transform: Module | None = None,
|
|
76
77
|
**kwargs
|
|
77
78
|
):
|
|
78
79
|
super().__init__()
|
|
79
80
|
self.branch = branch
|
|
81
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
80
82
|
|
|
81
83
|
def width_connection(self, residuals):
|
|
82
84
|
return residuals, residuals, dict()
|
|
83
85
|
|
|
84
86
|
def depth_connection(self, branch_output, residuals):
|
|
85
|
-
return branch_output + residuals
|
|
87
|
+
return branch_output + self.residual_transform(residuals)
|
|
86
88
|
|
|
87
89
|
def decorate_branch(self, branch: Callable):
|
|
88
90
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -128,7 +130,8 @@ class HyperConnections(Module):
|
|
|
128
130
|
layer_index = None,
|
|
129
131
|
tanh = True,
|
|
130
132
|
channel_first = False,
|
|
131
|
-
dropout = 0
|
|
133
|
+
dropout = 0.,
|
|
134
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
132
135
|
):
|
|
133
136
|
"""
|
|
134
137
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -168,7 +171,14 @@ class HyperConnections(Module):
|
|
|
168
171
|
|
|
169
172
|
self.channel_first = channel_first
|
|
170
173
|
|
|
174
|
+
# maybe residual transform
|
|
175
|
+
|
|
176
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
177
|
+
|
|
171
178
|
def width_connection(self, residuals):
|
|
179
|
+
|
|
180
|
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
181
|
+
|
|
172
182
|
# width connection
|
|
173
183
|
|
|
174
184
|
if self.channel_first:
|
|
@@ -197,7 +207,7 @@ class HyperConnections(Module):
|
|
|
197
207
|
if self.channel_first:
|
|
198
208
|
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
199
209
|
|
|
200
|
-
return branch_input,
|
|
210
|
+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
201
211
|
|
|
202
212
|
def depth_connection(self, branch_output, residuals, *, beta):
|
|
203
213
|
# 'depth' connection
|
|
@@ -205,13 +215,15 @@ class HyperConnections(Module):
|
|
|
205
215
|
if self.channel_first:
|
|
206
216
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
207
217
|
|
|
208
|
-
|
|
209
|
-
output = rearrange(
|
|
218
|
+
output = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d')
|
|
219
|
+
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
210
220
|
|
|
211
221
|
if self.channel_first:
|
|
212
222
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
213
223
|
|
|
214
|
-
|
|
224
|
+
residuals = residuals + output
|
|
225
|
+
|
|
226
|
+
return self.dropout(residuals)
|
|
215
227
|
|
|
216
228
|
def decorate_branch(self, branch: Callable):
|
|
217
229
|
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from random import randrange
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.nn import Module
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
|
+
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
from einops.layers.torch import Reduce, Rearrange
|
|
15
|
+
|
|
16
|
+
from hyper_connections.hyper_connections import (
|
|
17
|
+
Residual,
|
|
18
|
+
RMSNorm
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
ein notation:
|
|
23
|
+
b - batch
|
|
24
|
+
d - feature dimension
|
|
25
|
+
s - residual streams
|
|
26
|
+
t - residual streams + num branch inputs
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# helper functions
|
|
30
|
+
|
|
31
|
+
def exists(v):
|
|
32
|
+
return v is not None
|
|
33
|
+
|
|
34
|
+
def default(v, d):
|
|
35
|
+
return v if exists(v) else d
|
|
36
|
+
|
|
37
|
+
def identity(t):
|
|
38
|
+
return t
|
|
39
|
+
|
|
40
|
+
# main functions
|
|
41
|
+
|
|
42
|
+
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
43
|
+
|
|
44
|
+
if num_streams == 1 or disable:
|
|
45
|
+
return (nn.Identity(), nn.Identity())
|
|
46
|
+
|
|
47
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
48
|
+
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
49
|
+
|
|
50
|
+
return expand_fn, reduce_fn
|
|
51
|
+
|
|
52
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
53
|
+
|
|
54
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
55
|
+
|
|
56
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
57
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
58
|
+
|
|
59
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
60
|
+
|
|
61
|
+
# norms
|
|
62
|
+
|
|
63
|
+
class RMSNorm(Module):
|
|
64
|
+
def __init__(self, dim):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.scale = dim ** 0.5
|
|
67
|
+
self.gamma = nn.Parameter(torch.zeros(dim, 1, 1))
|
|
68
|
+
|
|
69
|
+
def forward(self, x):
|
|
70
|
+
return F.normalize(x, dim = 1) * self.scale * (self.gamma + 1)
|
|
71
|
+
|
|
72
|
+
# hyper connection residual streams
|
|
73
|
+
|
|
74
|
+
class HyperConnections(Module):
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
num_residual_streams,
|
|
78
|
+
*,
|
|
79
|
+
dim,
|
|
80
|
+
branch: Module | None = None,
|
|
81
|
+
layer_index = None,
|
|
82
|
+
tanh = True,
|
|
83
|
+
channel_first = True,
|
|
84
|
+
dropout = 0.,
|
|
85
|
+
residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
89
|
+
"""
|
|
90
|
+
super().__init__()
|
|
91
|
+
|
|
92
|
+
self.branch = branch
|
|
93
|
+
|
|
94
|
+
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
95
|
+
|
|
96
|
+
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
97
|
+
|
|
98
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
99
|
+
|
|
100
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
101
|
+
|
|
102
|
+
self.num_residual_streams = num_residual_streams
|
|
103
|
+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
104
|
+
|
|
105
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
106
|
+
|
|
107
|
+
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
108
|
+
init_alpha0[init_residual_index, 0] = 1.
|
|
109
|
+
|
|
110
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
111
|
+
|
|
112
|
+
self.dynamic_alpha_fn = nn.Conv2d(dim, num_residual_streams + 1, 1, bias = False)
|
|
113
|
+
nn.init.zeros_(self.dynamic_alpha_fn.weight)
|
|
114
|
+
|
|
115
|
+
self.dynamic_beta_fn = nn.Sequential(
|
|
116
|
+
nn.Conv2d(dim, 1, 1, bias = False),
|
|
117
|
+
Rearrange('b 1 ... -> b ...')
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
nn.init.zeros_(self.dynamic_beta_fn[0].weight)
|
|
121
|
+
|
|
122
|
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
123
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# dropouts
|
|
127
|
+
|
|
128
|
+
self.dropout = nn.Dropout(dropout)
|
|
129
|
+
|
|
130
|
+
# maybe residual transform
|
|
131
|
+
|
|
132
|
+
self.residual_transform = default(residual_transform, nn.Identity())
|
|
133
|
+
|
|
134
|
+
def width_connection(self, residuals):
|
|
135
|
+
|
|
136
|
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
|
137
|
+
|
|
138
|
+
# width connection
|
|
139
|
+
|
|
140
|
+
normed = self.norm(residuals)
|
|
141
|
+
|
|
142
|
+
# alpha for weighted sum of residuals going into branch
|
|
143
|
+
|
|
144
|
+
wc_weight = self.act(self.dynamic_alpha_fn(normed))
|
|
145
|
+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
|
146
|
+
|
|
147
|
+
dynamic_alpha = rearrange(dynamic_alpha, '(b s) ... -> b s ...', s = self.num_residual_streams)
|
|
148
|
+
alpha = dynamic_alpha + rearrange(self.static_alpha, 's t -> s t 1 1')
|
|
149
|
+
|
|
150
|
+
# beta for weights from branch output back to residual streams
|
|
151
|
+
|
|
152
|
+
dc_weight = self.act(self.dynamic_beta_fn(normed))
|
|
153
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
154
|
+
dynamic_beta = rearrange(dynamic_beta, '(b s) ... -> b s ...', s = self.num_residual_streams)
|
|
155
|
+
beta = dynamic_beta + rearrange(self.static_beta, 's -> s 1 1')
|
|
156
|
+
|
|
157
|
+
residuals = rearrange(residuals, '(b s) ... -> b s ...', s = self.num_residual_streams)
|
|
158
|
+
mix_h = einsum(alpha, residuals, 'b s t ..., b s d ... -> b t d ...')
|
|
159
|
+
|
|
160
|
+
branch_input, residuals = mix_h[:, 0, ...], mix_h[:, 1:, ...]
|
|
161
|
+
|
|
162
|
+
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
163
|
+
|
|
164
|
+
def depth_connection(self, branch_output, residuals, *, beta):
|
|
165
|
+
# 'depth' connection
|
|
166
|
+
|
|
167
|
+
output = einsum(branch_output, beta, 'b d ..., b s ... -> b s d ...')
|
|
168
|
+
output = rearrange(output, 'b s d ... -> (b s) d ...')
|
|
169
|
+
|
|
170
|
+
residuals = residuals + output
|
|
171
|
+
|
|
172
|
+
return self.dropout(residuals)
|
|
173
|
+
|
|
174
|
+
def decorate_branch(self, branch: Callable):
|
|
175
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
176
|
+
|
|
177
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
178
|
+
branch_input, add_residual = self.forward(residual)
|
|
179
|
+
|
|
180
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
181
|
+
|
|
182
|
+
residual = add_residual(branch_output)
|
|
183
|
+
|
|
184
|
+
return residual
|
|
185
|
+
|
|
186
|
+
return forward_and_add_residual
|
|
187
|
+
|
|
188
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
189
|
+
|
|
190
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
191
|
+
|
|
192
|
+
def add_residual_fn(branch_out):
|
|
193
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
194
|
+
|
|
195
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
196
|
+
|
|
197
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
198
|
+
|
|
199
|
+
if not exists(self.branch):
|
|
200
|
+
return branch_input, add_residual_fn
|
|
201
|
+
|
|
202
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
203
|
+
|
|
204
|
+
return add_residual_fn(branch_output)
|
|
205
|
+
|
|
206
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
207
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
@@ -136,3 +136,89 @@ def test_multi_input_hyper_connections(disable):
|
|
|
136
136
|
residual = reduce_stream(residual)
|
|
137
137
|
|
|
138
138
|
assert residual.shape == (3, 1024, 512)
|
|
139
|
+
|
|
140
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
141
|
+
def test_residual_transform(disable):
|
|
142
|
+
|
|
143
|
+
# a single branch layer
|
|
144
|
+
|
|
145
|
+
branch = nn.Sequential(
|
|
146
|
+
nn.Conv2d(512, 512, 3, padding = 1),
|
|
147
|
+
nn.SiLU(),
|
|
148
|
+
nn.Conv2d(512, 256, 3, padding = 1)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
residual_fn = nn.Conv2d(512, 256, 1)
|
|
152
|
+
|
|
153
|
+
# before
|
|
154
|
+
|
|
155
|
+
residual = torch.randn(2, 512, 16, 16)
|
|
156
|
+
|
|
157
|
+
before_residual = branch(residual) + residual_fn(residual)
|
|
158
|
+
|
|
159
|
+
# after, say 4 streams in paper
|
|
160
|
+
|
|
161
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
162
|
+
|
|
163
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
164
|
+
|
|
165
|
+
# 1. wrap your branch function
|
|
166
|
+
|
|
167
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch, channel_first = True, residual_transform = residual_fn)
|
|
168
|
+
|
|
169
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
170
|
+
|
|
171
|
+
residual = expand_stream(residual)
|
|
172
|
+
|
|
173
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
174
|
+
|
|
175
|
+
residual = hyper_conn_branch(residual)
|
|
176
|
+
|
|
177
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
178
|
+
|
|
179
|
+
after_residual = reduce_stream(residual)
|
|
180
|
+
|
|
181
|
+
assert before_residual.shape == after_residual.shape
|
|
182
|
+
|
|
183
|
+
@pytest.mark.parametrize('disable', (False, True))
|
|
184
|
+
def test_channel_first_hyper_connection(disable):
|
|
185
|
+
|
|
186
|
+
# a single branch layer
|
|
187
|
+
|
|
188
|
+
branch = nn.Sequential(
|
|
189
|
+
nn.Conv2d(512, 512, 3, padding = 1),
|
|
190
|
+
nn.SiLU(),
|
|
191
|
+
nn.Conv2d(512, 256, 3, padding = 1)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
residual_fn = nn.Conv2d(512, 256, 1)
|
|
195
|
+
|
|
196
|
+
# before
|
|
197
|
+
|
|
198
|
+
residual = torch.randn(2, 512, 16, 16)
|
|
199
|
+
|
|
200
|
+
before_residual = branch(residual) + residual_fn(residual)
|
|
201
|
+
|
|
202
|
+
# after, say 4 streams in paper
|
|
203
|
+
|
|
204
|
+
from hyper_connections.hyper_connections_channel_first import get_init_and_expand_reduce_stream_functions
|
|
205
|
+
|
|
206
|
+
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
|
|
207
|
+
|
|
208
|
+
# 1. wrap your branch function
|
|
209
|
+
|
|
210
|
+
hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch, residual_transform = residual_fn)
|
|
211
|
+
|
|
212
|
+
# 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
|
|
213
|
+
|
|
214
|
+
residual = expand_stream(residual)
|
|
215
|
+
|
|
216
|
+
# 3. forward your residual as usual into the wrapped branch function(s)
|
|
217
|
+
|
|
218
|
+
residual = hyper_conn_branch(residual)
|
|
219
|
+
|
|
220
|
+
# 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
|
|
221
|
+
|
|
222
|
+
after_residual = reduce_stream(residual)
|
|
223
|
+
|
|
224
|
+
assert before_residual.shape == after_residual.shape
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|