hyper-connections 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,7 +48,7 @@ def get_expand_reduce_stream_functions(
48
48
  dim = None,
49
49
  disable = False
50
50
  ):
51
- if num_streams == 1 or disable:
51
+ if disable:
52
52
  return (nn.Identity(), nn.Identity())
53
53
 
54
54
  if add_stream_embed:
@@ -41,7 +41,7 @@ def identity(t):
41
41
 
42
42
  def get_expand_reduce_stream_functions(num_streams, disable = False):
43
43
 
44
- if num_streams == 1 or disable:
44
+ if disable:
45
45
  return (nn.Identity(), nn.Identity())
46
46
 
47
47
  expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
@@ -42,7 +42,7 @@ def identity(t):
42
42
  # main functions
43
43
 
44
44
  def get_expand_reduce_stream_functions(cls, num_streams, disable = False):
45
- if num_streams == 1 or disable:
45
+ if disable:
46
46
  return (nn.Identity(), nn.Identity())
47
47
 
48
48
  expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
@@ -33,7 +33,7 @@ def default(v, d):
33
33
 
34
34
  def get_expand_reduce_stream_functions(num_streams, disable = False):
35
35
 
36
- if num_streams == 1 or disable:
36
+ if disable:
37
37
  return (nn.Identity(), nn.Identity())
38
38
 
39
39
  expand_fn = Reduce(pattern = 'b ... -> (b s) ...', reduction = 'repeat', s = num_streams)
