hyper-connections 0.3.6__tar.gz → 0.3.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/PKG-INFO +1 -1
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/manifold_constrained_hyper_connections.py +26 -11
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/pyproject.toml +1 -1
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/tests/test_hyper_connections.py +21 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/.gitignore +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/LICENSE +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/README.md +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper-connections.png +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.7}/hyper_connections/residuals.py +0 -0
|
@@ -46,6 +46,9 @@ def l1norm(t, dim):
|
|
|
46
46
|
return F.normalize(t, p = 1, dim = dim)
|
|
47
47
|
|
|
48
48
|
def sinkhorn_knopps(log_alpha, iters = 20):
|
|
49
|
+
dtype = log_alpha.dtype
|
|
50
|
+
log_alpha = log_alpha.float()
|
|
51
|
+
|
|
49
52
|
log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
|
|
50
53
|
|
|
51
54
|
alpha = log_alpha.exp()
|
|
@@ -54,7 +57,7 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
54
57
|
alpha = l1norm(alpha, dim = -2)
|
|
55
58
|
alpha = l1norm(alpha, dim = -1)
|
|
56
59
|
|
|
57
|
-
return alpha
|
|
60
|
+
return alpha.to(dtype)
|
|
58
61
|
|
|
59
62
|
# main functions
|
|
60
63
|
|
|
@@ -316,17 +319,21 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
316
319
|
|
|
317
320
|
# alpha for weighted sum of residuals going into branch
|
|
318
321
|
|
|
319
|
-
|
|
322
|
+
dtype = residuals.dtype
|
|
323
|
+
|
|
324
|
+
normed = normed.float()
|
|
325
|
+
|
|
326
|
+
wc_weight = normed @ self.dynamic_alpha_fn.float()
|
|
320
327
|
|
|
321
|
-
pre_branch_scale = repeat(self.pre_branch_scale, '1 -> s', s = self.num_fracs)
|
|
322
|
-
residual_scale = repeat(self.residual_scale, '1 -> s', s = self.num_fracs * streams)
|
|
328
|
+
pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
|
|
329
|
+
residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
|
|
323
330
|
alpha_scale = cat((pre_branch_scale, residual_scale))
|
|
324
331
|
|
|
325
332
|
alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
|
|
326
333
|
|
|
327
334
|
dynamic_alpha = wc_weight * alpha_scale
|
|
328
335
|
|
|
329
|
-
static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
|
|
336
|
+
static_alpha = rearrange(self.static_alpha.float(), '(f s) d -> f s d', s = streams)
|
|
330
337
|
|
|
331
338
|
alpha = dynamic_alpha + static_alpha
|
|
332
339
|
|
|
@@ -351,20 +358,20 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
351
358
|
beta = None
|
|
352
359
|
|
|
353
360
|
if self.add_branch_out_to_residual:
|
|
354
|
-
dc_weight = normed @ self.dynamic_beta_fn
|
|
361
|
+
dc_weight = normed @ self.dynamic_beta_fn.float()
|
|
355
362
|
|
|
356
363
|
dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
|
|
357
364
|
|
|
358
365
|
if not self.has_fracs:
|
|
359
366
|
dc_weight = rearrange(dc_weight, '... -> ... 1')
|
|
360
367
|
|
|
361
|
-
dynamic_beta = dc_weight * self.h_post_scale
|
|
368
|
+
dynamic_beta = dc_weight * self.h_post_scale.float()
|
|
362
369
|
|
|
363
|
-
static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
|
|
370
|
+
static_beta = rearrange(self.static_beta.float(), '... (s f) -> ... s f', s = streams)
|
|
364
371
|
|
|
365
372
|
beta = dynamic_beta + static_beta
|
|
366
373
|
|
|
367
|
-
mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
374
|
+
mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
368
375
|
|
|
369
376
|
if self.num_input_views == 1:
|
|
370
377
|
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
@@ -379,6 +386,12 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
379
386
|
|
|
380
387
|
branch_input = self.merge_fracs(branch_input)
|
|
381
388
|
|
|
389
|
+
branch_input = branch_input.to(dtype)
|
|
390
|
+
residuals = residuals.to(dtype)
|
|
391
|
+
|
|
392
|
+
if exists(beta):
|
|
393
|
+
beta = beta.to(dtype)
|
|
394
|
+
|
|
382
395
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
383
396
|
|
|
384
397
|
def depth_connection(
|
|
@@ -399,7 +412,9 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
399
412
|
if self.channel_first:
|
|
400
413
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
401
414
|
|
|
402
|
-
|
|
415
|
+
dtype = residuals.dtype
|
|
416
|
+
|
|
417
|
+
output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
|
|
403
418
|
|
|
404
419
|
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
405
420
|
|
|
@@ -412,7 +427,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
412
427
|
if self.channel_first:
|
|
413
428
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
414
429
|
|
|
415
|
-
residuals = self.depth_residual_fn(output, residuals)
|
|
430
|
+
residuals = self.depth_residual_fn(output.to(dtype), residuals)
|
|
416
431
|
|
|
417
432
|
return self.dropout(residuals)
|
|
418
433
|
|
|
@@ -231,3 +231,24 @@ def test_channel_first_hyper_connection(disable):
|
|
|
231
231
|
after_residual = reduce_stream(residual)
|
|
232
232
|
|
|
233
233
|
assert before_residual.shape == after_residual.shape
|
|
234
|
+
|
|
235
|
+
def test_mhc_dtype_restoration():
|
|
236
|
+
from hyper_connections.manifold_constrained_hyper_connections import ManifoldConstrainedHyperConnections
|
|
237
|
+
|
|
238
|
+
mhc = ManifoldConstrainedHyperConnections(
|
|
239
|
+
num_residual_streams = 4,
|
|
240
|
+
dim = 64,
|
|
241
|
+
add_branch_out_to_residual = True
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
residual = torch.randn(4, 1, 64).half()
|
|
245
|
+
|
|
246
|
+
branch_input, _, residual_kwargs = mhc.width_connection(residual)
|
|
247
|
+
|
|
248
|
+
assert branch_input.dtype == torch.half
|
|
249
|
+
assert residual_kwargs['beta'].dtype == torch.half
|
|
250
|
+
|
|
251
|
+
branch_output = torch.randn_like(branch_input).half()
|
|
252
|
+
residual = mhc.depth_connection(branch_output, residual, **residual_kwargs)
|
|
253
|
+
|
|
254
|
+
assert residual.dtype == torch.half
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|