hyper-connections 0.0.24__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ name: Tests the examples in README
2
+ on: push
3
+
4
+ jobs:
5
+ test:
6
+ runs-on: ubuntu-latest
7
+ steps:
8
+ - uses: actions/checkout@v4
9
+ - name: Install Python
10
+ uses: actions/setup-python@v4
11
+ - name: Install the latest version of rye
12
+ uses: eifinger/setup-rye@v2
13
+ - name: Use UV instead of pip
14
+ run: rye config --set-bool behavior.use-uv=true
15
+ - name: Install dependencies
16
+ run: |
17
+ rye sync
18
+ - name: Run pytest
19
+ run: rye run pytest tests/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.24
3
+ Version: 0.1.0
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -145,7 +145,7 @@ class Residual(Module):
145
145
 
146
146
  # hyper connection with multiple input streams
147
147
 
148
- InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
148
+ InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
149
149
 
150
150
  class HyperConnections(Module):
151
151
  @beartype
@@ -185,7 +185,7 @@ class HyperConnections(Module):
185
185
  init_alpha0 = torch.zeros((num_residual_streams, 1))
186
186
  init_alpha0[init_residual_index, 0] = 1.
187
187
 
188
- self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
188
+ self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
189
189
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
190
190
 
191
191
  self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
@@ -200,7 +200,7 @@ class HyperConnections(Module):
200
200
 
201
201
  self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
202
202
  self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
203
- self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
203
+ self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
204
204
 
205
205
  self.additional_input_paths = additional_input_paths
206
206
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.24"
3
+ version = "0.1.0"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,138 @@
1
+ import pytest
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ @pytest.mark.parametrize('disable', (False, True))
7
+ def test_readme(disable):
8
+
9
+ # a single branch layer
10
+
11
+ branch = nn.Linear(512, 512)
12
+
13
+ # before
14
+
15
+ residual = torch.randn(2, 1024, 512)
16
+
17
+ residual = branch(residual) + residual
18
+
19
+ # after, say 4 streams in paper
20
+
21
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
22
+
23
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
24
+
25
+ # 1. wrap your branch function
26
+
27
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
28
+
29
+ # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
30
+
31
+ residual = expand_stream(residual)
32
+
33
+ # 3. forward your residual as usual into the wrapped branch function(s)
34
+
35
+ residual = hyper_conn_branch(residual)
36
+
37
+ # 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
38
+
39
+ residual = reduce_stream(residual)
40
+
41
+ assert residual.shape == (2, 1024, 512)
42
+
43
+ def test_manual():
44
+ # a single branch layer
45
+
46
+ branch = nn.Linear(512, 512)
47
+
48
+ # before
49
+
50
+ residual = torch.randn(2, 1024, 512)
51
+
52
+ residual = branch(residual) + residual
53
+
54
+ # after, say 4 streams in paper
55
+
56
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
57
+
58
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
59
+
60
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
61
+
62
+ hyper_conn = init_hyper_conn(dim = 512)
63
+
64
+ # 2. expand to 4 streams
65
+
66
+ residual = expand_stream(residual)
67
+
68
+ # 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
69
+
70
+ branch_input, add_residual = hyper_conn(residual)
71
+
72
+ branch_output = branch(branch_input)
73
+
74
+ residual = add_residual(branch_output)
75
+
76
+ # or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
77
+
78
+ # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
79
+
80
+ residual = reduce_stream(residual)
81
+ assert residual.shape == (2, 1024, 512)
82
+
83
+ @pytest.mark.parametrize('disable', (False, True))
84
+ def test_multi_input_hyper_connections(disable):
85
+
86
+ # two branch layers
87
+
88
+ class CustomModule(nn.Module):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.linear = nn.Linear(512, 512)
92
+ self.second_linear = nn.Linear(256, 512)
93
+ self.third_linear = nn.Linear(128, 512)
94
+
95
+ def forward(self, x, second, *, third):
96
+ return self.linear(x) + self.second_linear(second) + self.third_linear(third), 3.
97
+
98
+ branch = CustomModule()
99
+
100
+ # before
101
+
102
+ residual = torch.randn(3, 1024, 512)
103
+ second_residual = torch.randn(3, 1024, 256)
104
+ third_residual = torch.randn(3, 1024, 128)
105
+
106
+ # residual = branch1(residual) + branch2(residual) + residual
107
+
108
+ # after, say 4 streams in paper
109
+
110
+ from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections
111
+
112
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = disable)
113
+
114
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
115
+
116
+ hyper_conn = init_hyper_conn(
117
+ dim = 512,
118
+ branch = branch,
119
+ additional_input_paths = [
120
+ (1, 256), # points at second residual stream, first arg
121
+ ('third', 128) # points at third residual stream, keyword argument 'third'
122
+ ],
123
+ layer_index = 1,
124
+ )
125
+
126
+ # 2. expand to 4 streams
127
+
128
+ residual = expand_stream(residual)
129
+ second_residual = expand_stream(second_residual)
130
+ third_residual = expand_stream(third_residual)
131
+
132
+ # 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
133
+
134
+ residual, rest_output = hyper_conn(residual, second_residual, third = third_residual)
135
+
136
+ residual = reduce_stream(residual)
137
+
138
+ assert residual.shape == (3, 1024, 512)