hyper-connections 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,6 +69,14 @@ class HyperConnections(Module):
69
69
 
70
70
  return expand_fn, reduce_fn
71
71
 
72
+ @classmethod
73
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams):
74
+
75
+ init_hyper_conn_fn = partial(cls, num_streams)
76
+ expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams)
77
+
78
+ return (init_hyper_conn_fn, *expand_reduce_fns)
79
+
72
80
  def width_connection(self, residuals):
73
81
  # width connection
74
82
 
@@ -114,7 +122,7 @@ class HyperConnections(Module):
114
122
 
115
123
  return output
116
124
 
117
- def forward(self, residuals, **branch_kwargs):
125
+ def forward(self, residuals, *branch_args, **branch_kwargs):
118
126
 
119
127
  branch_input, residuals, beta = self.width_connection(residuals)
120
128
 
@@ -124,7 +132,7 @@ class HyperConnections(Module):
124
132
  if not exists(self.branch):
125
133
  return branch_input, add_residual_fn
126
134
 
127
- branch_output = self.branch(branch_input, **branch_kwargs)
135
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
128
136
 
129
137
  (branch_output, *rest), tree_spec = tree_flatten(branch_output)
130
138
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -73,11 +73,11 @@ residual = branch(residual) + residual
73
73
 
74
74
  from hyper_connections import HyperConnections
75
75
 
76
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
76
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
77
 
78
78
  # 1. wrap your branch function
79
79
 
80
- hyper_conn_branch = HyperConnections(4, dim = 512, branch = branch)
80
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
81
81
 
82
82
  # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
83
83
 
@@ -112,11 +112,11 @@ residual = branch(residual) + residual
112
112
 
113
113
  from hyper_connections import HyperConnections
114
114
 
115
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
115
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
116
 
117
- # 1. instantiate hyper connection with correct number of streams (4 in this case)
117
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
118
 
119
- hyper_conn = HyperConnections(4, dim = 512)
119
+ hyper_conn = init_hyper_conn(dim = 512)
120
120
 
121
121
  # 2. expand to 4 streams
122
122
 
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
+ hyper_connections/hyper_connections.py,sha256=2lZcPuW4hEKet3r8caN-sN-PzRaBNL1q-V3_uA1lVaM,4613
3
+ hyper_connections-0.0.5.dist-info/METADATA,sha256=bko4lEEBiulROYd0aOC_8rIk0hPinBg5TBQWf9DZe9M,4753
4
+ hyper_connections-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.5.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
- hyper_connections/hyper_connections.py,sha256=91QtTtnpffmErIZvrnTtosSf4JgBqcyGvxftmka-EOw,4303
3
- hyper_connections-0.0.3.dist-info/METADATA,sha256=8XKTmC6Ys10uOyotPtvL17v4uZyepkWoeMRVM4B_TSQ,4676
4
- hyper_connections-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.3.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.3.dist-info/RECORD,,