hyper-connections 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/__init__.py +2 -1
- hyper_connections/hyper_connections.py +36 -0
- {hyper_connections-0.0.6.dist-info → hyper_connections-0.0.8.dist-info}/METADATA +2 -2
- hyper_connections-0.0.8.dist-info/RECORD +6 -0
- hyper_connections-0.0.6.dist-info/RECORD +0 -6
- {hyper_connections-0.0.6.dist-info → hyper_connections-0.0.8.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.6.dist-info → hyper_connections-0.0.8.dist-info}/licenses/LICENSE +0 -0
hyper_connections/__init__.py
CHANGED
|
@@ -180,3 +180,39 @@ class HyperConnections(Module):
|
|
|
180
180
|
branch_output = add_residual_fn(branch_output)
|
|
181
181
|
|
|
182
182
|
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
183
|
+
|
|
184
|
+
# stream embed
|
|
185
|
+
|
|
186
|
+
class StreamEmbed(Module):
|
|
187
|
+
def __init__(
|
|
188
|
+
self,
|
|
189
|
+
num_streams,
|
|
190
|
+
dim,
|
|
191
|
+
channel_first = False,
|
|
192
|
+
expand_to_streams = False
|
|
193
|
+
):
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.channel_first = channel_first
|
|
196
|
+
self.num_streams = num_streams
|
|
197
|
+
|
|
198
|
+
self.expand_to_streams = expand_to_streams
|
|
199
|
+
self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
|
|
200
|
+
|
|
201
|
+
def forward(self, residuals):
|
|
202
|
+
|
|
203
|
+
if self.expand_to_streams:
|
|
204
|
+
residuals = repeat(residuals, 'b ... -> (b s) ...', s = self.num_streams)
|
|
205
|
+
|
|
206
|
+
if self.channel_first:
|
|
207
|
+
residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
|
|
208
|
+
else:
|
|
209
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
|
|
210
|
+
|
|
211
|
+
residuals = residuals + self.stream_embed
|
|
212
|
+
|
|
213
|
+
if self.channel_first:
|
|
214
|
+
residuals = rearrange(residuals, 'b ... s d -> (b s) d ...', s = self.num_streams)
|
|
215
|
+
else:
|
|
216
|
+
residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
|
|
217
|
+
|
|
218
|
+
return residuals
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -138,7 +138,7 @@ residual = reduce_stream(residual)
|
|
|
138
138
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
139
139
|
|
|
140
140
|
```python
|
|
141
|
-
HyperConnections.get_init_and_expand_reduce_stream_functions(4,
|
|
141
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
142
142
|
```
|
|
143
143
|
|
|
144
144
|
## Citation
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=fdcr0DODcIQ1eggy7pa6faX6MqNIZST_q2aDMevViig,6964
|
|
3
|
+
hyper_connections-0.0.8.dist-info/METADATA,sha256=2q_q0AHjDyFHzliHiRXXTBaI9iA1uqECGAH_HMZmGis,4978
|
|
4
|
+
hyper_connections-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hyper_connections-0.0.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
+
hyper_connections-0.0.8.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=yEc-yNlGq084y0pR0_VVGLr-sH4ye-eVX0RNz7sTPCo,87
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=YibHh3ocMhkCWhEu8EF554HtGm7i4SPH5ChSMEyFPlI,5843
|
|
3
|
-
hyper_connections-0.0.6.dist-info/METADATA,sha256=c_2oOz7OtvUeLt71AHeO5AXj8d_oIOj0Qo1fcBvCN1A,4979
|
|
4
|
-
hyper_connections-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hyper_connections-0.0.6.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
6
|
-
hyper_connections-0.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|