hyper-connections 0.0.24__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ name: Tests the examples in README
2
+ on: push
3
+
4
+ jobs:
5
+ test:
6
+ runs-on: ubuntu-latest
7
+ steps:
8
+ - uses: actions/checkout@v4
9
+ - name: Install Python
10
+ uses: actions/setup-python@v4
11
+ - name: Install the latest version of rye
12
+ uses: eifinger/setup-rye@v2
13
+ - name: Use UV instead of pip
14
+ run: rye config --set-bool behavior.use-uv=true
15
+ - name: Install dependencies
16
+ run: |
17
+ rye sync
18
+ - name: Run pytest
19
+ run: rye run pytest tests/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.24
3
+ Version: 0.1.1
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -34,12 +34,9 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: beartype
38
37
  Requires-Dist: einops>=0.8.0
39
38
  Requires-Dist: torch>=2.3
40
39
  Provides-Extra: examples
41
- Provides-Extra: test
42
- Requires-Dist: pytest; extra == 'test'
43
40
  Description-Content-Type: text/markdown
44
41
 
45
42
  <img src="./hyper-connections.png" width="450px"></img>
@@ -12,8 +12,6 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
 
15
- from beartype import beartype
16
-
17
15
  """
18
16
  ein notation:
19
17
  b - batch
@@ -65,7 +63,6 @@ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = Fals
65
63
  # hyper connection residual streams
66
64
 
67
65
  class HyperConnections(Module):
68
- @beartype
69
66
  def __init__(
70
67
  self,
71
68
  num_residual_streams,
@@ -13,8 +13,6 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
13
13
  from einops import rearrange, repeat, reduce, einsum
14
14
  from einops.layers.torch import Rearrange
15
15
 
16
- from beartype import beartype
17
-
18
16
  """
19
17
  ein notation:
20
18
  b - batch
@@ -95,7 +93,6 @@ class ProjActScale(Module):
95
93
  # residual base class
96
94
 
97
95
  class Residual(Module):
98
- @beartype
99
96
  def __init__(
100
97
  self,
101
98
  *args,
@@ -145,10 +142,9 @@ class Residual(Module):
145
142
 
146
143
  # hyper connection with multiple input streams
147
144
 
148
- InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int` + 1] and `str` points to **kwargs[`str`]
145
+ InputPathType = int | str # the path to the second residual stream, where `int` points to *args[`int`] and `str` points to **kwargs[`str`] - `int` needs to be > 0, as 0 is the default input residual stream
149
146
 
150
147
  class HyperConnections(Module):
151
- @beartype
152
148
  def __init__(
153
149
  self,
154
150
  num_residual_streams,
@@ -185,7 +181,7 @@ class HyperConnections(Module):
185
181
  init_alpha0 = torch.zeros((num_residual_streams, 1))
186
182
  init_alpha0[init_residual_index, 0] = 1.
187
183
 
188
- self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1)
184
+ self.dynamic_alpha_and_branch_input = ProjActScale(dim, num_residual_streams + 1, activation = act)
189
185
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
190
186
 
191
187
  self.dynamic_beta = ProjActScale(dim, 1, activation = act, squeeze_output = True)
@@ -200,7 +196,7 @@ class HyperConnections(Module):
200
196
 
201
197
  self.additional_norms = ModuleList([RMSNorm(dim) for _, dim in additional_input_paths])
202
198
  self.additional_to_dynamic_input = ModuleList([ProjActScale(dim, 1, activation = act, squeeze_output = True) for _ , dim in additional_input_paths])
203
- self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0])])
199
+ self.additional_static_input = nn.ParameterList([nn.Parameter(init_alpha0[..., 0]) for _ in additional_input_paths])
204
200
 
205
201
  self.additional_input_paths = additional_input_paths
206
202
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.24"
3
+ version = "0.1.1"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -23,7 +23,6 @@ classifiers=[
23
23
  ]
24
24
 
25
25
  dependencies = [
26
- "beartype",
27
26
  "einops>=0.8.0",
28
27
  "torch>=2.3",
29
28
  ]
@@ -34,14 +33,9 @@ Repository = "https://github.com/lucidrains/hyper-connections"
34
33
 
35
34
  [project.optional-dependencies]
36
35
  examples = []
37
- test = [
38
- "pytest"
39
- ]
40
36
 
41
37
  [tool.pytest.ini_options]
42
- pythonpath = [
43
- "."
44
- ]
38
+ pythonpath = ["."]
45
39
 
46
40
  [build-system]
47
41
  requires = ["hatchling"]
@@ -49,7 +43,7 @@ build-backend = "hatchling.build"
49
43
 
50
44
  [tool.rye]
51
45
  managed = true
52
- dev-dependencies = []
46
+ dev-dependencies = ["pytest>=8.2.0"]
53
47
 
54
48
  [tool.hatch.metadata]
55
49
  allow-direct-references = true
@@ -0,0 +1,138 @@
1
+ import pytest
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ @pytest.mark.parametrize('disable', (False, True))
7
+ def test_readme(disable):
8
+
9
+ # a single branch layer
10
+
11
+ branch = nn.Linear(512, 512)
12
+
13
+ # before
14
+
15
+ residual = torch.randn(2, 1024, 512)
16
+
17
+ residual = branch(residual) + residual
18
+
19
+ # after, say 4 streams in paper
20
+
21
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
22
+
23
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4, disable = disable)
24
+
25
+ # 1. wrap your branch function
26
+
27
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
28
+
29
+ # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
30
+
31
+ residual = expand_stream(residual)
32
+
33
+ # 3. forward your residual as usual into the wrapped branch function(s)
34
+
35
+ residual = hyper_conn_branch(residual)
36
+
37
+ # 4. reduce 4 streams with a summation, this has to be done after your for-loop trunk. for transformer, unsure whether to do before or after final norm
38
+
39
+ residual = reduce_stream(residual)
40
+
41
+ assert residual.shape == (2, 1024, 512)
42
+
43
+ def test_manual():
44
+ # a single branch layer
45
+
46
+ branch = nn.Linear(512, 512)
47
+
48
+ # before
49
+
50
+ residual = torch.randn(2, 1024, 512)
51
+
52
+ residual = branch(residual) + residual
53
+
54
+ # after, say 4 streams in paper
55
+
56
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
57
+
58
+ init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(4)
59
+
60
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
61
+
62
+ hyper_conn = init_hyper_conn(dim = 512)
63
+
64
+ # 2. expand to 4 streams
65
+
66
+ residual = expand_stream(residual)
67
+
68
+ # 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
69
+
70
+ branch_input, add_residual = hyper_conn(residual)
71
+
72
+ branch_output = branch(branch_input)
73
+
74
+ residual = add_residual(branch_output)
75
+
76
+ # or you can do it in one line as so -> residual = hyper_conn.decorate_branch(branch)(residual)
77
+
78
+ # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
79
+
80
+ residual = reduce_stream(residual)
81
+ assert residual.shape == (2, 1024, 512)
82
+
83
+ @pytest.mark.parametrize('disable', (False, True))
84
+ def test_multi_input_hyper_connections(disable):
85
+
86
+ # two branch layers
87
+
88
+ class CustomModule(nn.Module):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.linear = nn.Linear(512, 512)
92
+ self.second_linear = nn.Linear(256, 512)
93
+ self.third_linear = nn.Linear(128, 512)
94
+
95
+ def forward(self, x, second, *, third):
96
+ return self.linear(x) + self.second_linear(second) + self.third_linear(third), 3.
97
+
98
+ branch = CustomModule()
99
+
100
+ # before
101
+
102
+ residual = torch.randn(3, 1024, 512)
103
+ second_residual = torch.randn(3, 1024, 256)
104
+ third_residual = torch.randn(3, 1024, 128)
105
+
106
+ # residual = branch1(residual) + branch2(residual) + residual
107
+
108
+ # after, say 4 streams in paper
109
+
110
+ from hyper_connections.hyper_connections_with_multi_input_streams import HyperConnections
111
+
112
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = disable)
113
+
114
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
115
+
116
+ hyper_conn = init_hyper_conn(
117
+ dim = 512,
118
+ branch = branch,
119
+ additional_input_paths = [
120
+ (1, 256), # points at second residual stream, first arg
121
+ ('third', 128) # points at third residual stream, keyword argument 'third'
122
+ ],
123
+ layer_index = 1,
124
+ )
125
+
126
+ # 2. expand to 4 streams
127
+
128
+ residual = expand_stream(residual)
129
+ second_residual = expand_stream(second_residual)
130
+ third_residual = expand_stream(third_residual)
131
+
132
+ # 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
133
+
134
+ residual, rest_output = hyper_conn(residual, second_residual, third = third_residual)
135
+
136
+ residual = reduce_stream(residual)
137
+
138
+ assert residual.shape == (3, 1024, 512)