hyper-connections 0.0.19__tar.gz → 0.0.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/PKG-INFO +1 -1
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/hyper_connections/hyper_connections.py +10 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +10 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/pyproject.toml +1 -1
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/.gitignore +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/LICENSE +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/README.md +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/hyper-connections.png +0 -0
- {hyper_connections-0.0.19 → hyper_connections-0.0.20}/hyper_connections/__init__.py +0 -0
{hyper_connections-0.0.19 → hyper_connections-0.0.20}/hyper_connections/hyper_connections.py
RENAMED
|
@@ -12,6 +12,14 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
"""
|
|
16
|
+
ein notation:
|
|
17
|
+
b - batch
|
|
18
|
+
d - feature dimension
|
|
19
|
+
s - residual streams
|
|
20
|
+
t - residual streams + num branch inputs
|
|
21
|
+
"""
|
|
22
|
+
|
|
15
23
|
# helper functions
|
|
16
24
|
|
|
17
25
|
def exists(v):
|
|
@@ -102,6 +110,8 @@ class HyperConnections(Module):
|
|
|
102
110
|
|
|
103
111
|
self.norm = nn.RMSNorm(dim) # they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
104
112
|
|
|
113
|
+
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
114
|
+
|
|
105
115
|
self.num_residual_streams = num_residual_streams
|
|
106
116
|
init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
|
|
107
117
|
|
|
@@ -12,6 +12,16 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
12
12
|
|
|
13
13
|
from einops import rearrange, repeat, reduce, einsum
|
|
14
14
|
|
|
15
|
+
"""
|
|
16
|
+
ein notation:
|
|
17
|
+
b - batch
|
|
18
|
+
d - feature dimension
|
|
19
|
+
s - residual streams
|
|
20
|
+
i - branch inputs
|
|
21
|
+
br - branch functions
|
|
22
|
+
t - residual streams + num branch inputs
|
|
23
|
+
"""
|
|
24
|
+
|
|
15
25
|
from hyper_connections.hyper_connections import Residual, StreamEmbed
|
|
16
26
|
|
|
17
27
|
# helper functions
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|