hyper-connections 0.0.23__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections_with_multi_input_streams.py +9 -5
- {hyper_connections-0.0.23.dist-info → hyper_connections-0.1.0.dist-info}/METADATA +1 -1
- hyper_connections-0.1.0.dist-info/RECORD +8 -0
- hyper_connections-0.0.23.dist-info/RECORD +0 -8
- {hyper_connections-0.0.23.dist-info → hyper_connections-0.1.0.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.23.dist-info → hyper_connections-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -145,7 +145,7 @@ class Residual(Module):
|
|
|
145
145
|
|
|
146
146
|
# hyper connection with multiple input streams
|
|
147
147
|
|
|
148
|
-
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`
|
|
148
|
+
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
|
|
149
149
|
|
|
150
150
|
class HyperConnections(Module):
|
|
151
151
|
@beartype
|
|
@@ -185,7 +185,7 @@ class HyperConnections(Module):
|
|
|
185
185
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
186
186
|
init_alpha0[init_residual_index, 0] = 1.
|
|
187
187
|
|
|
188
|
-
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
|
|
188
|
+
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
|
|
189
189
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
190
190
|
|
|
191
191
|
self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
|
|
@@ -196,9 +196,11 @@ class HyperConnections(Module):
|
|
|
196
196
|
additional_input_paths = default(additional_input_paths, [])
|
|
197
197
|
additional_input_paths = [one_path if isinstance(one_path, tuple) else (one_path, dim) for one_path in additional_input_paths]
|
|
198
198
|
|
|
199
|
+
assert all([isinstance(path, str) or path > 0 for (path, _) in additional_input_paths])
|
|
200
|
+
|
|
199
201
|
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
200
202
|
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
201
|
-
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
203
|
+
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
|
|
202
204
|
|
|
203
205
|
self.additional_input_paths = additional_input_paths
|
|
204
206
|
|
|
@@ -247,12 +249,14 @@ class HyperConnections(Module):
|
|
|
247
249
|
|
|
248
250
|
# take care of additional inputs
|
|
249
251
|
|
|
252
|
+
branch_args = list(branch_args)
|
|
253
|
+
|
|
250
254
|
for (path, *_), norm, proj, learned_static in zip(self.additional_input_paths, self.additional_norms, self.additional_to_dynamic_input, self.additional_static_input):
|
|
251
255
|
|
|
252
256
|
# get the residual streams from additional arguments
|
|
253
257
|
|
|
254
258
|
if isinstance(path, int):
|
|
255
|
-
additional_residuals = branch_args[path]
|
|
259
|
+
additional_residuals = branch_args[path - 1]
|
|
256
260
|
elif isinstance(path, str):
|
|
257
261
|
additional_residuals = branch_kwargs[path]
|
|
258
262
|
|
|
@@ -280,7 +284,7 @@ class HyperConnections(Module):
|
|
|
280
284
|
# set back transformed residual
|
|
281
285
|
|
|
282
286
|
if isinstance(path, int):
|
|
283
|
-
branch_args[path] = additional_residuals
|
|
287
|
+
branch_args[path - 1] = additional_residuals
|
|
284
288
|
elif isinstance(path, str):
|
|
285
289
|
branch_kwargs[path] = additional_residuals
|
|
286
290
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
|
|
3
|
+
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
|
|
4
|
+
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=rSKxtJReg0H-eSCFxdruf2HKU-0-iZ4_x5Fcla3sa3Y,11317
|
|
5
|
+
hyper_connections-0.1.0.dist-info/METADATA,sha256=LNCK8n3-8qGa3twXYGXBg3scYDzw8KGCQDXiRQrV14Q,5314
|
|
6
|
+
hyper_connections-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
hyper_connections-0.1.0.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
8
|
+
hyper_connections-0.1.0.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
|
|
3
|
-
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
|
|
4
|
-
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=wma0du8wO6Q5690aYk7TwIza8cT6Csf5c-8EtiIaIVI,11058
|
|
5
|
-
hyper_connections-0.0.23.dist-info/METADATA,sha256=6Uhck5q8NtgJZYnYiTcPjPQyD1nd-m1crZOXyDEZ_ZU,5315
|
|
6
|
-
hyper_connections-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
hyper_connections-0.0.23.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
8
|
-
hyper_connections-0.0.23.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|