hyper-connections 0.0.24__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/hyper_connections_with_multi_input_streams.py +3 -3
- {hyper_connections-0.0.24.dist-info → hyper_connections-0.1.0.dist-info}/METADATA +1 -1
- hyper_connections-0.1.0.dist-info/RECORD +8 -0
- hyper_connections-0.0.24.dist-info/RECORD +0 -8
- {hyper_connections-0.0.24.dist-info → hyper_connections-0.1.0.dist-info}/WHEEL +0 -0
- {hyper_connections-0.0.24.dist-info → hyper_connections-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -145,7 +145,7 @@ class Residual(Module):
|
|
|
145
145
|
|
|
146
146
|
# hyper connection with multiple input streams
|
|
147
147
|
|
|
148
|
-
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`
|
|
148
|
+
InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
|
|
149
149
|
|
|
150
150
|
class HyperConnections(Module):
|
|
151
151
|
@beartype
|
|
@@ -185,7 +185,7 @@ class HyperConnections(Module):
|
|
|
185
185
|
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
|
186
186
|
init_alpha0[init_residual_index, 0] = 1.
|
|
187
187
|
|
|
188
|
-
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
|
|
188
|
+
self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
|
|
189
189
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
190
190
|
|
|
191
191
|
self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
|
|
@@ -200,7 +200,7 @@ class HyperConnections(Module):
|
|
|
200
200
|
|
|
201
201
|
self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
|
|
202
202
|
self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
|
|
203
|
-
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
|
|
203
|
+
self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
|
|
204
204
|
|
|
205
205
|
self.additional_input_paths = additional_input_paths
|
|
206
206
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
+
hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
|
|
3
|
+
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
|
|
4
|
+
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=rSKxtJReg0H-eSCFxdruf2HKU-0-iZ4_x5Fcla3sa3Y,11317
|
|
5
|
+
hyper_connections-0.1.0.dist-info/METADATA,sha256=LNCK8n3-8qGa3twXYGXBg3scYDzw8KGCQDXiRQrV14Q,5314
|
|
6
|
+
hyper_connections-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
hyper_connections-0.1.0.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
8
|
+
hyper_connections-0.1.0.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
hyper_connections/__init__.py,sha256=d2zNTka0Gp9vINu4U-RhgTJFBhsVrs1fne_15Zl0oOs,224
|
|
2
|
-
hyper_connections/hyper_connections.py,sha256=HyMz-jmICBC6L8QT-LA3EdY8djqG5XkOV7mi-i420mI,9993
|
|
3
|
-
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=x4it5IGllpZGYank8PBHCRzFeozgZxUY7UYo6YkSkcg,7778
|
|
4
|
-
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=0a3fGZ8SjHL7uzXIVhNnjTzvN0WR41SG31iVcGdGVZ8,11204
|
|
5
|
-
hyper_connections-0.0.24.dist-info/METADATA,sha256=r5x-l4MtcKmP9tGX-0tbxSnstYm6ufinVaW0UpZP9cI,5315
|
|
6
|
-
hyper_connections-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
hyper_connections-0.0.24.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
8
|
-
hyper_connections-0.0.24.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|