hyper-connections 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,8 +6,8 @@ from random import randrange
6
6
 
7
7
  import torch
8
8
  from torch import nn
9
- from torch.nn import Module
10
9
  import torch.nn.functional as F
10
+ from torch.nn import Module, ModuleList
11
11
  from torch.utils._pytree import tree_flatten, tree_unflatten
12
12
 
13
13
  from einops import rearrange, repeat, reduce, einsum
@@ -22,6 +22,9 @@ def exists(v):
22
22
  def default(v, d):
23
23
  return v if exists(v) else d
24
24
 
25
+ def divisible_by(num, den):
26
+ return (num % den) == 0
27
+
25
28
  def identity(t):
26
29
  return t
27
30
 
@@ -35,7 +38,7 @@ class HyperConnections(Module):
35
38
  num_residual_streams,
36
39
  *,
37
40
  dim,
38
- branch: Module | None = None,
41
+ branch: Module | tuple[Module, ...] | list[Module] | None = None,
39
42
  layer_index = None,
40
43
  tanh = True,
41
44
  channel_first = False,
@@ -46,7 +49,15 @@ class HyperConnections(Module):
46
49
  """
47
50
  super().__init__()
48
51
 
49
- self.branch = branch
52
+ self.branches = None
53
+
54
+ if isinstance(branch, Module):
55
+ branch = [branch]
56
+
57
+ if exists(branch):
58
+ assert divisible_by(num_branch_inputs, len(branch))
59
+
60
+ self.branches = ModuleList(branch)
50
61
 
51
62
  # activation, seemingly results were wishy washy depending on using tanh or not
52
63
 
@@ -57,12 +68,19 @@ class HyperConnections(Module):
57
68
  self.num_residual_streams = num_residual_streams
58
69
  self.num_branch_inputs = num_branch_inputs
59
70
 
60
- init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
61
-
62
71
  self.static_beta = nn.Parameter(torch.ones(num_residual_streams, num_branch_inputs))
63
72
 
64
- init_alpha0 = torch.zeros((num_residual_streams, num_branch_inputs))
65
- init_alpha0[init_residual_index, :] = 1.
73
+ # make sure each branch input receives from different residual stream on init
74
+
75
+ stream_branches = num_residual_streams * num_branch_inputs
76
+ layer_index = default(layer_index, randrange(stream_branches))
77
+ layer_offset = layer_index % stream_branches * num_branch_inputs
78
+
79
+ stream_seq = torch.arange(num_residual_streams)
80
+ branch_input_seq = torch.arange(num_branch_inputs)
81
+
82
+ init_alpha0 = rearrange(stream_seq, 's -> s 1') + rearrange(branch_input_seq, 'bi -> 1 bi') + layer_offset
83
+ init_alpha0 = ((init_alpha0 % num_residual_streams) == 0).float()
66
84
 
67
85
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
68
86
 
@@ -76,7 +94,7 @@ class HyperConnections(Module):
76
94
  self.channel_first = channel_first
77
95
 
78
96
  @classmethod
79
- def get_expand_reduce_stream_functions(cls, num_streams):
97
+ def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
80
98
  if disable:
81
99
  return (identity, identity)
82
100
 
@@ -148,13 +166,22 @@ class HyperConnections(Module):
148
166
 
149
167
  return output
150
168
 
151
- def decorate_branch(self, branch: Callable):
152
- assert not exists(self.branch), 'branch was already wrapped on init'
169
+ def decorate_branch(self, branch: Callable | tuple[Callable, ...] | list[Callable]):
170
+ assert not exists(self.branches), 'branch was already wrapped on init'
153
171
 
154
172
  def forward_and_add_residual(residual, *args, **kwargs):
155
173
  branch_input, add_residual = self.forward(residual)
156
174
 
157
- branch_output = branch(branch_input, *args, **kwargs)
175
+ if callable(branch):
176
+ branches = [branch]
177
+ else:
178
+ branches = branch
179
+
180
+ branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(branches))
181
+
182
+ branch_outputs = [fn(x, *args, **kwargs) for fn, x in zip(branches, branch_inputs)]
183
+
184
+ branch_output = torch.cat(branch_outputs)
158
185
 
159
186
  residual = add_residual(branch_output)
160
187
 
@@ -173,9 +200,13 @@ class HyperConnections(Module):
173
200
 
174
201
  return tree_unflatten((branch_out, *rest), tree_spec)
175
202
 
176
- if not exists(self.branch):
203
+ if not exists(self.branches):
177
204
  return branch_input, add_residual_fn
178
205
 
179
- branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
206
+ branch_inputs = rearrange(branch_input, '(br b) ... -> br b ...', br = len(self.branches))
207
+
208
+ branch_outputs = [fn(x, *branch_args, **branch_kwargs) for fn, x in zip(self.branches, branch_inputs)]
209
+
210
+ branch_output = torch.cat(branch_outputs)
180
211
 
181
212
  return add_residual_fn(branch_output)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -0,0 +1,7 @@
1
+ hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
+ hyper_connections/hyper_connections.py,sha256=LoyiGGDH81MLUcF9WKT39lA9Y1CXUAd-SYM9tJGQ61A,8099
3
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=l8oKkeVouFWQZY_0IVNTtgQPSQTA9qZdgLrfotRS_5w,7297
4
+ hyper_connections-0.0.17.dist-info/METADATA,sha256=dB2ZvPv63hTNJB6FEbyLlqxer4L7epwFMOvEVNjNOGk,5076
5
+ hyper_connections-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ hyper_connections-0.0.17.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
+ hyper_connections-0.0.17.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- hyper_connections/__init__.py,sha256=mAy66IuHqXM4XOyOZGt5mo2B2hfHdUk8jW31YnWNQTg,104
2
- hyper_connections/hyper_connections.py,sha256=LoyiGGDH81MLUcF9WKT39lA9Y1CXUAd-SYM9tJGQ61A,8099
3
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6KHdqvjUZCyHwf7LVByMMuoQtnEHBxRxdt-4Bw9HPcA,6130
4
- hyper_connections-0.0.15.dist-info/METADATA,sha256=DnZongXit3VTaT7DxmpK0RRfUhXqfUTjM32a-ofbrIw,5076
5
- hyper_connections-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- hyper_connections-0.0.15.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
7
- hyper_connections-0.0.15.dist-info/RECORD,,