hyper-connections 0.0.24__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -145,7 +145,7 @@ class Residual(Module):
145
145
 
146
146
  # hyper connection with multiple input streams
147
147
 
148
- InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
148
+ InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
149
149
 
150
150
  class HyperConnections(Module):
151
151
  @beartype
@@ -185,7 +185,7 @@ class HyperConnections(Module):
185
185
  init_alpha0 = torch.zeros((num_residual_streams, 1))
186
186
  init_alpha0[init_residual_index, 0] = 1.
187
187
 
188
- self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
188
+ self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
189
189
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
190
190
 
191
191
  self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
@@ -200,7 +200,7 @@ class HyperConnections(Module):
200
200
 
201
201
  self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
202
202
  self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
203
- self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
203
+ self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
204
204
 
205
205
  self.additional_input_paths = additional_input_paths
206
206
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.24
3
+ Version: 0.1.0
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,8 @@
1
+ hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
+ hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=rSKxtJReg0H-eSCFxdruf2HKU-0-iZ4_x5Fcla3sa3Y,11317
5
+ hyper_connections-0.1.0.dist-info/METADATA,sha256=LNCK8n3-8qGa3twXYGXBg3scYDzw8KGCQDXiRQrV14Q,5314
6
+ hyper_connections-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ hyper_connections-0.1.0.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
8
+ hyper_connections-0.1.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=0a3fGZ8SjHL7uzXIVhNnjTzvN0WR41SG31iVcGdGVZ8,11204
5
- hyper_connections-0.0.24.dist-info/METADATA,sha256=r5x-l4MtcKmP9tGX-0tbxSnstYm6ufinVaW0UpZP9cI,5315
6
- hyper_connections-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- hyper_connections-0.0.24.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
8
- hyper_connections-0.0.24.dist-info/RECORD,,