hyper-connections 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,342 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ from functools import partial
5
+ from random import randrange
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import Module, ModuleList
11
+ from torch.utils._pytree import tree_flatten, tree_unflatten
12
+
13
+ from einops import rearrange, repeat, reduce, einsum
14
+ from einops.layers.torch import Rearrange
15
+
16
+ from beartype import beartype
17
+
18
+ """
19
+ ein notation:
20
+ b - batch
21
+ d - feature dimension
22
+ s - residual streams
23
+ t - residual streams + num branch inputs
24
+ """
25
+
26
+ # helper functions
27
+
28
+ def exists(v):
29
+ return v is not None
30
+
31
+ def default(v, d):
32
+ return v if exists(v) else d
33
+
34
+ def identity(t):
35
+ return t
36
+
37
+ # main functions
38
+
39
+ def get_expand_reduce_stream_functions(num_streams, disable = False):
40
+
41
+ if disable:
42
+ return (identity, identity)
43
+
44
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
45
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
46
+
47
+ return expand_fn, reduce_fn
48
+
49
+ def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
50
+
51
+ hyper_conn_klass = HyperConnections if not disable else Residual
52
+
53
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
54
+ expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
55
+
56
+ return (init_hyper_conn_fn, *expand_reduce_fns)
57
+
58
+ # norms
59
+
60
+ class RMSNorm(Module):
61
+ def __init__(self, dim):
62
+ super().__init__()
63
+ self.scale = dim ** 0.5
64
+ self.gamma = nn.Parameter(torch.zeros(dim))
65
+
66
+ def forward(self, x):
67
+ return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
68
+
69
+ class ProjActScale(Module):
70
+ def __init__(
71
+ self,
72
+ dim,
73
+ dim_out,
74
+ activation: Module = nn.Identity(),
75
+ scale_init: float = 1e-2,
76
+ squeeze_output = False
77
+ ):
78
+ super().__init__()
79
+ dim_out = default(dim_out, dim)
80
+
81
+ self.proj = nn.Linear(dim, dim_out, bias = False)
82
+ nn.init.zeros_(self.proj.weight)
83
+
84
+ self.act = activation
85
+ self.scale = nn.Parameter(torch.ones(()) * scale_init)
86
+ self.maybe_squeeze = Rearrange('... 1 -> ...') if squeeze_output else nn.Identity()
87
+
88
+ def forward(self, x):
89
+ out = self.proj(x)
90
+ out = self.act(out)
91
+ return self.maybe_squeeze(out * self.scale)
92
+
93
+ # main classes
94
+
95
+ # residual base class
96
+
97
+ class Residual(Module):
98
+ @beartype
99
+ def __init__(
100
+ self,
101
+ *args,
102
+ branch: Module | None = None,
103
+ **kwargs
104
+ ):
105
+ super().__init__()
106
+ self.branch = branch
107
+
108
+ def width_connection(self, residuals, *args, **kwargs):
109
+ return residuals, residuals, dict()
110
+
111
+ def depth_connection(self, branch_output, residuals):
112
+ return branch_output + residuals
113
+
114
+ def decorate_branch(self, branch: Callable):
115
+ assert not exists(self.branch), 'branch was already wrapped on init'
116
+
117
+ def forward_and_add_residual(residual, *args, **kwargs):
118
+ branch_input, add_residual = self.forward(residual, *args, **kwargs)
119
+
120
+ branch_output = branch(branch_input, *args, **kwargs)
121
+
122
+ residual = add_residual(branch_output)
123
+
124
+ return residual
125
+
126
+ return forward_and_add_residual
127
+
128
+ def forward(self, residuals, *branch_args, **branch_kwargs):
129
+
130
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
131
+
132
+ def add_residual_fn(branch_out):
133
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
134
+
135
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
136
+
137
+ return tree_unflatten((branch_out, *rest), tree_spec)
138
+
139
+ if not exists(self.branch):
140
+ return branch_input, add_residual_fn
141
+
142
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
143
+
144
+ return add_residual_fn(branch_output)
145
+
146
+ # hyper connection with multiple input streams
147
+
148
+ InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
149
+
150
+ class HyperConnections(Module):
151
+ @beartype
152
+ def __init__(
153
+ self,
154
+ num_residual_streams,
155
+ *,
156
+ dim,
157
+ additional_input_paths: (
158
+ list[InputPathType |
159
+ tuple[InputPathType, int]] # if the second residual has different dimensions, second tuple element is the dimension
160
+ | None
161
+ ) = None,
162
+ branch: Module | None = None,
163
+ layer_index = None,
164
+ tanh = True,
165
+ channel_first = False,
166
+ dropout = 0.
167
+ ):
168
+ """
169
+ Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
170
+ """
171
+ super().__init__()
172
+
173
+ self.branch = branch
174
+ act = nn.Tanh() if tanh else nn.Identity()
175
+
176
+ self.num_residual_streams = num_residual_streams
177
+ assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
178
+
179
+ # activation, seemingly results were wishy washy depending on using tanh or not
180
+
181
+ self.norm = RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
182
+
183
+ init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
184
+
185
+ init_alpha0 = torch.zeros((num_residual_streams, 1))
186
+ init_alpha0[init_residual_index, 0] = 1.
187
+
188
+ self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
189
+ self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
190
+
191
+ self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
192
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
193
+
194
+ # additional input residual streams
195
+
196
+ additional_input_paths = default(additional_input_paths, [])
197
+ additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
198
+
199
+ assert all([isinstance(path, str) or path > 0 for (path, _) in additional_input_paths])
200
+
201
+ self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
202
+ self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
203
+ self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
204
+
205
+ self.additional_input_paths = additional_input_paths
206
+
207
+ # dropouts
208
+
209
+ self.dropout = nn.Dropout(dropout)
210
+
211
+ # channel first option
212
+
213
+ self.channel_first = channel_first
214
+
215
+ def width_connection(
216
+ self,
217
+ residuals,
218
+ *branch_args,
219
+ **branch_kwargs
220
+ ):
221
+
222
+ transpose = self.channel_first
223
+
224
+ # width connection
225
+
226
+ if transpose:
227
+ residuals = rearrange(residuals, 'b d ... -> b ... d')
228
+
229
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
230
+
231
+ normed = self.norm(residuals)
232
+
233
+ # alpha for weighted sum of residuals going into branch
234
+
235
+ dynamic_alpha = self.dynamic_alpha_and_branch_input(normed)
236
+ alpha = dynamic_alpha + self.static_alpha
237
+
238
+ # beta for weights from branch output back to residual streams
239
+
240
+ dynamic_beta = self.dynamic_beta(normed)
241
+ beta = dynamic_beta + self.static_beta
242
+
243
+ mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
244
+
245
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
246
+
247
+ if transpose:
248
+ branch_input = rearrange(branch_input, 'b ... d -> b d ...')
249
+
250
+ # take care of additional inputs
251
+
252
+ branch_args = list(branch_args)
253
+
254
+ for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
255
+
256
+ # get the residual streams from additional arguments
257
+
258
+ if isinstance(path, int):
259
+ additional_residuals = branch_args[path - 1]
260
+ elif isinstance(path, str):
261
+ additional_residuals = branch_kwargs[path]
262
+
263
+ assert torch.is_tensor(additional_residuals)
264
+
265
+ # handle channel first
266
+
267
+ if transpose:
268
+ additional_residuals = rearrange('b d ... -> b ... d')
269
+
270
+ additional_residuals = rearrange(additional_residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
271
+
272
+ # norm
273
+
274
+ additional_mix = proj(norm(additional_residuals))
275
+ additional_mix = additional_mix + learned_static
276
+
277
+ additional_residuals = einsum(additional_mix, additional_residuals, '... s, ... s d -> ... d')
278
+
279
+ # transpose out
280
+
281
+ if transpose:
282
+ additional_residuals = rearrange('b ... d -> b d ...')
283
+
284
+ # set back transformed residual
285
+
286
+ if isinstance(path, int):
287
+ branch_args[path - 1] = additional_residuals
288
+ elif isinstance(path, str):
289
+ branch_kwargs[path] = additional_residuals
290
+
291
+ return ([branch_input, *branch_args], branch_kwargs), residuals, dict(beta = beta)
292
+
293
+ def depth_connection(self, branch_output, residuals, *, beta):
294
+ # 'depth' connection
295
+
296
+ if self.channel_first:
297
+ branch_output = rearrange(branch_output, 'b d ... -> b ... d')
298
+
299
+ residuals = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d') + residuals
300
+ output = rearrange(residuals, 'b ... s d -> (b s) ... d')
301
+
302
+ if self.channel_first:
303
+ output = rearrange(output, 'b ... d -> b d ...')
304
+
305
+ return self.dropout(output)
306
+
307
+ def decorate_branch(self, branch: Callable):
308
+ assert not exists(self.branch), 'branch was already wrapped on init'
309
+
310
+ def forward_and_add_residual(residual, *args, **kwargs):
311
+ ([branch_input, *args], kwargs), add_residual = self.forward(residual, *args, **kwargs)
312
+
313
+ branch_output = branch(branch_input, *args, **kwargs)
314
+
315
+ residual = add_residual(branch_output)
316
+
317
+ return residual
318
+
319
+ return forward_and_add_residual
320
+
321
+ def forward(self, residuals, *branch_args, **branch_kwargs):
322
+
323
+ (branch_args, branch_kwargs), residuals, residual_kwargs = self.width_connection(residuals, *branch_args, **branch_kwargs)
324
+
325
+ def add_residual_fn(branch_out):
326
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
327
+
328
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
329
+
330
+ return tree_unflatten((branch_out, *rest), tree_spec)
331
+
332
+ if not exists(self.branch):
333
+ return (branch_args, branch_kwargs), add_residual_fn
334
+
335
+ branch_output = self.branch(*branch_args, **branch_kwargs)
336
+
337
+ return add_residual_fn(branch_output)
338
+
339
+ # add static methods
340
+
341
+ HyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
342
+ HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,8 @@
1
+ hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
+ hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=0a3fGZ8SjHL7uzXIVhNnjTzvN0WR41SG31iVcGdGVZ8,11204
5
+ hyper_connections-0.0.24.dist-info/METADATA,sha256=r5x-l4MtcKmP9tGX-0tbxSnstYm6ufinVaW0UpZP9cI,5315
6
+ hyper_connections-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ hyper_connections-0.0.24.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
8
+ hyper_connections-0.0.24.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
- hyper_connections-0.0.22.dist-info/METADATA,sha256=uMrTDUeNCoLpQs89yjMvadzz8r4JLQpky0zQ_Di2H7I,5315
5
- hyper_connections-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- hyper_connections-0.0.22.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
- hyper_connections-0.0.22.dist-info/RECORD,,