hyper-connections 0.0.23__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -145,7 +145,7 @@ class Residual(Module):
145
145
 
146
146
  # hyper connection with multiple input streams
147
147
 
148
- InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
148
+ InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
149
149
 
150
150
  class HyperConnections(Module):
151
151
  @beartype
@@ -185,7 +185,7 @@ class HyperConnections(Module):
185
185
  init_alpha0 = torch.zeros((num_residual_streams, 1))
186
186
  init_alpha0[init_residual_index, 0] = 1.
187
187
 
188
- self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
188
+ self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
189
189
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
190
190
 
191
191
  self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
@@ -196,9 +196,11 @@ class HyperConnections(Module):
196
196
  additional_input_paths = default(additional_input_paths, [])
197
197
  additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
198
198
 
199
+ assert all([isinstance(path, str) or path > 0 for (path, _) in additional_input_paths])
200
+
199
201
  self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
200
202
  self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
201
- self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
203
+ self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
202
204
 
203
205
  self.additional_input_paths = additional_input_paths
204
206
 
@@ -247,12 +249,14 @@ class HyperConnections(Module):
247
249
 
248
250
  # take care of additional inputs
249
251
 
252
+ branch_args = list(branch_args)
253
+
250
254
  for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
251
255
 
252
256
  # get the residual streams from additional arguments
253
257
 
254
258
  if isinstance(path, int):
255
- additional_residuals = branch_args[path]
259
+ additional_residuals = branch_args[path - 1]
256
260
  elif isinstance(path, str):
257
261
  additional_residuals = branch_kwargs[path]
258
262
 
@@ -280,7 +284,7 @@ class HyperConnections(Module):
280
284
  # set back transformed residual
281
285
 
282
286
  if isinstance(path, int):
283
- branch_args[path] = additional_residuals
287
+ branch_args[path - 1] = additional_residuals
284
288
  elif isinstance(path, str):
285
289
  branch_kwargs[path] = additional_residuals
286
290
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.23
3
+ Version: 0.1.0
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,8 @@
1
+ hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
+ hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=rSKxtJReg0H-eSCFxdruf2HKU-0-iZ4_x5Fcla3sa3Y,11317
5
+ hyper_connections-0.1.0.dist-info/METADATA,sha256=LNCK8n3-8qGa3twXYGXBg3scYDzw8KGCQDXiRQrV14Q,5314
6
+ hyper_connections-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ hyper_connections-0.1.0.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
8
+ hyper_connections-0.1.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
2
- hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
3
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
4
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=wma0du8wO6Q5690aYk7TwIza8cT6Csf5c-8EtiIaIVI,11058
5
- hyper_connections-0.0.23.dist-info/METADATA,sha256=6Uhck5q8NtgJZYnYiTcPjPQyD1nd-m1crZOXyDEZ_ZU,5315
6
- hyper_connections-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- hyper_connections-0.0.23.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
8
- hyper_connections-0.0.23.dist-info/RECORD,,