hyper-connections 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -188,16 +188,21 @@ class StreamEmbed(Module):
188
188
  self,
189
189
  num_streams,
190
190
  dim,
191
- channel_first = False
191
+ channel_first = False,
192
+ expand_to_streams = False
192
193
  ):
193
194
  super().__init__()
194
195
  self.channel_first = channel_first
195
196
  self.num_streams = num_streams
196
197
 
198
+ self.expand_to_streams = expand_to_streams
197
199
  self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
198
200
 
199
201
  def forward(self, residuals):
200
202
 
203
+ if self.expand_to_streams:
204
+ residuals = repeat(residuals, 'b ... -> (b s) ...', s = self.num_streams)
205
+
201
206
  if self.channel_first:
202
207
  residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
203
208
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.7
3
+ Version: 0.0.8
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
+ hyper_connections/hyper_connections.py,sha256=fdcr0DODcIQ1eggy7pa6faX6MqNIZST_q2aDMevViig,6964
3
+ hyper_connections-0.0.8.dist-info/METADATA,sha256=2q_q0AHjDyFHzliHiRXXTBaI9iA1uqECGAH_HMZmGis,4978
4
+ hyper_connections-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.8.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.8.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
- hyper_connections/hyper_connections.py,sha256=RBm0qEhQwCSlvtqNXo_YIkRAkMynCNBrN7xXt4rsRBc,6756
3
- hyper_connections-0.0.7.dist-info/METADATA,sha256=PTozroByBHtvwj8fFJhflo0H0GdwRUM__8aNP2LzuPY,4978
4
- hyper_connections-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.7.dist-info/RECORD,,