hyper-connections 0.0.12__tar.gz → 0.0.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.12
3
+ Version: 0.0.14
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,178 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ from functools import partial
5
+ from random import randrange
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import Module
10
+ import torch.nn.functional as F
11
+ from torch.utils._pytree import tree_flatten, tree_unflatten
12
+
13
+ from einops import rearrange, repeat, reduce, einsum
14
+
15
+ from hyper_connections.hyper_connections import Residual, StreamEmbed
16
+
17
+ # helper functions
18
+
19
+ def exists(v):
20
+ return v is not None
21
+
22
+ def default(v, d):
23
+ return v if exists(v) else d
24
+
25
+ def identity(t):
26
+ return t
27
+
28
+ # main classes
29
+
30
+ # hyper connection residual streams
31
+
32
+ class HyperConnections(Module):
33
+ def __init__(
34
+ self,
35
+ num_residual_streams,
36
+ *,
37
+ dim,
38
+ branch: Module | None = None,
39
+ layer_index = None,
40
+ tanh = True,
41
+ channel_first = False,
42
+ num_branch_inputs = 1 # residuals will be linearly combined to multiple inputs, fed through the branch, then linearly combined back out to residuals
43
+ ):
44
+ """
45
+ Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
46
+ """
47
+ super().__init__()
48
+
49
+ self.branch = branch
50
+
51
+ # activation, seemingly results were wishy washy depending on using tanh or not
52
+
53
+ self.act = nn.Tanh() if tanh else nn.Identity()
54
+
55
+ self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
56
+
57
+ self.num_residual_streams = num_residual_streams
58
+ self.num_branch_inputs = num_branch_inputs
59
+
60
+ init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
61
+
62
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams, num_branch_inputs))
63
+
64
+ init_alpha0 = torch.zeros((num_residual_streams, num_branch_inputs))
65
+ init_alpha0[init_residual_index, :] = 1.
66
+
67
+ self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
68
+
69
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_branch_inputs))
70
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
71
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_branch_inputs))
72
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
73
+
74
+ # channel first option
75
+
76
+ self.channel_first = channel_first
77
+
78
+ @classmethod
79
+ def get_expand_reduce_stream_functions(cls, num_streams):
80
+ expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
81
+ reduce_fn = partial(reduce, pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
82
+
83
+ return expand_fn, reduce_fn
84
+
85
+ @classmethod
86
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
87
+
88
+ hyper_conn_klass = cls if not disable else Residual
89
+
90
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
91
+ expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams) if not disable else (identity, identity)
92
+
93
+ return (init_hyper_conn_fn, *expand_reduce_fns)
94
+
95
+ def width_connection(self, residuals):
96
+ num_streams, num_branch_inputs = self.num_residual_streams, self.num_branch_inputs
97
+
98
+ # width connection
99
+
100
+ if self.channel_first:
101
+ residuals = rearrange(residuals, 'b d ... -> b ... d')
102
+
103
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = num_streams)
104
+
105
+ normed = self.norm(residuals)
106
+
107
+ # alpha for weighted sum of residuals going into branch
108
+
109
+ wc_weight = self.act(normed @ self.dynamic_alpha_fn)
110
+ dynamic_alpha = wc_weight * self.dynamic_alpha_scale
111
+ alpha = dynamic_alpha + self.static_alpha
112
+
113
+ # beta for weights from branch output back to residual streams
114
+
115
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
116
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
117
+
118
+ beta = dynamic_beta + self.static_beta
119
+
120
+ mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
121
+
122
+ branch_input, residuals = mix_h[..., :-num_streams, :], mix_h[..., -num_streams:, :]
123
+
124
+ branch_input = rearrange(branch_input, 'b ... i d -> (i b) ... d')
125
+
126
+ if self.channel_first:
127
+ branch_input = rearrange(branch_input, 'b ... d -> b d ...')
128
+
129
+ return branch_input, residuals, dict(beta = beta)
130
+
131
+ def depth_connection(self, branch_output, residuals, *, beta):
132
+ # 'depth' connection
133
+
134
+ if self.channel_first:
135
+ branch_output = rearrange(branch_output, 'b d ... -> b ... d')
136
+
137
+ branch_output = rearrange(branch_output, '(i b) ... -> i b ...', i = self.num_branch_inputs)
138
+
139
+ residuals = einsum(branch_output, beta, 'i b ... d, b ... s i -> b ... s d') + residuals
140
+
141
+ output = rearrange(residuals, 'b ... s d -> (b s) ... d')
142
+
143
+ if self.channel_first:
144
+ output = rearrange(output, 'b ... d -> b d ...')
145
+
146
+ return output
147
+
148
+ def decorate_branch(self, branch: Callable):
149
+ assert not exists(self.branch), 'branch was already wrapped on init'
150
+
151
+ def forward_and_add_residual(residual, *args, **kwargs):
152
+ branch_input, add_residual = self.forward(residual)
153
+
154
+ branch_output = branch(branch_input, *args, **kwargs)
155
+
156
+ residual = add_residual(branch_output)
157
+
158
+ return residual
159
+
160
+ return forward_and_add_residual
161
+
162
+ def forward(self, residuals, *branch_args, **branch_kwargs):
163
+
164
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
165
+
166
+ def add_residual_fn(branch_out):
167
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
168
+
169
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
170
+
171
+ return tree_unflatten((branch_out, *rest), tree_spec)
172
+
173
+ if not exists(self.branch):
174
+ return branch_input, add_residual_fn
175
+
176
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
177
+
178
+ return add_residual_fn(branch_output)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.12"
3
+ version = "0.0.14"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }