hyper-connections 0.0.23__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections_with_multi_input_streams.py +6 -2
- {hyper_connections-0.0.23.dist-info → hyper_connections-0.0.24.dist-info}/METADATA +1 -1
- {hyper_connections-0.0.23.dist-info → hyper_connections-0.0.24.dist-info}/RECORD +5 -5
- {hyper_connections-0.0.23.dist-info → hyper_connections-0.0.24.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.23.dist-info → hyper_connections-0.0.24.dist-info}/licenses/LICENSE +0 -0
|
@@ -196,6 +196,8 @@ class HyperConnections(Module):
|
|
|
196
196
|
additional_input_paths = default(additional_input_paths, [])
|
|
197
197
|
additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
|
|
198
198
|
|
|
199
|
+
assert all([isinstance(path, str) or path > 0 for (path, _) in additional_input_paths])
|
|
200
|
+
|
|
199
201
|
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
200
202
|
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
201
203
|
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
@@ -247,12 +249,14 @@ class HyperConnections(Module):
|
|
|
247
249
|
|
|
248
250
|
# take care of additional inputs
|
|
249
251
|
|
|
252
|
+
branch_args = list(branch_args)
|
|
253
|
+
|
|
250
254
|
for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
|
|
251
255
|
|
|
252
256
|
# get the residual streams from additional arguments
|
|
253
257
|
|
|
254
258
|
if isinstance(path, int):
|
|
255
|
-
additional_residuals = branch_args[path]
|
|
259
|
+
additional_residuals = branch_args[path - 1]
|
|
256
260
|
elif isinstance(path, str):
|
|
257
261
|
additional_residuals = branch_kwargs[path]
|
|
258
262
|
|
|
@@ -280,7 +284,7 @@ class HyperConnections(Module):
|
|
|
280
284
|
# set back transformed residual
|
|
281
285
|
|
|
282
286
|
if isinstance(path, int):
|
|
283
|
-
branch_args[path] = additional_residuals
|
|
287
|
+
branch_args[path - 1] = additional_residuals
|
|
284
288
|
elif isinstance(path, str):
|
|
285
289
|
branch_kwargs[path] = additional_residuals
|
|
286
290
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
2
|
hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
|
|
3
3
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
|
|
4
|
-
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=
|
|
5
|
-
hyper_connections-0.0.
|
|
6
|
-
hyper_connections-0.0.
|
|
7
|
-
hyper_connections-0.0.
|
|
8
|
-
hyper_connections-0.0.
|
|
4
|
+
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=0a3fGZ8SjHL7uzXIVhNnjTzvN0WR41SG31iVcGdGVZ8,11204
|
|
5
|
+
hyper_connections-0.0.24.dist-info/METADATA,sha256=r5x-l4MtcKmP9tGX-0tbxSnstYm6ufinVaW0UpZP9cI,5315
|
|
6
|
+
hyper_connections-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
hyper_connections-0.0.24.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
8
|
+
hyper_connections-0.0.24.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|