hyper-connections 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
1
  from hyper_connections.hyper_connections import (
2
- HyperConnections
2
+ HyperConnections,
3
+ Residual,
4
+ StreamEmbed
3
5
  )
@@ -18,7 +18,46 @@ def exists(v):
18
18
  def default(v, d):
19
19
  return v if exists(v) else d
20
20
 
21
- # main class
21
+ def identity(t):
22
+ return t
23
+
24
+ # main classes
25
+
26
+ # residual base class
27
+
28
+ class Residual(Module):
29
+ def __init__(
30
+ self,
31
+ *args,
32
+ branch = None,
33
+ **kwargs
34
+ ):
35
+ super().__init__()
36
+ self.branch = branch
37
+
38
+ def width_connection(self, residuals):
39
+ return residuals, residuals, dict()
40
+
41
+ def depth_connection(self, branch_output, residuals):
42
+ return branch_output + residuals
43
+
44
+ def forward(self, residuals, *branch_args, **branch_kwargs):
45
+
46
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
47
+
48
+ def add_residual_fn(branch_out):
49
+ return self.depth_connection(branch_out, residuals, **residual_kwargs)
50
+
51
+ if not exists(self.branch):
52
+ return branch_input, add_residual_fn
53
+
54
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
55
+
56
+ (branch_output, *rest), tree_spec = tree_flatten(branch_output)
57
+
58
+ branch_output = add_residual_fn(branch_output)
59
+
60
+ return tree_unflatten((branch_output, *rest), tree_spec)
22
61
 
23
62
  # hyper connection residual streams
24
63
 
@@ -70,10 +109,12 @@ class HyperConnections(Module):
70
109
  return expand_fn, reduce_fn
71
110
 
72
111
  @classmethod
73
- def get_init_and_expand_reduce_stream_functions(cls, num_streams):
112
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
113
+
114
+ hyper_conn_klass = cls if not disable else Residual
74
115
 
75
- init_hyper_conn_fn = partial(cls, num_streams)
76
- expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams)
116
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
117
+ expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams) if not disable else (identity, identity)
77
118
 
78
119
  return (init_hyper_conn_fn, *expand_reduce_fns)
79
120
 
@@ -106,9 +147,9 @@ class HyperConnections(Module):
106
147
  if self.channel_first:
107
148
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
108
149
 
109
- return branch_input, residuals, beta
150
+ return branch_input, residuals, dict(beta = beta)
110
151
 
111
- def depth_connection(self, branch_output, residuals, beta):
152
+ def depth_connection(self, branch_output, residuals, *, beta):
112
153
  # 'depth' connection
113
154
 
114
155
  if self.channel_first:
@@ -124,10 +165,10 @@ class HyperConnections(Module):
124
165
 
125
166
  def forward(self, residuals, *branch_args, **branch_kwargs):
126
167
 
127
- branch_input, residuals, beta = self.width_connection(residuals)
168
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
128
169
 
129
170
  def add_residual_fn(branch_out):
130
- return self.depth_connection(branch_out, residuals, beta)
171
+ return self.depth_connection(branch_out, residuals, **residual_kwargs)
131
172
 
132
173
  if not exists(self.branch):
133
174
  return branch_input, add_residual_fn
@@ -139,3 +180,34 @@ class HyperConnections(Module):
139
180
  branch_output = add_residual_fn(branch_output)
140
181
 
141
182
  return tree_unflatten((branch_output, *rest), tree_spec)
183
+
184
+ # stream embed
185
+
186
+ class StreamEmbed(Module):
187
+ def __init__(
188
+ self,
189
+ num_streams,
190
+ dim,
191
+ channel_first = False
192
+ ):
193
+ super().__init__()
194
+ self.channel_first = channel_first
195
+ self.num_streams = num_streams
196
+
197
+ self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
198
+
199
+ def forward(self, residuals):
200
+
201
+ if self.channel_first:
202
+ residuals = rearrange(residuals, '(b s) d ... -> b ... s d', s = self.num_streams)
203
+ else:
204
+ residuals = rearrange(residuals, '(b s) ... d -> b ... s d', s = self.num_streams)
205
+
206
+ residuals = residuals + self.stream_embed
207
+
208
+ if self.channel_first:
209
+ residuals = rearrange(residuals, 'b ... s d -> (b s) d ...', s = self.num_streams)
210
+ else:
211
+ residuals = rearrange(residuals, 'b ... s d -> (b s) ... d', s = self.num_streams)
212
+
213
+ return residuals
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -135,6 +135,12 @@ residual = add_residual(branch_output)
135
135
  residual = reduce_stream(residual)
136
136
  ```
137
137
 
138
+ To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
139
+
140
+ ```python
141
+ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disable = True)
142
+ ```
143
+
138
144
  ## Citation
139
145
 
140
146
  ```bibtex
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
+ hyper_connections/hyper_connections.py,sha256=RBm0qEhQwCSlvtqNXo_YIkRAkMynCNBrN7xXt4rsRBc,6756
3
+ hyper_connections-0.0.7.dist-info/METADATA,sha256=PTozroByBHtvwj8fFJhflo0H0GdwRUM__8aNP2LzuPY,4978
4
+ hyper_connections-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.7.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
- hyper_connections/hyper_connections.py,sha256=2lZcPuW4hEKet3r8caN-sN-PzRaBNL1q-V3_uA1lVaM,4613
3
- hyper_connections-0.0.5.dist-info/METADATA,sha256=bko4lEEBiulROYd0aOC_8rIk0hPinBg5TBQWf9DZe9M,4753
4
- hyper_connections-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.5.dist-info/RECORD,,