hyper-connections 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ import torch
6
6
  from torch import nn
7
7
  from torch.nn import Module
8
8
  import torch.nn.functional as F
9
+ from torch.utils._pytree import tree_flatten, tree_unflatten
9
10
 
10
11
  from einops import rearrange, repeat, reduce, einsum
11
12
 
@@ -30,6 +31,7 @@ class HyperConnections(Module):
30
31
  branch: Module | None = None,
31
32
  layer_index = None,
32
33
  tanh = True,
34
+ channel_first = False
33
35
  ):
34
36
  """
35
37
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -56,6 +58,10 @@ class HyperConnections(Module):
56
58
  self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
57
59
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
58
60
 
61
+ # channel first option
62
+
63
+ self.channel_first = channel_first
64
+
59
65
  @classmethod
60
66
  def get_expand_reduce_stream_functions(cls, num_streams):
61
67
  expand_fn = partial(repeat, pattern = 'b ... -> (b s) ...', s = num_streams)
@@ -66,31 +72,47 @@ class HyperConnections(Module):
66
72
  def width_connection(self, residuals):
67
73
  # width connection
68
74
 
75
+ if self.channel_first:
76
+ residuals = rearrange(residuals, 'b d ... -> b ... d')
77
+
69
78
  residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_residual_streams)
70
79
 
71
80
  normed = self.norm(residuals)
72
81
 
82
+ # alpha for weighted sum of residuals going into branch
83
+
73
84
  wc_weight = self.act(normed @ self.dynamic_alpha_fn)
74
85
  dynamic_alpha = wc_weight * self.dynamic_alpha_scale
75
86
  alpha = dynamic_alpha + self.static_alpha
76
87
 
88
+ # beta for weights from branch output back to residual streams
89
+
77
90
  dc_weight = self.act(normed @ self.dynamic_beta_fn)
78
91
  dynamic_beta = dc_weight * self.dynamic_beta_scale
79
92
  beta = dynamic_beta + self.static_beta
80
93
 
81
- # width connection
82
-
83
94
  mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
84
95
 
85
96
  branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
86
97
 
98
+ if self.channel_first:
99
+ branch_input = rearrange(branch_input, 'b ... d -> b d ...')
100
+
87
101
  return branch_input, residuals, beta
88
102
 
89
103
  def depth_connection(self, branch_output, residuals, beta):
90
104
  # 'depth' connection
91
105
 
106
+ if self.channel_first:
107
+ branch_output = rearrange(branch_output, 'b d ... -> b ... d')
108
+
92
109
  residuals = einsum(branch_output, beta, 'b ... d, b ... s -> b ... s d') + residuals
93
- return rearrange(residuals, 'b ... s d -> (b s) ... d')
110
+ output = rearrange(residuals, 'b ... s d -> (b s) ... d')
111
+
112
+ if self.channel_first:
113
+ output = rearrange(output, 'b ... d -> b d ...')
114
+
115
+ return output
94
116
 
95
117
  def forward(self, residuals, **branch_kwargs):
96
118
 
@@ -104,4 +126,8 @@ class HyperConnections(Module):
104
126
 
105
127
  branch_output = self.branch(branch_input, **branch_kwargs)
106
128
 
107
- return add_residual_fn(branch_output)
129
+ (branch_output, *rest), tree_spec = tree_flatten(branch_output)
130
+
131
+ branch_output = add_residual_fn(branch_output)
132
+
133
+ return tree_unflatten((branch_output, *rest), tree_spec)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -45,7 +45,7 @@ Description-Content-Type: text/markdown
45
45
 
46
46
  ## Hyper Connections
47
47
 
48
- Attempt to make the multiple residual stream approach proposed by Hyper-Connections paper by Bytedance AI more accessible as a reusable library, and for following any new research in this direction.
48
+ Attempt to make multiple residual streams, proposed in [Hyper-Connections paper](https://arxiv.org/abs/2409.19606) out of Bytedance AI lab, accessible as an easy to use library, as well as for following any new research in this direction.
49
49
 
50
50
  ## Install
51
51
 
@@ -92,6 +92,49 @@ residual = hyper_conn_branch(residual)
92
92
  residual = reduce_stream(residual)
93
93
  ```
94
94
 
95
+ Or doing it manually, as in the paper
96
+
97
+ ```python
98
+ import torch
99
+ from torch import nn
100
+
101
+ # a single branch layer
102
+
103
+ branch = nn.Linear(512, 512)
104
+
105
+ # before
106
+
107
+ residual = torch.randn(2, 1024, 512)
108
+
109
+ residual = branch(residual) + residual
110
+
111
+ # after, say 4 streams in paper
112
+
113
+ from hyper_connections import HyperConnections
114
+
115
+ expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
116
+
117
+ # 1. instantiate hyper connection with correct number of streams (4 in this case)
118
+
119
+ hyper_conn = HyperConnections(4, dim = 512)
120
+
121
+ # 2. expand to 4 streams
122
+
123
+ residual = expand_stream(residual)
124
+
125
+ # 3. forward your residual into hyper connection for the branch input + add residual function (learned betas)
126
+
127
+ branch_input, add_residual = hyper_conn(residual)
128
+
129
+ branch_output = branch(branch_input)
130
+
131
+ residual = add_residual(branch_output)
132
+
133
+ # 4. reduce 4 streams with a summation, this has to be done after your for loop trunk
134
+
135
+ residual = reduce_stream(residual)
136
+ ```
137
+
95
138
  ## Citation
96
139
 
97
140
  ```bibtex
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
+ hyper_connections/hyper_connections.py,sha256=91QtTtnpffmErIZvrnTtosSf4JgBqcyGvxftmka-EOw,4303
3
+ hyper_connections-0.0.3.dist-info/METADATA,sha256=8XKTmC6Ys10uOyotPtvL17v4uZyepkWoeMRVM4B_TSQ,4676
4
+ hyper_connections-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.3.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.3.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
- hyper_connections/hyper_connections.py,sha256=Nbv7_OZ8FkdRG1WfmqdQHf46GZxWY0G_h_p4lT_JW38,3450
3
- hyper_connections-0.0.1.dist-info/METADATA,sha256=o_PVMP0Mm_sr7196WRYbh7O3wDp-nHVzAESgfWzi3FQ,3684
4
- hyper_connections-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.1.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.1.dist-info/RECORD,,