hyper-connections 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,6 +69,14 @@ class HyperConnections(Module):
69
69
 
70
70
  return expand_fn, reduce_fn
71
71
 
72
+ @classmethod
73
+ def get_init_and_expand_reduce_stream_functions(cls, num_streams):
74
+
75
+ init_hyper_conn_fn = partial(cls, num_streams)
76
+ expand_reduce_fns = cls.get_expand_reduce_stream_functions(num_streams)
77
+
78
+ return (init_hyper_conn_fn, *expand_reduce_fns)
79
+
72
80
  def width_connection(self, residuals):
73
81
  # width connection
74
82
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -73,11 +73,11 @@ residual = branch(residual) + residual
73
73
 
74
74
  from hyper_connections import HyperConnections
75
75
 
76
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
76
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
77
77
 
78
78
  # 1. wrap your branch function
79
79
 
80
- hyper_conn_branch = HyperConnections(4, dim = 512, branch = branch)
80
+ hyper_conn_branch = init_hyper_conn(dim = 512, branch = branch)
81
81
 
82
82
  # 2. expand to 4 streams, this must be done before your trunk, typically a for-loop with many branch functions
83
83
 
@@ -112,11 +112,11 @@ residual = branch(residual) + residual
112
112
 
113
113
  from hyper_connections import HyperConnections
114
114
 
115
- expand_stream, reduce_stream = HyperConnections.get_expand_reduce_stream_functions(4)
115
+ init_hyper_conn, expand_stream, reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(4)
116
116
 
117
- # 1. instantiate hyper connection with correct number of streams (4 in this case)
117
+ # 1. instantiate hyper connection with correct number of streams (4 in this case) - or use the init function above
118
118
 
119
- hyper_conn = HyperConnections(4, dim = 512)
119
+ hyper_conn = init_hyper_conn(dim = 512)
120
120
 
121
121
  # 2. expand to 4 streams
122
122
 
@@ -0,0 +1,6 @@
1
+ hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
+ hyper_connections/hyper_connections.py,sha256=2lZcPuW4hEKet3r8caN-sN-PzRaBNL1q-V3_uA1lVaM,4613
3
+ hyper_connections-0.0.5.dist-info/METADATA,sha256=bko4lEEBiulROYd0aOC_8rIk0hPinBg5TBQWf9DZe9M,4753
4
+ hyper_connections-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hyper_connections-0.0.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
+ hyper_connections-0.0.5.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hyper_connections/__init__.py,sha256=xXx2Mb-dS1__UPzT-5VR1XZmyqKSSkT1DU6bAcK8jR0,73
2
- hyper_connections/hyper_connections.py,sha256=HbSNev1_wsJeZir5pqa-5IQYW3Ovy3UJD7A59utSF9Q,4331
3
- hyper_connections-0.0.4.dist-info/METADATA,sha256=uRvNtEjISVXZkI_GVfwEk_1z_w2C4uzFYSP9m8cYpWs,4676
4
- hyper_connections-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hyper_connections-0.0.4.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
6
- hyper_connections-0.0.4.dist-info/RECORD,,