hyper-connections 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections.py +24 -2
- {hyper_connections-0.0.7.dist-info → hyper_connections-0.0.9.dist-info}/METADATA +3 -1
- hyper_connections-0.0.9.dist-info/RECORD +6 -0
- hyper_connections-0.0.7.dist-info/RECORD +0 -6
- {hyper_connections-0.0.7.dist-info → hyper_connections-0.0.9.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.7.dist-info → hyper_connections-0.0.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -79,8 +79,11 @@ class HyperConnections(Module):
|
|
|
79
79
|
|
|
80
80
|
self.branch = branch
|
|
81
81
|
|
|
82
|
+
# activation, seemingly results were wishy washy depending on using tanh or not
|
|
83
|
+
|
|
82
84
|
self.act = nn.Tanh() if tanh else nn.Identity()
|
|
83
|
-
|
|
85
|
+
|
|
86
|
+
self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
84
87
|
|
|
85
88
|
self.num_residual_streams = num_residual_streams
|
|
86
89
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
@@ -163,6 +166,20 @@ class HyperConnections(Module):
|
|
|
163
166
|
|
|
164
167
|
return output
|
|
165
168
|
|
|
169
|
+
def decorate_branch(self, branch: Callable):
|
|
170
|
+
assert not exists(self.branch), 'branch was already wrapped on init'
|
|
171
|
+
|
|
172
|
+
def forward_and_add_residual(residual, *args, **kwargs):
|
|
173
|
+
branch_input, add_residual = self.forward(residual)
|
|
174
|
+
|
|
175
|
+
branch_output = branch(branch_input)
|
|
176
|
+
|
|
177
|
+
residual = add_residual(branch_output)
|
|
178
|
+
|
|
179
|
+
return residual
|
|
180
|
+
|
|
181
|
+
return forward_and_add_residual
|
|
182
|
+
|
|
166
183
|
def forward(self, residuals, *branch_args, **branch_kwargs):
|
|
167
184
|
|
|
168
185
|
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
|
@@ -188,16 +205,21 @@ class StreamEmbed(Module):
|
|
|
188
205
|
self,
|
|
189
206
|
num_streams,
|
|
190
207
|
dim,
|
|
191
|
-
channel_first = False
|
|
208
|
+
channel_first = False,
|
|
209
|
+
expand_to_streams = False
|
|
192
210
|
):
|
|
193
211
|
super().__init__()
|
|
194
212
|
self.channel_first = channel_first
|
|
195
213
|
self.num_streams = num_streams
|
|
196
214
|
|
|
215
|
+
self.expand_to_streams = expand_to_streams
|
|
197
216
|
self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
|
|
198
217
|
|
|
199
218
|
def forward(self, residuals):
|
|
200
219
|
|
|
220
|
+
if self.expand_to_streams:
|
|
221
|
+
residuals = repeat(residuals, 'b ... -> (b s) ...', s = self.num_streams)
|
|
222
|
+
|
|
201
223
|
if self.channel_first:
|
|
202
224
|
residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
|
|
203
225
|
else:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -130,6 +130,8 @@ branch_output = branch(branch_input)
|
|
|
130
130
|
|
|
131
131
|
residual = add_residual(branch_output)
|
|
132
132
|
|
|
133
|
+
# or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
|
|
134
|
+
|
|
133
135
|
# 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
|
|
134
136
|
|
|
135
137
|
residual = reduce_stream(residual)
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=YfEDcPcT93-S599gFUpaATtuUZ908vJ_pmjPeF4Po28,7558
|
|
3
|
+
hyper_connections-0.0.9.dist-info/METADATA,sha256=xur7rWt-ZdJU1XxXlpaO0D9aWDR1BgaVNhjfnQaedZQ,5075
|
|
4
|
+
hyper_connections-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hyper_connections-0.0.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
+
hyper_connections-0.0.9.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=RBm0qEhQwCSlvtqNXo_YIkRAkMynCNBrN7xXt4rsRBc,6756
|
|
3
|
-
hyper_connections-0.0.7.dist-info/METADATA,sha256=PTozroByBHtvwj8fFJhflo0H0GdwRUM__8aNP2LzuPY,4978
|
|
4
|
-
hyper_connections-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hyper_connections-0.0.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
-
hyper_connections-0.0.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|