hyper-connections 0.0.22__tar.gz → 0.0.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/PKG-INFO +1 -1
- hyper_connections-0.0.23/hyper_connections/hyper_connections_with_multi_input_streams.py +338 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/pyproject.toml +1 -1
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/.gitignore +0 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/LICENSE +0 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/README.md +0 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/hyper-connections.png +0 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.0.22 → hyper_connections-0.0.23}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from random import randrange
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch.nn import Module, ModuleList
|
|
11
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
|
+
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
from einops.layers.torch import Rearrange
|
|
15
|
+
|
|
16
|
+
from beartype import beartype
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
ein notation:
|
|
20
|
+
b - batch
|
|
21
|
+
d - feature dimension
|
|
22
|
+
s - residual streams
|
|
23
|
+
t - residual streams + num branch inputs
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# helper functions
|
|
27
|
+
|
|
28
|
+
def exists(v):
|
|
29
|
+
return v is not None
|
|
30
|
+
|
|
31
|
+
def default(v, d):
|
|
32
|
+
return v if exists(v) else d
|
|
33
|
+
|
|
34
|
+
def identity(t):
|
|
35
|
+
return t
|
|
36
|
+
|
|
37
|
+
# main functions
|
|
38
|
+
|
|
39
|
+
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
40
|
+
|
|
41
|
+
if disable:
|
|
42
|
+
return (identity, identity)
|
|
43
|
+
|
|
44
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
45
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
46
|
+
|
|
47
|
+
return expand_fn, reduce_fn
|
|
48
|
+
|
|
49
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
50
|
+
|
|
51
|
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
52
|
+
|
|
53
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
54
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
55
|
+
|
|
56
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
57
|
+
|
|
58
|
+
# norms
|
|
59
|
+
|
|
60
|
+
class RMSNorm(Module):
|
|
61
|
+
def __init__(self, dim):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.scale = dim ** 0.5
|
|
64
|
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
|
65
|
+
|
|
66
|
+
def forward(self, x):
|
|
67
|
+
return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
|
|
68
|
+
|
|
69
|
+
class ProjActScale(Module):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
dim,
|
|
73
|
+
dim_out,
|
|
74
|
+
activation: Module = nn.Identity(),
|
|
75
|
+
scale_init: float = 1e-2,
|
|
76
|
+
squeeze_output = False
|
|
77
|
+
):
|
|
78
|
+
super().__init__()
|
|
79
|
+
dim_out = default(dim_out, dim)
|
|
80
|
+
|
|
81
|
+
self.proj = nn.Linear(dim, dim_out, bias = False)
|
|
82
|
+
nn.init.zeros_(self.proj.weight)
|
|
83
|
+
|
|
84
|
+
self.act = activation
|
|
85
|
+
self.scale = nn.Parameter(torch.ones(()) * scale_init)
|
|
86
|
+
self.maybe_squeeze = Rearrange('... 1 -> ...') if squeeze_output else nn.Identity()
|
|
87
|
+
|
|
88
|
+
def forward(self, x):
|
|
89
|
+
out = self.proj(x)
|
|
90
|
+
out = self.act(out)
|
|
91
|
+
return self.maybe_squeeze(out * self.scale)
|
|
92
|
+
|
|
93
|
+
# main classes
|
|
94
|
+
|
|
95
|
+
# residual base class
|
|
96
|
+
|
|
97
|
+
class Residual(Module):
|
|
98
|
+
@beartype
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
*args,
|
|
102
|
+
branch: Module | None = None,
|
|
103
|
+
**kwargs
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.branch = branch
|
|
107
|
+
|
|
108
|
+
def width_connection(self, residuals, *args, **kwargs):
|
|
109
|
+
return residuals, residuals, dict()
|
|
110
|
+
|
|
111
|
+
def depth_connection(self, branch_output, residuals):
|
|
112
|
+
return branch_output + residuals
|
|
113
|
+
|
|
114
|
+
def decorate_branch(self, branch: Callable):
|
|
115
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
116
|
+
|
|
117
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
118
|
+
branch_input, add_residual = self.forward(residual, *args, **kwargs)
|
|
119
|
+
|
|
120
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
121
|
+
|
|
122
|
+
residual = add_residual(branch_output)
|
|
123
|
+
|
|
124
|
+
return residual
|
|
125
|
+
|
|
126
|
+
return forward_and_add_residual
|
|
127
|
+
|
|
128
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
129
|
+
|
|
130
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
|
|
131
|
+
|
|
132
|
+
def add_residual_fn(branch_out):
|
|
133
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
134
|
+
|
|
135
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
136
|
+
|
|
137
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
138
|
+
|
|
139
|
+
if not exists(self.branch):
|
|
140
|
+
return branch_input, add_residual_fn
|
|
141
|
+
|
|
142
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
143
|
+
|
|
144
|
+
return add_residual_fn(branch_output)
|
|
145
|
+
|
|
146
|
+
# hyper connection with multiple input streams
|
|
147
|
+
|
|
148
|
+
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
|
|
149
|
+
|
|
150
|
+
class HyperConnections(Module):
|
|
151
|
+
@beartype
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
num_residual_streams,
|
|
155
|
+
*,
|
|
156
|
+
dim,
|
|
157
|
+
additional_input_paths: (
|
|
158
|
+
list[InputPathType |
|
|
159
|
+
tuple[InputPathType, int]] # if the second residual has different dimensions, second tuple element is the dimension
|
|
160
|
+
| None
|
|
161
|
+
) = None,
|
|
162
|
+
branch: Module | None = None,
|
|
163
|
+
layer_index = None,
|
|
164
|
+
tanh = True,
|
|
165
|
+
channel_first = False,
|
|
166
|
+
dropout = 0.
|
|
167
|
+
):
|
|
168
|
+
"""
|
|
169
|
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
170
|
+
"""
|
|
171
|
+
super().__init__()
|
|
172
|
+
|
|
173
|
+
self.branch = branch
|
|
174
|
+
act = nn.Tanh() if tanh else nn.Identity()
|
|
175
|
+
|
|
176
|
+
self.num_residual_streams = num_residual_streams
|
|
177
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
178
|
+
|
|
179
|
+
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
180
|
+
|
|
181
|
+
self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
182
|
+
|
|
183
|
+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
184
|
+
|
|
185
|
+
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
186
|
+
init_alpha0[init_residual_index, 0] = 1.
|
|
187
|
+
|
|
188
|
+
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
|
|
189
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
190
|
+
|
|
191
|
+
self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
|
|
192
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
193
|
+
|
|
194
|
+
# additional input residual streams
|
|
195
|
+
|
|
196
|
+
additional_input_paths = default(additional_input_paths, [])
|
|
197
|
+
additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
|
|
198
|
+
|
|
199
|
+
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
200
|
+
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
201
|
+
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
202
|
+
|
|
203
|
+
self.additional_input_paths = additional_input_paths
|
|
204
|
+
|
|
205
|
+
# dropouts
|
|
206
|
+
|
|
207
|
+
self.dropout = nn.Dropout(dropout)
|
|
208
|
+
|
|
209
|
+
# channel first option
|
|
210
|
+
|
|
211
|
+
self.channel_first = channel_first
|
|
212
|
+
|
|
213
|
+
def width_connection(
|
|
214
|
+
self,
|
|
215
|
+
residuals,
|
|
216
|
+
*branch_args,
|
|
217
|
+
**branch_kwargs
|
|
218
|
+
):
|
|
219
|
+
|
|
220
|
+
transpose = self.channel_first
|
|
221
|
+
|
|
222
|
+
# width connection
|
|
223
|
+
|
|
224
|
+
if transpose:
|
|
225
|
+
residuals = rearrange(residuals, 'b d ... -> b ... d')
|
|
226
|
+
|
|
227
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
|
|
228
|
+
|
|
229
|
+
normed = self.norm(residuals)
|
|
230
|
+
|
|
231
|
+
# alpha for weighted sum of residuals going into branch
|
|
232
|
+
|
|
233
|
+
dynamic_alpha = self.dynamic_alpha_and_branch_input(normed)
|
|
234
|
+
alpha = dynamic_alpha + self.static_alpha
|
|
235
|
+
|
|
236
|
+
# beta for weights from branch output back to residual streams
|
|
237
|
+
|
|
238
|
+
dynamic_beta = self.dynamic_beta(normed)
|
|
239
|
+
beta = dynamic_beta + self.static_beta
|
|
240
|
+
|
|
241
|
+
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
242
|
+
|
|
243
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
244
|
+
|
|
245
|
+
if transpose:
|
|
246
|
+
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
247
|
+
|
|
248
|
+
# take care of additional inputs
|
|
249
|
+
|
|
250
|
+
for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
|
|
251
|
+
|
|
252
|
+
# get the residual streams from additional arguments
|
|
253
|
+
|
|
254
|
+
if isinstance(path, int):
|
|
255
|
+
additional_residuals = branch_args[path]
|
|
256
|
+
elif isinstance(path, str):
|
|
257
|
+
additional_residuals = branch_kwargs[path]
|
|
258
|
+
|
|
259
|
+
assert torch.is_tensor(additional_residuals)
|
|
260
|
+
|
|
261
|
+
# handle channel first
|
|
262
|
+
|
|
263
|
+
if transpose:
|
|
264
|
+
additional_residuals = rearrange('b d ... -> b ... d')
|
|
265
|
+
|
|
266
|
+
additional_residuals = rearrange(additional_residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
|
|
267
|
+
|
|
268
|
+
# norm
|
|
269
|
+
|
|
270
|
+
additional_mix = proj(norm(additional_residuals))
|
|
271
|
+
additional_mix = additional_mix + learned_static
|
|
272
|
+
|
|
273
|
+
additional_residuals = einsum(additional_mix, additional_residuals, '... s, ... s d -> ... d')
|
|
274
|
+
|
|
275
|
+
# transpose out
|
|
276
|
+
|
|
277
|
+
if transpose:
|
|
278
|
+
additional_residuals = rearrange('b ... d -> b d ...')
|
|
279
|
+
|
|
280
|
+
# set back transformed residual
|
|
281
|
+
|
|
282
|
+
if isinstance(path, int):
|
|
283
|
+
branch_args[path] = additional_residuals
|
|
284
|
+
elif isinstance(path, str):
|
|
285
|
+
branch_kwargs[path] = additional_residuals
|
|
286
|
+
|
|
287
|
+
return ([branch_input, *branch_args], branch_kwargs), residuals, dict(beta = beta)
|
|
288
|
+
|
|
289
|
+
def depth_connection(self, branch_output, residuals, *, beta):
|
|
290
|
+
# 'depth' connection
|
|
291
|
+
|
|
292
|
+
if self.channel_first:
|
|
293
|
+
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
294
|
+
|
|
295
|
+
residuals = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d') + residuals
|
|
296
|
+
output = rearrange(residuals, 'b ... s d -> (b s) ... d')
|
|
297
|
+
|
|
298
|
+
if self.channel_first:
|
|
299
|
+
output = rearrange(output, 'b ... d -> b d ...')
|
|
300
|
+
|
|
301
|
+
return self.dropout(output)
|
|
302
|
+
|
|
303
|
+
def decorate_branch(self, branch: Callable):
|
|
304
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
305
|
+
|
|
306
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
307
|
+
([branch_input, *args], kwargs), add_residual = self.forward(residual, *args, **kwargs)
|
|
308
|
+
|
|
309
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
310
|
+
|
|
311
|
+
residual = add_residual(branch_output)
|
|
312
|
+
|
|
313
|
+
return residual
|
|
314
|
+
|
|
315
|
+
return forward_and_add_residual
|
|
316
|
+
|
|
317
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
318
|
+
|
|
319
|
+
(branch_args, branch_kwargs), residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
|
|
320
|
+
|
|
321
|
+
def add_residual_fn(branch_out):
|
|
322
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
323
|
+
|
|
324
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
325
|
+
|
|
326
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
327
|
+
|
|
328
|
+
if not exists(self.branch):
|
|
329
|
+
return (branch_args, branch_kwargs), add_residual_fn
|
|
330
|
+
|
|
331
|
+
branch_output = self.branch(*branch_args, **branch_kwargs)
|
|
332
|
+
|
|
333
|
+
return add_residual_fn(branch_output)
|
|
334
|
+
|
|
335
|
+
# add static methods
|
|
336
|
+
|
|
337
|
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
|
|
338
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{hyper_connections-0.0.22 → hyper_connections-0.0.23}/hyper_connections/hyper_connections.py
RENAMED
|
File without changes
|