@@ -0,0 +1,541 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ from functools import partial
5
+ from random import randrange
6
+
7
+ import torch
8
+ from torch import nn, cat
9
+ import torch.nn.functional as F
10
+ from torch.nn import Module, Sequential
11
+ from torch.utils._pytree import tree_flatten, tree_unflatten
12
+
13
+ from einops import rearrange, repeat, reduce, einsum
14
+ from einops.layers.torch import Rearrange, Reduce
15
+
16
+ """
17
+ ein notation:
18
+ b - batch
19
+ d - feature dimension
20
+ s - residual streams
21
+ t - residual streams + num branch inputs
22
+ f - number of fractions (division of feature dimension space)
23
+ v - number of views for branch input
24
+ p - proposals
25
+ """
26
+
27
+ # helper functions
28
+
29
+ def exists(v):
30
+ return v is not None
31
+
32
+ def divisible_by(num, den):
33
+ return (num % den) == 0
34
+
35
+ def default(v, d):
36
+ return v if exists(v) else d
37
+
38
+ def identity(t):
39
+ return t
40
+
41
+ def add(x, y):
42
+ return x + y
43
+
44
+ # sinkhorn
45
+
46
+ def l1norm(t, dim):
47
+ return F.normalize(t, p = 1, dim = dim)
48
+
49
+ def sinkhorn_knopps(log_alpha, iters = 20):
50
+ assert log_alpha.shape[-2] == log_alpha.shape[-1]
51
+
52
+ dtype = log_alpha.dtype
53
+ log_alpha = log_alpha.float()
54
+
55
+ log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
56
+
57
+ alpha = log_alpha.exp()
58
+
59
+ for _ in range(iters):
60
+ alpha = l1norm(alpha, dim = -2)
61
+ alpha = l1norm(alpha, dim = -1)
62
+
63
+ return alpha.to(dtype)
64
+
65
+ def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
66
+ assert log_alpha.shape[-2] == log_alpha.shape[-1]
67
+
68
+ dtype = log_alpha.dtype
69
+ log_alpha = log_alpha.float()
70
+
71
+ for _ in range(iters):
72
+ log_alpha = log_alpha - log_alpha.logsumexp(dim = -2, keepdim = True)
73
+ log_alpha = log_alpha - log_alpha.logsumexp(dim = -1, keepdim = True)
74
+
75
+ return log_alpha.exp().to(dtype)
76
+
77
+ # main functions
78
+
79
+ def get_expand_reduce_stream_functions(
80
+ num_streams,
81
+ add_stream_embed = False,
82
+ add_attn_pool_reduce_stream = False,
83
+ dim = None,
84
+ disable = False
85
+ ):
86
+ if disable:
87
+ return (nn.Identity(), nn.Identity())
88
+
89
+ if add_stream_embed or add_attn_pool_reduce_stream:
90
+ assert exists(dim), '`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added'
91
+
92
+ if add_stream_embed:
93
+ expand_fn = StreamEmbed(num_streams, dim, expand_to_streams = True)
94
+ else:
95
+ expand_fn = Reduce('... d -> ... s d', 'repeat', s = num_streams)
96
+
97
+ if add_attn_pool_reduce_stream:
98
+ reduce_fn = AttentionPoolReduceStream(dim)
99
+ else:
100
+ reduce_fn = Reduce('... s d -> ... d', 'sum')
101
+
102
+ return expand_fn, reduce_fn
103
+
104
+ def get_init_and_expand_reduce_stream_functions(
105
+ num_streams,
106
+ num_fracs = 1,
107
+ dim = None,
108
+ add_stream_embed = False,
109
+ add_attn_pool_reduce_stream = False,
110
+ disable = None,
111
+ sinkhorn_iters = 20,
112
+ **kwargs
113
+ ):
114
+ disable = default(disable, num_streams == 1 and num_fracs == 1)
115
+
116
+ hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
117
+
118
+ kwargs.pop('add_attn_pool_reduce_stream', None)
119
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, **kwargs)
120
+ expand_reduce_fns = get_expand_reduce_stream_functions(
121
+ num_streams,
122
+ add_stream_embed = add_stream_embed,
123
+ add_attn_pool_reduce_stream = add_attn_pool_reduce_stream,
124
+ dim = dim,
125
+ disable = disable
126
+ )
127
+
128
+ if exists(dim):
129
+ init_hyper_conn_fn = partial(init_hyper_conn_fn, dim = dim)
130
+
131
+ return (init_hyper_conn_fn, *expand_reduce_fns)
132
+
133
+ # norms
134
+
135
+ class RMSNorm(Module):
136
+ def __init__(self, dim):
137
+ super().__init__()
138
+ self.scale = dim ** 0.5
139
+ self.gamma = nn.Parameter(torch.zeros(dim))
140
+
141
+ def forward(self, x):
142
+ return F.normalize(x, dim = -1) * self.scale * (self.gamma + 1)
143
+
144
+ # main classes
145
+
146
+ # residual base class
147
+
148
+ class Residual(Module):
149
+ def __init__(
150
+ self,
151
+ *args,
152
+ branch: Module | None = None,
153
+ residual_transform: Module | None = None,
154
+ **kwargs
155
+ ):
156
+ super().__init__()
157
+ self.branch = branch
158
+ self.residual_transform = default(residual_transform, nn.Identity())
159
+
160
+ def width_connection(
161
+ self,
162
+ residuals
163
+ ):
164
+ return residuals, residuals, dict()
165
+
166
+ def depth_connection(
167
+ self,
168
+ branch_output,
169
+ residuals
170
+ ):
171
+ return branch_output + self.residual_transform(residuals)
172
+
173
+ def decorate_branch(
174
+ self,
175
+ branch: Callable
176
+ ):
177
+ assert not exists(self.branch), 'branch was already wrapped on init'
178
+
179
+ def forward_and_add_residual(residual, *args, **kwargs):
180
+ branch_input, add_residual = self.forward(residual)
181
+
182
+ branch_output = branch(branch_input, *args, **kwargs)
183
+
184
+ residual = add_residual(branch_output)
185
+
186
+ return residual
187
+
188
+ return forward_and_add_residual
189
+
190
+ def forward(
191
+ self,
192
+ residuals,
193
+ *branch_args,
194
+ **branch_kwargs
195
+ ):
196
+
197
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
198
+
199
+ def add_residual_fn(branch_out):
200
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
201
+
202
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
203
+
204
+ return tree_unflatten((branch_out, *rest), tree_spec)
205
+
206
+ if not exists(self.branch):
207
+ return branch_input, add_residual_fn
208
+
209
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
210
+
211
+ return add_residual_fn(branch_output)
212
+
213
+ # hyper connection residual streams
214
+
215
+ class ManifoldConstrainedHyperConnections(Module):
216
+ def __init__(
217
+ self,
218
+ num_residual_streams,
219
+ *,
220
+ dim,
221
+ branch: Module | None = None,
222
+ layer_index = None,
223
+ dropout = 0.,
224
+ residual_transform: Module | None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
225
+ add_branch_out_to_residual = True, # will disable depth connections (weighted residual sum with beta) if set False
226
+ num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
227
+ depth_residual_fn = add,
228
+ num_fracs = 1, # https://arxiv.org/abs/2503.14125
229
+ sinkhorn_iters = 20,
230
+ log_domain_sinkhorn = False,
231
+ residual_mix_constraint_fn: Callable | None = None,
232
+ forward_method_names: tuple[str, ...] = (),
233
+ num_dynamic_alpha_proposals = 1,
234
+
235
+ ):
236
+ """
237
+ Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
238
+ """
239
+ super().__init__()
240
+
241
+ self.branch = branch
242
+
243
+ # frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
244
+
245
+ assert num_fracs >= 1
246
+
247
+ self.num_fracs = num_fracs
248
+ self.has_fracs = num_fracs > 1
249
+
250
+ self.split_fracs = Rearrange('b ... (f d) -> b ... f d', f = num_fracs)
251
+ self.merge_fracs = Rearrange('b ... f d -> b ... (f d)')
252
+
253
+ assert divisible_by(dim, num_fracs), f'feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})'
254
+
255
+ dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
256
+
257
+ # they used layernorm in paper, but rmsnorm is fine given what we know now
258
+
259
+ self.norm = RMSNorm(dim)
260
+
261
+ assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
262
+
263
+ self.num_residual_streams = num_residual_streams
264
+ init_residual_index = default(layer_index, randrange(num_residual_streams)) % num_residual_streams # just choose one random residual stream if layer index not given
265
+
266
+ # handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
267
+
268
+ num_residual_streams_fracs = num_residual_streams * num_fracs
269
+ num_input_views_fracs = num_input_views * num_fracs
270
+
271
+ self.num_fracs = num_fracs
272
+
273
+ # width num residual streams
274
+
275
+ assert num_input_views >= 1
276
+ self.num_input_views = num_input_views
277
+
278
+ # number of dynamic alpha proposals, for averaging Hres across proposals
279
+
280
+ self.has_dynamic_alpha_proposals = num_dynamic_alpha_proposals > 1
281
+ self.num_dynamic_alpha_proposals = num_dynamic_alpha_proposals
282
+
283
+ # width connection
284
+
285
+ init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
286
+ init_alpha0[init_residual_index, :] = 1.
287
+
288
+ self.static_alpha = nn.Parameter(cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim = 1))
289
+
290
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(num_dynamic_alpha_proposals, dim, num_residual_streams_fracs + num_input_views_fracs))
291
+
292
+ self.pre_branch_scale = nn.Parameter(torch.ones(1) * 1e-2)
293
+ self.residual_scale = nn.Parameter(torch.ones(1) * 1e-2)
294
+
295
+ # depth connection related (beta)
296
+
297
+ self.add_branch_out_to_residual = add_branch_out_to_residual
298
+
299
+ if add_branch_out_to_residual:
300
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams, num_fracs, 1))
301
+
302
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim, num_fracs))
303
+
304
+ self.h_post_scale = nn.Parameter(torch.ones(()) * 1e-2)
305
+
306
+ # Hres constraint related
307
+ # by default is sinkhorn
308
+
309
+ self.residual_mix_constraint_fn = default(
310
+ residual_mix_constraint_fn,
311
+ partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
312
+ )
313
+
314
+ # dropouts
315
+
316
+ self.dropout = nn.Dropout(dropout)
317
+
318
+ # maybe residual transform
319
+
320
+ self.residual_transform = default(residual_transform, nn.Identity())
321
+
322
+ # maybe custom depth connection residual function
323
+ # this is to prepare for gating the addition of the branch outputs to the residual streams
324
+ # needed for memory lanes a la RMT / LMM
325
+
326
+ self.depth_residual_fn = depth_residual_fn
327
+
328
+ # forwarding method names
329
+
330
+ self.forward_method_names = forward_method_names
331
+
332
+ for forward_method_name in self.forward_method_names:
333
+ assert not hasattr(self, forward_method_name)
334
+
335
+ fn = getattr(self.branch, forward_method_name)
336
+ setattr(self, forward_method_name, fn)
337
+
338
+ def width_connection(
339
+ self,
340
+ residuals
341
+ ):
342
+ streams, fracs = self.num_residual_streams, self.num_fracs
343
+
344
+ residuals = self.residual_transform(residuals)
345
+
346
+ # width connection
347
+
348
+ # split out fractions
349
+
350
+ residuals = self.split_fracs(residuals)
351
+
352
+ # norm
353
+
354
+ normed = self.norm(residuals)
355
+
356
+ # alpha for weighted sum of residuals going into branch
357
+
358
+ dtype = residuals.dtype
359
+
360
+ normed = normed.float()
361
+
362
+ wc_weight = einsum(normed, self.dynamic_alpha_fn.float(), '... d, p d mix -> p ... mix')
363
+ wc_weight = rearrange(wc_weight, '... s1 f2 mix -> ... (s1 f2) mix')
364
+
365
+ pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
366
+ residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
367
+ alpha_scale = cat((pre_branch_scale, residual_scale))
368
+
369
+ alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
370
+
371
+ dynamic_alpha = wc_weight * alpha_scale
372
+
373
+ alpha = dynamic_alpha + self.static_alpha.float()
374
+
375
+ # the alpha is now split and "manifold constrained" with sinkhorn and sigmoid
376
+
377
+ alpha_pre, alpha_residual = alpha[..., :self.num_input_views * self.num_fracs], alpha[..., self.num_input_views * self.num_fracs:]
378
+
379
+ alpha_pre = alpha_pre.sigmoid()
380
+
381
+ alpha_residual = self.residual_mix_constraint_fn(alpha_residual)
382
+
383
+ alpha = cat((alpha_pre, alpha_residual), dim = -1)
384
+
385
+ if self.has_dynamic_alpha_proposals:
386
+ alpha = reduce(alpha, 'p ... -> ...', 'mean')
387
+ else:
388
+ alpha = rearrange(alpha, '1 ... -> ...')
389
+
390
+ alpha = rearrange(alpha, '... (s f) t -> ... s f t', s = streams) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
391
+
392
+ # beta for weights from branch output back to residual streams
393
+
394
+ beta = None
395
+
396
+ if self.add_branch_out_to_residual:
397
+ dc_weight = normed @ self.dynamic_beta_fn.float()
398
+
399
+ dynamic_beta = dc_weight * self.h_post_scale.float()
400
+
401
+ beta = dynamic_beta + self.static_beta.float()
402
+
403
+ beta = beta.sigmoid() * 2 # for "H_post" manifold constraint
404
+
405
+ mix_h = einsum(alpha, residuals.float(), '... s f tf, ... s f d -> ... tf d')
406
+
407
+ mix_h = rearrange(mix_h, '... (t f) d -> ... t f d', f = fracs)
408
+
409
+ if self.num_input_views == 1:
410
+ branch_input, residuals = mix_h[..., 0, :, :], mix_h[..., 1:, :, :]
411
+ else:
412
+ branch_input, residuals = mix_h[..., :self.num_input_views, :, :], mix_h[..., self.num_input_views:, :, :]
413
+ branch_input = rearrange(branch_input, 'b ... v f d -> v b ... f d')
414
+
415
+ # maybe merge fractions back
416
+
417
+ branch_input = self.merge_fracs(branch_input)
418
+
419
+ residuals = rearrange(residuals, 'b ... s f d -> b ... s (f d)')
420
+
421
+ branch_input, residuals = tuple(t.to(dtype) for t in (branch_input, residuals))
422
+
423
+ if exists(beta):
424
+ beta = beta.to(dtype)
425
+
426
+ return branch_input, residuals, dict(beta = beta)
427
+
428
+ def depth_connection(
429
+ self,
430
+ branch_output,
431
+ residuals,
432
+ *,
433
+ beta
434
+ ):
435
+ assert self.add_branch_out_to_residual
436
+
437
+ # maybe split fractions
438
+
439
+ branch_output = self.split_fracs(branch_output)
440
+
441
+ # 'depth' connection
442
+
443
+ dtype = residuals.dtype
444
+
445
+ output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... s f1 f2 -> b ... s f2 d')
446
+
447
+ # merge merge back fractions
448
+
449
+ output = self.merge_fracs(output)
450
+
451
+ # channel first
452
+
453
+ residuals = self.depth_residual_fn(output.to(dtype), residuals)
454
+
455
+ return self.dropout(residuals)
456
+
457
+ def decorate_branch(
458
+ self,
459
+ branch: Callable
460
+ ):
461
+ assert not exists(self.branch), 'branch was already wrapped on init'
462
+
463
+ def forward_and_add_residual(residual, *args, **kwargs):
464
+ branch_input, add_residual = self.forward(residual)
465
+
466
+ branch_output = branch(branch_input, *args, **kwargs)
467
+
468
+ residual = add_residual(branch_output)
469
+
470
+ return residual
471
+
472
+ return forward_and_add_residual
473
+
474
+ def forward(
475
+ self,
476
+ residuals,
477
+ *branch_args,
478
+ **branch_kwargs
479
+ ):
480
+
481
+ branch_input, residuals, residual_kwargs = self.width_connection(residuals)
482
+
483
+ def add_residual_fn(branch_out):
484
+
485
+ if not self.add_branch_out_to_residual:
486
+ return branch_out
487
+
488
+ (branch_out, *rest), tree_spec = tree_flatten(branch_out)
489
+
490
+ branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
491
+
492
+ return tree_unflatten((branch_out, *rest), tree_spec)
493
+
494
+ if not exists(self.branch):
495
+ return branch_input, add_residual_fn
496
+
497
+ branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
498
+
499
+ return add_residual_fn(branch_output)
500
+
501
+ mHC = ManifoldConstrainedHyperConnections
502
+
503
+ ManifoldConstrainedHyperConnections.get_expand_reduce_stream_functions = staticmethod(get_expand_reduce_stream_functions)
504
+ ManifoldConstrainedHyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(get_init_and_expand_reduce_stream_functions)
505
+
506
+ # stream embed
507
+
508
+ class StreamEmbed(Module):
509
+ def __init__(
510
+ self,
511
+ num_streams,
512
+ dim,
513
+ expand_to_streams = False
514
+ ):
515
+ super().__init__()
516
+ self.num_streams = num_streams
517
+
518
+ self.expand_to_streams = expand_to_streams
519
+ self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
520
+
521
+ def forward(self, residuals):
522
+
523
+ if self.expand_to_streams:
524
+ residuals = repeat(residuals, '... d -> ... s d', s = self.num_streams)
525
+
526
+ return residuals + self.stream_embed
527
+
528
+ # attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
529
+
530
+ class AttentionPoolReduceStream(Module):
531
+ def __init__(self, dim):
532
+ super().__init__()
533
+ self.to_attn_logits = nn.Linear(dim, dim, bias = False)
534
+ self.to_attn_logits.weight.data.copy_(torch.eye(dim))
535
+
536
+ def forward(self, residuals):
537
+
538
+ attn_logits = self.to_attn_logits(residuals)
539
+ attn = attn_logits.softmax(dim = -2)
540
+
541
+ return einsum(residuals, attn, '... s d, ... s d -> ... d')
@@ -78,7 +78,7 @@ def get_expand_reduce_stream_functions(
78
78
  dim = None,
79
79
  disable = False
80
80
  ):
81
- if num_streams == 1 or disable:
81
+ if disable:
82
82
  return (nn.Identity(), nn.Identity())
83
83
 
84
84
  if add_stream_embed:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.2
3
+ Version: 0.4.3
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -34,8 +34,8 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: einops>=0.8.0
38
- Requires-Dist: torch>=2.3
37
+ Requires-Dist: einops>=0.8.1
38
+ Requires-Dist: torch>=2.5
39
39
  Provides-Extra: examples
40
40
  Description-Content-Type: text/markdown
41
41
 
@@ -0,0 +1,13 @@
1
+ hyper_connections/__init__.py,sha256=BAGwi53ozXcnfPJAGur0RHA4vcolF1ORBhbZ9a8SkrE,602
2
+ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX3xKzAF40E,14884
3
+ hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
4
+ hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
5
+ hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
6
+ hyper_connections/mHCv2.py,sha256=j3A4XbisBXzqdW9vYCrPRrK2M6iPAqMOjxGCj3lsQ-g,16810
7
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=rQzAIkP84adzEVyrMasqMuZV76-6LAioUbwKnABcBto,18315
8
+ hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
9
+ hyper_connections/vit.py,sha256=BOWVfCAIzDQdnTq8OBzNUyiKGGILYZkIQ6mr1GKJVB0,5225
10
+ hyper_connections-0.4.3.dist-info/METADATA,sha256=h_zeG-qAgyg-vDktRMaPpGuYzmA-kxrcUmPvVQ4CYvs,6704
11
+ hyper_connections-0.4.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
12
+ hyper_connections-0.4.3.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
13
+ hyper_connections-0.4.3.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- hyper_connections/__init__.py,sha256=BAGwi53ozXcnfPJAGur0RHA4vcolF1ORBhbZ9a8SkrE,602
2
- hyper_connections/hyper_connections.py,sha256=rqFJj3U0LF3uDKNKNPBpRrmf0oa2BGWVbD6S-xdZdLo,14904
3
- hyper_connections/hyper_connections_channel_first.py,sha256=Mh_hzhTi96ZoOPmhSKwUaF4TbHpNqhs83wNe5hNuL7o,6532
4
- hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
5
- hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
6
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=Q6KPBL7XDwfuQtk6INOpvFVNJ663WhQEhsDY_vZGhws,18335
7
- hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
8
- hyper_connections/vit.py,sha256=BOWVfCAIzDQdnTq8OBzNUyiKGGILYZkIQ6mr1GKJVB0,5225
9
- hyper_connections-0.4.2.dist-info/METADATA,sha256=VdKA7lKvg8fuXL5WCyG5R57puM3Q-QZho3FZXt1hHXI,6704
10
- hyper_connections-0.4.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- hyper_connections-0.4.2.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
12
- hyper_connections-0.4.2.dist-info/RECORD,,