hyper-connections 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections_with_multi_branch_inputs.py +178 -0
- {hyper_connections-0.0.12.dist-info → hyper_connections-0.0.14.dist-info}/METADATA +1 -1
- hyper_connections-0.0.14.dist-info/RECORD +7 -0
- hyper_connections-0.0.12.dist-info/RECORD +0 -6
- {hyper_connections-0.0.12.dist-info → hyper_connections-0.0.14.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.12.dist-info → hyper_connections-0.0.14.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from random import randrange
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.nn import Module
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
12
|
+
|
|
13
|
+
from einops import rearrange, repeat, reduce, einsum
|
|
14
|
+
|
|
15
|
+
from hyper_connections.hyper_connections import Residual, StreamEmbed
|
|
16
|
+
|
|
17
|
+
# helper functions
|
|
18
|
+
|
|
19
|
+
def exists(v):
|
|
20
|
+
return v is not None
|
|
21
|
+
|
|
22
|
+
def default(v, d):
|
|
23
|
+
return v if exists(v) else d
|
|
24
|
+
|
|
25
|
+
def identity(t):
|
|
26
|
+
return t
|
|
27
|
+
|
|
28
|
+
# main classes
|
|
29
|
+
|
|
30
|
+
# hyper connection residual streams
|
|
31
|
+
|
|
32
|
+
class HyperConnections(Module):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
num_residual_streams,
|
|
36
|
+
*,
|
|
37
|
+
dim,
|
|
38
|
+
branch: Module | None = None,
|
|
39
|
+
layer_index = None,
|
|
40
|
+
tanh = True,
|
|
41
|
+
channel_first = False,
|
|
42
|
+
num_branch_inputs = 1 # residuals will be linearly combined to multiple inputs, fed through the branch, then linearly combined back out to residuals
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
46
|
+
"""
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
self.branch = branch
|
|
50
|
+
|
|
51
|
+
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
52
|
+
|
|
53
|
+
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
54
|
+
|
|
55
|
+
self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
56
|
+
|
|
57
|
+
self.num_residual_streams = num_residual_streams
|
|
58
|
+
self.num_branch_inputs = num_branch_inputs
|
|
59
|
+
|
|
60
|
+
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
61
|
+
|
|
62
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams, num_branch_inputs))
|
|
63
|
+
|
|
64
|
+
init_alpha0 = torch.zeros((num_residual_streams, num_branch_inputs))
|
|
65
|
+
init_alpha0[init_residual_index, :] = 1.
|
|
66
|
+
|
|
67
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
68
|
+
|
|
69
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_branch_inputs))
|
|
70
|
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
71
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_branch_inputs))
|
|
72
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
73
|
+
|
|
74
|
+
# channel first option
|
|
75
|
+
|
|
76
|
+
self.channel_first = channel_first
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def get_expand_reduce_stream_functions(cls, num_streams):
|
|
80
|
+
expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
|
|
81
|
+
reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
82
|
+
|
|
83
|
+
return expand_fn, reduce_fn
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
|
|
87
|
+
|
|
88
|
+
hyper_conn_klass = cls if not disable else Residual
|
|
89
|
+
|
|
90
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
91
|
+
expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams) if not disable else (identity, identity)
|
|
92
|
+
|
|
93
|
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
94
|
+
|
|
95
|
+
def width_connection(self, residuals):
|
|
96
|
+
num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
|
|
97
|
+
|
|
98
|
+
# width connection
|
|
99
|
+
|
|
100
|
+
if self.channel_first:
|
|
101
|
+
residuals = rearrange(residuals, 'b d ... -> b ... d')
|
|
102
|
+
|
|
103
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = num_streams)
|
|
104
|
+
|
|
105
|
+
normed = self.norm(residuals)
|
|
106
|
+
|
|
107
|
+
# alpha for weighted sum of residuals going into branch
|
|
108
|
+
|
|
109
|
+
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
|
|
110
|
+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
|
111
|
+
alpha = dynamic_alpha + self.static_alpha
|
|
112
|
+
|
|
113
|
+
# beta for weights from branch output back to residual streams
|
|
114
|
+
|
|
115
|
+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
|
116
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
|
117
|
+
|
|
118
|
+
beta = dynamic_beta + self.static_beta
|
|
119
|
+
|
|
120
|
+
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
|
|
121
|
+
|
|
122
|
+
branch_input, residuals = mix_h[..., :-num_streams, :], mix_h[..., -num_streams:, :]
|
|
123
|
+
|
|
124
|
+
branch_input = rearrange(branch_input, 'b ... i d -> (i b) ... d')
|
|
125
|
+
|
|
126
|
+
if self.channel_first:
|
|
127
|
+
branch_input = rearrange(branch_input, 'b ... d -> b d ...')
|
|
128
|
+
|
|
129
|
+
return branch_input, residuals, dict(beta = beta)
|
|
130
|
+
|
|
131
|
+
def depth_connection(self, branch_output, residuals, *, beta):
|
|
132
|
+
# 'depth' connection
|
|
133
|
+
|
|
134
|
+
if self.channel_first:
|
|
135
|
+
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
136
|
+
|
|
137
|
+
branch_output = rearrange(branch_output, '(i b) ... -> i b ...', i = self.num_branch_inputs)
|
|
138
|
+
|
|
139
|
+
residuals = einsum(branch_output, beta, 'i b ... d, b ... s i -> b ... s d') + residuals
|
|
140
|
+
|
|
141
|
+
output = rearrange(residuals, 'b ... s d -> (b s) ... d')
|
|
142
|
+
|
|
143
|
+
if self.channel_first:
|
|
144
|
+
output = rearrange(output, 'b ... d -> b d ...')
|
|
145
|
+
|
|
146
|
+
return output
|
|
147
|
+
|
|
148
|
+
def decorate_branch(self, branch: Callable):
|
|
149
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
150
|
+
|
|
151
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
152
|
+
branch_input, add_residual = self.forward(residual)
|
|
153
|
+
|
|
154
|
+
branch_output = branch(branch_input, *args, **kwargs)
|
|
155
|
+
|
|
156
|
+
residual = add_residual(branch_output)
|
|
157
|
+
|
|
158
|
+
return residual
|
|
159
|
+
|
|
160
|
+
return forward_and_add_residual
|
|
161
|
+
|
|
162
|
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
163
|
+
|
|
164
|
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
165
|
+
|
|
166
|
+
def add_residual_fn(branch_out):
|
|
167
|
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
|
168
|
+
|
|
169
|
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
|
170
|
+
|
|
171
|
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
|
172
|
+
|
|
173
|
+
if not exists(self.branch):
|
|
174
|
+
return branch_input, add_residual_fn
|
|
175
|
+
|
|
176
|
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
|
177
|
+
|
|
178
|
+
return add_residual_fn(branch_output)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=-cqKAGkvSf9NKP-S5cmZU9h59qyPbDq27CSDTmZ5fm8,8042
|
|
3
|
+
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=jn-wxAgECVUpiyf1elTtwcr-wvzCQQVS32rxted7XqU,6091
|
|
4
|
+
hyper_connections-0.0.14.dist-info/METADATA,sha256=mM6YzyPTNDYZqR-l8cwqxnZfYngaNf0LunycWmIoP4Y,5076
|
|
5
|
+
hyper_connections-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
hyper_connections-0.0.14.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
7
|
+
hyper_connections-0.0.14.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=-cqKAGkvSf9NKP-S5cmZU9h59qyPbDq27CSDTmZ5fm8,8042
|
|
3
|
-
hyper_connections-0.0.12.dist-info/METADATA,sha256=QjKTCkJf0w2fzs5WHyRmg_8kHLetZ2Nm_hiLzArTg1w,5076
|
|
4
|
-
hyper_connections-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hyper_connections-0.0.12.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
-
hyper_connections-0.0.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|