hyper-connections 0.0.4__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -73,11 +73,11 @@ residual = branch(residual) + residual
73
73
 
74
74
  from hyper_connections import HyperConnections
75
75
 
76
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
76
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
77
 
78
78
  # 1. wrap your branch function
79
79
 
80
- hyper_conn_branch = HyperConnections(4, dim = 512, branch = branch)
80
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
81
81
 
82
82
  # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
83
83
 
@@ -112,11 +112,11 @@ residual = branch(residual) + residual
112
112
 
113
113
  from hyper_connections import HyperConnections
114
114
 
115
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
115
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
116
 
117
- # 1. instantiate hyper connection with correct number of streams (4 in this case)
117
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
118
 
119
- hyper_conn = HyperConnections(4, dim = 512)
119
+ hyper_conn = init_hyper_conn(dim = 512)
120
120
 
121
121
  # 2. expand to 4 streams
122
122
 
@@ -135,6 +135,12 @@ residual = add_residual(branch_output)
135
135
  residual = reduce_stream(residual)
136
136
  ```
137
137
 
138
+ To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
139
+
140
+ ```python
141
+ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disabled = True)
142
+ ```
143
+
138
144
  ## Citation
139
145
 
140
146
  ```bibtex
@@ -30,11 +30,11 @@ residual = branch(residual) + residual
30
30
 
31
31
  from hyper_connections import HyperConnections
32
32
 
33
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
33
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
34
34
 
35
35
  # 1. wrap your branch function
36
36
 
37
- hyper_conn_branch = HyperConnections(4, dim = 512, branch = branch)
37
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
38
38
 
39
39
  # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
40
40
 
@@ -69,11 +69,11 @@ residual = branch(residual) + residual
69
69
 
70
70
  from hyper_connections import HyperConnections
71
71
 
72
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
72
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
73
73
 
74
- # 1. instantiate hyper connection with correct number of streams (4 in this case)
74
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
75
75
 
76
- hyper_conn = HyperConnections(4, dim = 512)
76
+ hyper_conn = init_hyper_conn(dim = 512)
77
77
 
78
78
  # 2. expand to 4 streams
79
79
 
@@ -92,6 +92,12 @@ residual = add_residual(branch_output)
92
92
  residual = reduce_stream(residual)
93
93
  ```
94
94
 
95
+ To compare hyper connections to plain residual without changing the code, just pass `disable = True` when fetching the functions
96
+
97
+ ```python
98
+ HyperConnections.get_init_and_expand_reduce_stream_functions(4, disabled = True)
99
+ ```
100
+
95
101
  ## Citation
96
102
 
97
103
  ```bibtex
@@ -1,3 +1,4 @@
1
1
  from hyper_connections.hyper_connections import (
2
- HyperConnections
2
+ HyperConnections,
3
+ Residual
3
4
  )
@@ -18,7 +18,46 @@ def exists(v):
18
18
  def default(v, d):
19
19
  return v if exists(v) else d
20
20
 
21
- # main class
21
+ def identity(t):
22
+ return t
23
+
24
+ # main classes
25
+
26
+ # residual base class
27
+
28
+ class Residual(Module):
29
+ def __init__(
30
+ self,
31
+ *args,
32
+ branch = None,
33
+ **kwargs
34
+ ):
35
+ super().__init__()
36
+ self.branch = branch
37
+
38
+ def width_connection(self, residuals):
39
+ return residuals, residuals, dict()
40
+
41
+ def depth_connection(self, branch_output, residuals):
42
+ return branch_output + residuals
43
+
44
+ def forward(self, residuals, *branch_args, **branch_kwargs):
45
+
46
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
47
+
48
+ def add_residual_fn(branch_out):
49
+ return self.depth_connection(branch_out, residuals, **residual_kwargs)
50
+
51
+ if not exists(self.branch):
52
+ return branch_input, add_residual_fn
53
+
54
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
55
+
56
+ (branch_output, *rest), tree_spec = tree_flatten(branch_output)
57
+
58
+ branch_output = add_residual_fn(branch_output)
59
+
60
+ return tree_unflatten((branch_output, *rest), tree_spec)
22
61
 
23
62
  # hyper connection residual streams
24
63
 
@@ -69,6 +108,16 @@ class HyperConnections(Module):
69
108
 
70
109
  return expand_fn, reduce_fn
71
110
 
111
+ @classmethod
112
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams, disable = False):
113
+
114
+ hyper_conn_klass = cls if not disable else Residual
115
+
116
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
117
+ expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams) if not disable else (identity, identity)
118
+
119
+ return (init_hyper_conn_fn, *expand_reduce_fns)
120
+
72
121
  def width_connection(self, residuals):
73
122
  # width connection
74
123
 
@@ -98,9 +147,9 @@ class HyperConnections(Module):
98
147
  if self.channel_first:
99
148
  branch_input = rearrange(branch_input, 'b ... d -> b d ...')
100
149
 
101
- return branch_input, residuals, beta
150
+ return branch_input, residuals, dict(beta = beta)
102
151
 
103
- def depth_connection(self, branch_output, residuals, beta):
152
+ def depth_connection(self, branch_output, residuals, *, beta):
104
153
  # 'depth' connection
105
154
 
106
155
  if self.channel_first:
@@ -116,10 +165,10 @@ class HyperConnections(Module):
116
165
 
117
166
  def forward(self, residuals, *branch_args, **branch_kwargs):
118
167
 
119
- branch_input, residuals, beta = self.width_connection(residuals)
168
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
120
169
 
121
170
  def add_residual_fn(branch_out):
122
- return self.depth_connection(branch_out, residuals, beta)
171
+ return self.depth_connection(branch_out, residuals, **residual_kwargs)
123
172
 
124
173
  if not exists(self.branch):
125
174
  return branch_input, add_residual_fn
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.4"
3
+ version = "0.0.6"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }