hyper-connections 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/manifold_constrained_hyper_connections.py +26 -11
- {hyper_connections-0.3.6.dist-info → hyper_connections-0.3.7.dist-info}/METADATA +1 -1
- {hyper_connections-0.3.6.dist-info → hyper_connections-0.3.7.dist-info}/RECORD +5 -5
- {hyper_connections-0.3.6.dist-info → hyper_connections-0.3.7.dist-info}/WHEEL +0 -0
- {hyper_connections-0.3.6.dist-info → hyper_connections-0.3.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -46,6 +46,9 @@ def l1norm(t, dim):
|
|
|
46
46
|
return F.normalize(t, p = 1, dim = dim)
|
|
47
47
|
|
|
48
48
|
def sinkhorn_knopps(log_alpha, iters = 20):
|
|
49
|
+
dtype = log_alpha.dtype
|
|
50
|
+
log_alpha = log_alpha.float()
|
|
51
|
+
|
|
49
52
|
log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
|
|
50
53
|
|
|
51
54
|
alpha = log_alpha.exp()
|
|
@@ -54,7 +57,7 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
54
57
|
alpha = l1norm(alpha, dim = -2)
|
|
55
58
|
alpha = l1norm(alpha, dim = -1)
|
|
56
59
|
|
|
57
|
-
return alpha
|
|
60
|
+
return alpha.to(dtype)
|
|
58
61
|
|
|
59
62
|
# main functions
|
|
60
63
|
|
|
@@ -316,17 +319,21 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
316
319
|
|
|
317
320
|
# alpha for weighted sum of residuals going into branch
|
|
318
321
|
|
|
319
|
-
|
|
322
|
+
dtype = residuals.dtype
|
|
323
|
+
|
|
324
|
+
normed = normed.float()
|
|
325
|
+
|
|
326
|
+
wc_weight = normed @ self.dynamic_alpha_fn.float()
|
|
320
327
|
|
|
321
|
-
pre_branch_scale = repeat(self.pre_branch_scale, '1 -> s', s = self.num_fracs)
|
|
322
|
-
residual_scale = repeat(self.residual_scale, '1 -> s', s = self.num_fracs * streams)
|
|
328
|
+
pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
|
|
329
|
+
residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
|
|
323
330
|
alpha_scale = cat((pre_branch_scale, residual_scale))
|
|
324
331
|
|
|
325
332
|
alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
|
|
326
333
|
|
|
327
334
|
dynamic_alpha = wc_weight * alpha_scale
|
|
328
335
|
|
|
329
|
-
static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
|
|
336
|
+
static_alpha = rearrange(self.static_alpha.float(), '(f s) d -> f s d', s = streams)
|
|
330
337
|
|
|
331
338
|
alpha = dynamic_alpha + static_alpha
|
|
332
339
|
|
|
@@ -351,20 +358,20 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
351
358
|
beta = None
|
|
352
359
|
|
|
353
360
|
if self.add_branch_out_to_residual:
|
|
354
|
-
dc_weight = normed @ self.dynamic_beta_fn
|
|
361
|
+
dc_weight = normed @ self.dynamic_beta_fn.float()
|
|
355
362
|
|
|
356
363
|
dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
|
|
357
364
|
|
|
358
365
|
if not self.has_fracs:
|
|
359
366
|
dc_weight = rearrange(dc_weight, '... -> ... 1')
|
|
360
367
|
|
|
361
|
-
dynamic_beta = dc_weight * self.h_post_scale
|
|
368
|
+
dynamic_beta = dc_weight * self.h_post_scale.float()
|
|
362
369
|
|
|
363
|
-
static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
|
|
370
|
+
static_beta = rearrange(self.static_beta.float(), '... (s f) -> ... s f', s = streams)
|
|
364
371
|
|
|
365
372
|
beta = dynamic_beta + static_beta
|
|
366
373
|
|
|
367
|
-
mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
374
|
+
mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
368
375
|
|
|
369
376
|
if self.num_input_views == 1:
|
|
370
377
|
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
@@ -379,6 +386,12 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
379
386
|
|
|
380
387
|
branch_input = self.merge_fracs(branch_input)
|
|
381
388
|
|
|
389
|
+
branch_input = branch_input.to(dtype)
|
|
390
|
+
residuals = residuals.to(dtype)
|
|
391
|
+
|
|
392
|
+
if exists(beta):
|
|
393
|
+
beta = beta.to(dtype)
|
|
394
|
+
|
|
382
395
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
383
396
|
|
|
384
397
|
def depth_connection(
|
|
@@ -399,7 +412,9 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
399
412
|
if self.channel_first:
|
|
400
413
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
401
414
|
|
|
402
|
-
|
|
415
|
+
dtype = residuals.dtype
|
|
416
|
+
|
|
417
|
+
output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
|
|
403
418
|
|
|
404
419
|
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
405
420
|
|
|
@@ -412,7 +427,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
412
427
|
if self.channel_first:
|
|
413
428
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
414
429
|
|
|
415
|
-
residuals = self.depth_residual_fn(output, residuals)
|
|
430
|
+
residuals = self.depth_residual_fn(output.to(dtype), residuals)
|
|
416
431
|
|
|
417
432
|
return self.dropout(residuals)
|
|
418
433
|
|
|
@@ -3,9 +3,9 @@ hyper_connections/hyper_connections.py,sha256=UHxZhyRwx89GRgmQVt53Gv6JeNhX8UCjjE
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=_1PM4LRcPpDqfCiHlBMc2nLV08sXM2nuyZGSKTiuqbE,6818
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=6BXKdSwyx6wdQVseebKG2EQkhVaVLrrepOlL8lLnex4,7855
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=ueT3CJPHrt5hRU7q1bFF0rANWJh_pXqclt6HiUu1gBY,11331
|
|
6
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=
|
|
6
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=gJzc9oHZjhC3S85HiXGQck5qBNcRWGCjVsDfMXwqPxo,16961
|
|
7
7
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
8
|
-
hyper_connections-0.3.
|
|
9
|
-
hyper_connections-0.3.
|
|
10
|
-
hyper_connections-0.3.
|
|
11
|
-
hyper_connections-0.3.
|
|
8
|
+
hyper_connections-0.3.7.dist-info/METADATA,sha256=Pyj7qVEMj6Szb_PD80eYoN3O4fwv8DjsYSzFd8EY_bo,6704
|
|
9
|
+
hyper_connections-0.3.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
10
|
+
hyper_connections-0.3.7.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
11
|
+
hyper_connections-0.3.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|