hyper-connections 0.0.23__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -196,6 +196,8 @@ class HyperConnections(Module):
196
196
  additional_input_paths = default(additional_input_paths, [])
197
197
  additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
198
198
 
199
+ assert all([isinstance(path, str) or path > 0 for (path, _) in additional_input_paths])
200
+
199
201
  self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
200
202
  self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
201
203
  self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
@@ -247,12 +249,14 @@ class HyperConnections(Module):
247
249
 
248
250
  # take care of additional inputs
249
251
 
252
+ branch_args = list(branch_args)
253
+
250
254
  for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
251
255
 
252
256
  # get the residual streams from additional arguments
253
257
 
254
258
  if isinstance(path, int):
255
- additional_residuals = branch_args[path]
259
+ additional_residuals = branch_args[path - 1]
256
260
  elif isinstance(path, str):
257
261
  additional_residuals = branch_kwargs[path]
258
262
 
@@ -280,7 +284,7 @@ class HyperConnections(Module):
280
284
  # set back transformed residual
281
285
 
282
286
  if isinstance(path, int):
283
- branch_args[path] = additional_residuals
287
+ branch_args[path - 1] = additional_residuals
284
288
  elif isinstance(path, str):
285
289
  branch_kwargs[path] = additional_residuals
286
290
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.23
3
+ Version: 0.0.24
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -1,8 +1,8 @@
1
1
  hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
2
  hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
3
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=wma0du8wO6Q5690aYk7TwIza8cT6Csf5c-8EtiIaIVI,11058
5
- hyper_connections-0.0.23.dist-info/METADATA,sha256=6Uhck5q8NtgJZYnYiTcPjPQyD1nd-m1crZOXyDEZ_ZU,5315
6
- hyper_connections-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- hyper_connections-0.0.23.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
8
- hyper_connections-0.0.23.dist-info/RECORD,,
4
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=0a3fGZ8SjHL7uzXIVhNnjTzvN0WR41SG31iVcGdGVZ8,11204
5
+ hyper_connections-0.0.24.dist-info/METADATA,sha256=r5x-l4MtcKmP9tGX-0tbxSnstYm6ufinVaW0UpZP9cI,5315
6
+ hyper_connections-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ hyper_connections-0.0.24.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
8
+ hyper_connections-0.0.24.dist-info/RECORD,,