hyper-connections 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -73,11 +73,11 @@ residual = branch(residual) + residual
73
73
 
74
74
  from hyper_connections import HyperConnections
75
75
 
76
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
76
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
77
 
78
78
  # 1. wrap your branch function
79
79
 
80
- hyper_conn_branch = HyperConnections(4, dim = 512, branch = branch)
80
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
81
81
 
82
82
  # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
83
83
 
@@ -112,11 +112,11 @@ residual = branch(residual) + residual
112
112
 
113
113
  from hyper_connections import HyperConnections
114
114
 
115
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
115
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
116
 
117
- # 1. instantiate hyper connection with correct number of streams (4 in this case)
117
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
118
 
119
- hyper_conn = HyperConnections(4, dim = 512)
119
+ hyper_conn = init_hyper_conn(dim = 512)
120
120
 
121
121
  # 2. expand to 4 streams
122
122
 
@@ -30,11 +30,11 @@ residual = branch(residual) + residual
30
30
 
31
31
  from hyper_connections import HyperConnections
32
32
 
33
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
33
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
34
34
 
35
35
  # 1. wrap your branch function
36
36
 
37
- hyper_conn_branch = HyperConnections(4, dim = 512, branch = branch)
37
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
38
38
 
39
39
  # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
40
40
 
@@ -69,11 +69,11 @@ residual = branch(residual) + residual
69
69
 
70
70
  from hyper_connections import HyperConnections
71
71
 
72
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
72
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
73
73
 
74
- # 1. instantiate hyper connection with correct number of streams (4 in this case)
74
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
75
75
 
76
- hyper_conn = HyperConnections(4, dim = 512)
76
+ hyper_conn = init_hyper_conn(dim = 512)
77
77
 
78
78
  # 2. expand to 4 streams
79
79
 
@@ -69,6 +69,14 @@ class HyperConnections(Module):
69
69
 
70
70
  return expand_fn, reduce_fn
71
71
 
72
+ @classmethod
73
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams):
74
+
75
+ init_hyper_conn_fn = partial(cls, num_streams)
76
+ expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams)
77
+
78
+ return (init_hyper_conn_fn, *expand_reduce_fns)
79
+
72
80
  def width_connection(self, residuals):
73
81
  # width connection
74
82
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }