hyper-connections 0.0.22__tar.gz → 0.0.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.22
3
+ Version: 0.0.23
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,338 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ from functools import partial
5
+ from random import randrange
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import Module, ModuleList
11
+ from torch.utils._pytree import tree_flatten, tree_unflatten
12
+
13
+ from einops import rearrange, repeat, reduce, einsum
14
+ from einops.layers.torch import Rearrange
15
+
16
+ from beartype import beartype
17
+
18
+ """
19
+ ein notation:
20
+ b - batch
21
+ d - feature dimension
22
+ s - residual streams
23
+ t - residual streams + num branch inputs
24
+ """
25
+
26
+ # helper functions
27
+
28
+ def exists(v):
29
+ return v is not None
30
+
31
+ def default(v, d):
32
+ return v if exists(v) else d
33
+
34
+ def identity(t):
35
+ return t
36
+
37
+ # main functions
38
+
39
+ def get_expand_reduce_stream_functions(num_streams, disable = False):
40
+
41
+ if disable:
42
+ return (identity, identity)
43
+
44
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
45
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
46
+
47
+ return expand_fn, reduce_fn
48
+
49
+ def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
50
+
51
+ hyper_conn_klass = HyperConnections if not disable else Residual
52
+
53
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
54
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
55
+
56
+ return (init_hyper_conn_fn, *expand_reduce_fns)
57
+
58
+ # norms
59
+
60
+ class RMSNorm(Module):
61
+ def __init__(self, dim):
62
+ super().__init__()
63
+ self.scale = dim ** 0.5
64
+ self.gamma = nn.Parameter(torch.zeros(dim))
65
+
66
+ def forward(self, x):
67
+ return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
68
+
69
+ class ProjActScale(Module):
70
+ def __init__(
71
+ self,
72
+ dim,
73
+ dim_out,
74
+ activation: Module = nn.Identity(),
75
+ scale_init: float = 1e-2,
76
+ squeeze_output = False
77
+ ):
78
+ super().__init__()
79
+ dim_out = default(dim_out, dim)
80
+
81
+ self.proj = nn.Linear(dim, dim_out, bias = False)
82
+ nn.init.zeros_(self.proj.weight)
83
+
84
+ self.act = activation
85
+ self.scale = nn.Parameter(torch.ones(()) * scale_init)
86
+ self.maybe_squeeze = Rearrange('... 1 -> ...') if squeeze_output else nn.Identity()
87
+
88
+ def forward(self, x):
89
+ out = self.proj(x)
90
+ out = self.act(out)
91
+ return self.maybe_squeeze(out * self.scale)
92
+
93
+ # main classes
94
+
95
+ # residual base class
96
+
97
+ class Residual(Module):
98
+ @beartype
99
+ def __init__(
100
+ self,
101
+ *args,
102
+ branch: Module | None = None,
103
+ **kwargs
104
+ ):
105
+ super().__init__()
106
+ self.branch = branch
107
+
108
+ def width_connection(self, residuals, *args, **kwargs):
109
+ return residuals, residuals, dict()
110
+
111
+ def depth_connection(self, branch_output, residuals):
112
+ return branch_output + residuals
113
+
114
+ def decorate_branch(self, branch: Callable):
115
+ assert not exists(self.branch), 'branch was already wrapped on init'
116
+
117
+ def forward_and_add_residual(residual, *args, **kwargs):
118
+ branch_input, add_residual = self.forward(residual, *args, **kwargs)
119
+
120
+ branch_output = branch(branch_input, *args, **kwargs)
121
+
122
+ residual = add_residual(branch_output)
123
+
124
+ return residual
125
+
126
+ return forward_and_add_residual
127
+
128
+ def forward(self, residuals, *branch_args, **branch_kwargs):
129
+
130
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
131
+
132
+ def add_residual_fn(branch_out):
133
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
134
+
135
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
136
+
137
+ return tree_unflatten((branch_out, *rest), tree_spec)
138
+
139
+ if not exists(self.branch):
140
+ return branch_input, add_residual_fn
141
+
142
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
143
+
144
+ return add_residual_fn(branch_output)
145
+
146
+ # hyper connection with multiple input streams
147
+
148
+ InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
149
+
150
+ class HyperConnections(Module):
151
+ @beartype
152
+ def __init__(
153
+ self,
154
+ num_residual_streams,
155
+ *,
156
+ dim,
157
+ additional_input_paths: (
158
+ list[InputPathType |
159
+ tuple[InputPathType, int]] # if the second residual has different dimensions, second tuple element is the dimension
160
+ | None
161
+ ) = None,
162
+ branch: Module | None = None,
163
+ layer_index = None,
164
+ tanh = True,
165
+ channel_first = False,
166
+ dropout = 0.
167
+ ):
168
+ """
169
+ Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
170
+ """
171
+ super().__init__()
172
+
173
+ self.branch = branch
174
+ act = nn.Tanh() if tanh else nn.Identity()
175
+
176
+ self.num_residual_streams = num_residual_streams
177
+ assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
178
+
179
+ # activation, seemingly results were wishy washy depending on using tanh or not
180
+
181
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
182
+
183
+ init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
184
+
185
+ init_alpha0 = torch.zeros((num_residual_streams, 1))
186
+ init_alpha0[init_residual_index, 0] = 1.
187
+
188
+ self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
189
+ self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
190
+
191
+ self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
192
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
193
+
194
+ # additional input residual streams
195
+
196
+ additional_input_paths = default(additional_input_paths, [])
197
+ additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
198
+
199
+ self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
200
+ self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
201
+ self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
202
+
203
+ self.additional_input_paths = additional_input_paths
204
+
205
+ # dropouts
206
+
207
+ self.dropout = nn.Dropout(dropout)
208
+
209
+ # channel first option
210
+
211
+ self.channel_first = channel_first
212
+
213
+ def width_connection(
214
+ self,
215
+ residuals,
216
+ *branch_args,
217
+ **branch_kwargs
218
+ ):
219
+
220
+ transpose = self.channel_first
221
+
222
+ # width connection
223
+
224
+ if transpose:
225
+ residuals = rearrange(residuals, 'b d ... -> b ... d')
226
+
227
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
228
+
229
+ normed = self.norm(residuals)
230
+
231
+ # alpha for weighted sum of residuals going into branch
232
+
233
+ dynamic_alpha = self.dynamic_alpha_and_branch_input(normed)
234
+ alpha = dynamic_alpha + self.static_alpha
235
+
236
+ # beta for weights from branch output back to residual streams
237
+
238
+ dynamic_beta = self.dynamic_beta(normed)
239
+ beta = dynamic_beta + self.static_beta
240
+
241
+ mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
242
+
243
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
244
+
245
+ if transpose:
246
+ branch_input = rearrange(branch_input, 'b ... d -> b d ...')
247
+
248
+ # take care of additional inputs
249
+
250
+ for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
251
+
252
+ # get the residual streams from additional arguments
253
+
254
+ if isinstance(path, int):
255
+ additional_residuals = branch_args[path]
256
+ elif isinstance(path, str):
257
+ additional_residuals = branch_kwargs[path]
258
+
259
+ assert torch.is_tensor(additional_residuals)
260
+
261
+ # handle channel first
262
+
263
+ if transpose:
264
+ additional_residuals = rearrange('b d ... -> b ... d')
265
+
266
+ additional_residuals = rearrange(additional_residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
267
+
268
+ # norm
269
+
270
+ additional_mix = proj(norm(additional_residuals))
271
+ additional_mix = additional_mix + learned_static
272
+
273
+ additional_residuals = einsum(additional_mix, additional_residuals, '... s, ... s d -> ... d')
274
+
275
+ # transpose out
276
+
277
+ if transpose:
278
+ additional_residuals = rearrange('b ... d -> b d ...')
279
+
280
+ # set back transformed residual
281
+
282
+ if isinstance(path, int):
283
+ branch_args[path] = additional_residuals
284
+ elif isinstance(path, str):
285
+ branch_kwargs[path] = additional_residuals
286
+
287
+ return ([branch_input, *branch_args], branch_kwargs), residuals, dict(beta = beta)
288
+
289
+ def depth_connection(self, branch_output, residuals, *, beta):
290
+ # 'depth' connection
291
+
292
+ if self.channel_first:
293
+ branch_output = rearrange(branch_output, 'b d ... -> b ... d')
294
+
295
+ residuals = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d') + residuals
296
+ output = rearrange(residuals, 'b ... s d -> (b s) ... d')
297
+
298
+ if self.channel_first:
299
+ output = rearrange(output, 'b ... d -> b d ...')
300
+
301
+ return self.dropout(output)
302
+
303
+ def decorate_branch(self, branch: Callable):
304
+ assert not exists(self.branch), 'branch was already wrapped on init'
305
+
306
+ def forward_and_add_residual(residual, *args, **kwargs):
307
+ ([branch_input, *args], kwargs), add_residual = self.forward(residual, *args, **kwargs)
308
+
309
+ branch_output = branch(branch_input, *args, **kwargs)
310
+
311
+ residual = add_residual(branch_output)
312
+
313
+ return residual
314
+
315
+ return forward_and_add_residual
316
+
317
+ def forward(self, residuals, *branch_args, **branch_kwargs):
318
+
319
+ (branch_args, branch_kwargs), residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
320
+
321
+ def add_residual_fn(branch_out):
322
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
323
+
324
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
325
+
326
+ return tree_unflatten((branch_out, *rest), tree_spec)
327
+
328
+ if not exists(self.branch):
329
+ return (branch_args, branch_kwargs), add_residual_fn
330
+
331
+ branch_output = self.branch(*branch_args, **branch_kwargs)
332
+
333
+ return add_residual_fn(branch_output)
334
+
335
+ # add static methods
336
+
337
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
338
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.22"
3
+ version = "0.0.23"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }