hyper-connections 0.0.6__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/PKG-INFO +2 -2
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/README.md +1 -1
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/hyper_connections/__init__.py +2 -1
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/hyper_connections/hyper_connections.py +31 -0
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/pyproject.toml +1 -1
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/.gitignore +0 -0
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/LICENSE +0 -0
- {hyper_connections-0.0.6 → hyper_connections-0.0.7}/hyper-connections.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyper-connections
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Hyper-Connections
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hyper-connections/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hyper-connections
|
|
@@ -138,7 +138,7 @@ residual = reduce_stream(residual)
|
|
|
138
138
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
139
139
|
|
|
140
140
|
```python
|
|
141
|
-
HyperConnections.get_init_and_expand_reduce_stream_functions(4,
|
|
141
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
142
142
|
```
|
|
143
143
|
|
|
144
144
|
## Citation
|
|
@@ -95,7 +95,7 @@ residual = reduce_stream(residual)
|
|
|
95
95
|
To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
|
|
96
96
|
|
|
97
97
|
```python
|
|
98
|
-
HyperConnections.get_init_and_expand_reduce_stream_functions(4,
|
|
98
|
+
HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
|
|
99
99
|
```
|
|
100
100
|
|
|
101
101
|
## Citation
|
|
@@ -180,3 +180,34 @@ class HyperConnections(Module):
|
|
|
180
180
|
branch_output = add_residual_fn(branch_output)
|
|
181
181
|
|
|
182
182
|
return tree_unflatten((branch_output, *rest), tree_spec)
|
|
183
|
+
|
|
184
|
+
# stream embed
|
|
185
|
+
|
|
186
|
+
class StreamEmbed(Module):
|
|
187
|
+
def __init__(
|
|
188
|
+
self,
|
|
189
|
+
num_streams,
|
|
190
|
+
dim,
|
|
191
|
+
channel_first = False
|
|
192
|
+
):
|
|
193
|
+
super().__init__()
|
|
194
|
+
self.channel_first = channel_first
|
|
195
|
+
self.num_streams = num_streams
|
|
196
|
+
|
|
197
|
+
self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
|
|
198
|
+
|
|
199
|
+
def forward(self, residuals):
|
|
200
|
+
|
|
201
|
+
if self.channel_first:
|
|
202
|
+
residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
|
|
203
|
+
else:
|
|
204
|
+
residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
|
|
205
|
+
|
|
206
|
+
residuals = residuals + self.stream_embed
|
|
207
|
+
|
|
208
|
+
if self.channel_first:
|
|
209
|
+
residuals = rearrange(residuals, 'b ... s d -> (b s) d ...', s = self.num_streams)
|
|
210
|
+
else:
|
|
211
|
+
residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
|
|
212
|
+
|
|
213
|
+
return residuals
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|