hyper-connections 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
1
  from hyper_connections.hyper_connections import (
2
- HyperConnections
2
+ HyperConnections,
3
+ Residual
3
4
  )
@@ -18,7 +18,46 @@ def exists(v):
18
18
  def default(v, d):
19
19
  return v if exists(v) else d
20
20
 
21
- # main class
21
+ def identity(t):
22
+ return t
23
+
24
+ # main classes
25
+
26
+ # residual base class
27
+
28
+ class Residual(Module):
29
+ def __init__(
30
+ self,
31
+ *args,
32
+ branch = None,
33
+ **kwargs
34
+ ):
35
+ super().__init__()
36
+ self.branch = branch
37
+
38
+ def width_connection(self, residuals):
39
+ return residuals, residuals, dict()
40
+
41
+ def depth_connection(self, branch_output, residuals):
42
+ return branch_output + residuals
43
+
44
+ def forward(self, residuals, *branch_args, **branch_kwargs):
45
+
46
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
47
+
48
+ def add_residual_fn(branch_out):
49
+ return self.depth_connection(branch_out, residuals, **residual_kwargs)
50
+
51
+ if not exists(self.branch):
52
+ return branch_input, add_residual_fn
53
+
54
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
55
+
56
+ (branch_output, *rest), tree_spec = tree_flatten(branch_output)
57
+
58
+ branch_output = add_residual_fn(branch_output)
59
+
60
+ return tree_unflatten((branch_output, *rest), tree_spec)
22
61
 
23
62
  # hyper connection residual streams
24
63
 
@@ -69,6 +108,16 @@ class HyperConnections(Module):
69
108
 
70
109
  return expand_fn, reduce_fn
71
110
 
111
+ @classmethod
112
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
113
+
114
+ hyper_conn_klass = cls if not disable else Residual
115
+
116
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
117
+ expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams) if not disable else (identity, identity)
118
+
119
+ return (init_hyper_conn_fn, *expand_reduce_fns)
120
+
72
121
  def width_connection(self, residuals):
73
122
  # width connection
74
123
 
@@ -98,9 +147,9 @@ class HyperConnections(Module):
98
147
  if self.channel_first:
99
148
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
100
149
 
101
- return branch_input, residuals, beta
150
+ return branch_input, residuals, dict(beta = beta)
102
151
 
103
- def depth_connection(self, branch_output, residuals, beta):
152
+ def depth_connection(self, branch_output, residuals, *, beta):
104
153
  # 'depth' connection
105
154
 
106
155
  if self.channel_first:
@@ -116,10 +165,10 @@ class HyperConnections(Module):
116
165
 
117
166
  def forward(self, residuals, *branch_args, **branch_kwargs):
118
167
 
119
- branch_input, residuals, beta = self.width_connection(residuals)
168
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
120
169
 
121
170
  def add_residual_fn(branch_out):
122
- return self.depth_connection(branch_out, residuals, beta)
171
+ return self.depth_connection(branch_out, residuals, **residual_kwargs)
123
172
 
124
173
  if not exists(self.branch):
125
174
  return branch_input, add_residual_fn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -73,11 +73,11 @@ residual = branch(residual) + residual
73
73
 
74
74
  from hyper_connections import HyperConnections
75
75
 
76
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
76
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
77
 
78
78
  # 1. wrap your branch function
79
79
 
80
- hyper_conn_branch = HyperConnections(4, dim = 512, branch = branch)
80
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
81
81
 
82
82
  # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
83
83
 
@@ -112,11 +112,11 @@ residual = branch(residual) + residual
112
112
 
113
113
  from hyper_connections import HyperConnections
114
114
 
115
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
115
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
116
 
117
- # 1. instantiate hyper connection with correct number of streams (4 in this case)
117
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
118
 
119
- hyper_conn = HyperConnections(4, dim = 512)
119
+ hyper_conn = init_hyper_conn(dim = 512)
120
120
 
121
121
  # 2. expand to 4 streams
122
122
 
@@ -135,6 +135,12 @@ residual = add_residual(branch_output)
135
135
  residual = reduce_stream(residual)
136
136
  ```
137
137
 
138
+ To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
139
+
140
+ ```python
141
+ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disabled = True)
142
+ ```
143
+
138
144
  ## Citation
139
145
 
140
146
  ```bibtex
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=yEc-yNlGq084y0pR0_VVGLr-sH4ye-eVX0RNz7sTPCo,87
2
+ hyper_connections/hyper_connections.py,sha256=YibHh3ocMhkCWhEu8EF554HtGm7i4SPH5ChSMEyFPlI,5843
3
+ hyper_connections-0.0.6.dist-info/METADATA,sha256=c_2oOz7OtvUeLt71AHeO5AXj8d_oIOj0Qo1fcBvCN1A,4979
4
+ hyper_connections-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.6.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.6.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
- hyper_connections/hyper_connections.py,sha256=HbSNev1_wsJeZir5pqa-5IQYW3Ovy3UJD7A59utSF9Q,4331
3
- hyper_connections-0.0.4.dist-info/METADATA,sha256=uRvNtEjISVXZkI_GVfwEk_1z_w2C4uzFYSP9m8cYpWs,4676
4
- hyper_connections-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.4.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.4.dist-info/RECORD,,