hyper-connections 0.1.9__tar.gz → 0.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/PKG-INFO +1 -1
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/hyper_connections/hyper_connections.py +13 -4
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/pyproject.toml +1 -1
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/.gitignore +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/LICENSE +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/README.md +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/hyper-connections.png +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.1.9 → hyper_connections-0.1.10}/tests/test_hyper_connections.py +0 -0
|
@@ -34,22 +34,31 @@ def identity(t):
|
|
|
34
34
|
|
|
35
35
|
# main functions
|
|
36
36
|
|
|
37
|
-
def get_expand_reduce_stream_functions(num_streams, disable = False):
|
|
37
|
+
def get_expand_reduce_stream_functions(num_streams, add_stream_embed = False, dim = None, disable = False):
|
|
38
38
|
|
|
39
39
|
if num_streams == 1 or disable:
|
|
40
40
|
return (nn.Identity(), nn.Identity())
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
if add_stream_embed:
|
|
43
|
+
assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
|
|
44
|
+
|
|
45
|
+
expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
|
|
46
|
+
else:
|
|
47
|
+
expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
|
|
48
|
+
|
|
43
49
|
reduce_fn = Reduce(pattern = '(b s) ... -> b ...', reduction = 'sum', s = num_streams)
|
|
44
50
|
|
|
45
51
|
return expand_fn, reduce_fn
|
|
46
52
|
|
|
47
|
-
def get_init_and_expand_reduce_stream_functions(num_streams, disable = False):
|
|
53
|
+
def get_init_and_expand_reduce_stream_functions(num_streams, dim = None, add_stream_embed = False, disable = False):
|
|
48
54
|
|
|
49
55
|
hyper_conn_klass = HyperConnections if not disable else Residual
|
|
50
56
|
|
|
51
57
|
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams)
|
|
52
|
-
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, disable = disable)
|
|
58
|
+
expand_reduce_fns = get_expand_reduce_stream_functions(num_streams, add_stream_embed = add_stream_embed, dim = dim, disable = disable)
|
|
59
|
+
|
|
60
|
+
if exists(dim):
|
|
61
|
+
init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
|
|
53
62
|
|
|
54
63
|
return (init_hyper_conn_fn, *expand_reduce_fns)
|
|
55
64
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|