hyper-connections 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -79,8 +79,11 @@ class HyperConnections(Module):
79
79
 
80
80
  self.branch = branch
81
81
 
82
+ # activation, seemingly results were wishy washy depending on using tanh or not
83
+
82
84
  self.act = nn.Tanh() if tanh else nn.Identity()
83
- self.norm = nn.RMSNorm(dim)
85
+
86
+ self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
84
87
 
85
88
  self.num_residual_streams = num_residual_streams
86
89
  init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
@@ -163,6 +166,20 @@ class HyperConnections(Module):
163
166
 
164
167
  return output
165
168
 
169
+ def decorate_branch(self, branch: Callable):
170
+ assert not exists(self.branch), 'branch was already wrapped on init'
171
+
172
+ def forward_and_add_residual(residual, *args, **kwargs):
173
+ branch_input, add_residual = self.forward(residual)
174
+
175
+ branch_output = branch(branch_input)
176
+
177
+ residual = add_residual(branch_output)
178
+
179
+ return residual
180
+
181
+ return forward_and_add_residual
182
+
166
183
  def forward(self, residuals, *branch_args, **branch_kwargs):
167
184
 
168
185
  branch_input, residuals, residual_kwargs = self.width_connection(residuals)
@@ -188,16 +205,21 @@ class StreamEmbed(Module):
188
205
  self,
189
206
  num_streams,
190
207
  dim,
191
- channel_first = False
208
+ channel_first = False,
209
+ expand_to_streams = False
192
210
  ):
193
211
  super().__init__()
194
212
  self.channel_first = channel_first
195
213
  self.num_streams = num_streams
196
214
 
215
+ self.expand_to_streams = expand_to_streams
197
216
  self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
198
217
 
199
218
  def forward(self, residuals):
200
219
 
220
+ if self.expand_to_streams:
221
+ residuals = repeat(residuals, 'b ... -> (b s) ...', s = self.num_streams)
222
+
201
223
  if self.channel_first:
202
224
  residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
203
225
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -130,6 +130,8 @@ branch_output = branch(branch_input)
130
130
 
131
131
  residual = add_residual(branch_output)
132
132
 
133
+ # or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
134
+
133
135
  # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
134
136
 
135
137
  residual = reduce_stream(residual)
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
+ hyper_connections/hyper_connections.py,sha256=YfEDcPcT93-S599gFUpaATtuUZ908vJ_pmjPeF4Po28,7558
3
+ hyper_connections-0.0.9.dist-info/METADATA,sha256=xur7rWt-ZdJU1XxXlpaO0D9aWDR1BgaVNhjfnQaedZQ,5075
4
+ hyper_connections-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.9.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.9.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
- hyper_connections/hyper_connections.py,sha256=RBm0qEhQwCSlvtqNXo_YIkRAkMynCNBrN7xXt4rsRBc,6756
3
- hyper_connections-0.0.7.dist-info/METADATA,sha256=PTozroByBHtvwj8fFJhflo0H0GdwRUM__8aNP2LzuPY,4978
4
- hyper_connections-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.7.dist-info/RECORD,,