hyper-connections 0.3.5__tar.gz → 0.3.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/PKG-INFO +1 -1
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/manifold_constrained_hyper_connections.py +27 -16
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/pyproject.toml +1 -1
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/tests/test_hyper_connections.py +21 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/.gitignore +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/LICENSE +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/README.md +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper-connections.png +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.3.5 → hyper_connections-0.3.7}/hyper_connections/residuals.py +0 -0
|
@@ -46,6 +46,9 @@ def l1norm(t, dim):
|
|
|
46
46
|
return F.normalize(t, p = 1, dim = dim)
|
|
47
47
|
|
|
48
48
|
def sinkhorn_knopps(log_alpha, iters = 20):
|
|
49
|
+
dtype = log_alpha.dtype
|
|
50
|
+
log_alpha = log_alpha.float()
|
|
51
|
+
|
|
49
52
|
log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
|
|
50
53
|
|
|
51
54
|
alpha = log_alpha.exp()
|
|
@@ -54,7 +57,7 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
54
57
|
alpha = l1norm(alpha, dim = -2)
|
|
55
58
|
alpha = l1norm(alpha, dim = -1)
|
|
56
59
|
|
|
57
|
-
return alpha
|
|
60
|
+
return alpha.to(dtype)
|
|
58
61
|
|
|
59
62
|
# main functions
|
|
60
63
|
|
|
@@ -222,7 +225,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
222
225
|
|
|
223
226
|
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
|
224
227
|
|
|
225
|
-
self.norm = RMSNorm(dim
|
|
228
|
+
self.norm = RMSNorm(dim)
|
|
226
229
|
|
|
227
230
|
assert num_residual_streams > 0, '`num_residual_streams` must be greater than 0'
|
|
228
231
|
|
|
@@ -312,25 +315,25 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
312
315
|
|
|
313
316
|
# norm
|
|
314
317
|
|
|
315
|
-
normed =
|
|
318
|
+
normed = self.norm(residuals)
|
|
316
319
|
|
|
317
|
-
|
|
320
|
+
# alpha for weighted sum of residuals going into branch
|
|
318
321
|
|
|
319
|
-
|
|
322
|
+
dtype = residuals.dtype
|
|
320
323
|
|
|
321
|
-
|
|
324
|
+
normed = normed.float()
|
|
322
325
|
|
|
323
|
-
wc_weight = normed @ self.dynamic_alpha_fn
|
|
326
|
+
wc_weight = normed @ self.dynamic_alpha_fn.float()
|
|
324
327
|
|
|
325
|
-
pre_branch_scale = repeat(self.pre_branch_scale, '1 -> s', s = self.num_fracs)
|
|
326
|
-
residual_scale = repeat(self.residual_scale, '1 -> s', s = self.num_fracs * streams)
|
|
328
|
+
pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
|
|
329
|
+
residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
|
|
327
330
|
alpha_scale = cat((pre_branch_scale, residual_scale))
|
|
328
331
|
|
|
329
332
|
alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
|
|
330
333
|
|
|
331
334
|
dynamic_alpha = wc_weight * alpha_scale
|
|
332
335
|
|
|
333
|
-
static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
|
|
336
|
+
static_alpha = rearrange(self.static_alpha.float(), '(f s) d -> f s d', s = streams)
|
|
334
337
|
|
|
335
338
|
alpha = dynamic_alpha + static_alpha
|
|
336
339
|
|
|
@@ -355,20 +358,20 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
355
358
|
beta = None
|
|
356
359
|
|
|
357
360
|
if self.add_branch_out_to_residual:
|
|
358
|
-
dc_weight = normed @ self.dynamic_beta_fn
|
|
361
|
+
dc_weight = normed @ self.dynamic_beta_fn.float()
|
|
359
362
|
|
|
360
363
|
dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
|
|
361
364
|
|
|
362
365
|
if not self.has_fracs:
|
|
363
366
|
dc_weight = rearrange(dc_weight, '... -> ... 1')
|
|
364
367
|
|
|
365
|
-
dynamic_beta = dc_weight * self.h_post_scale
|
|
368
|
+
dynamic_beta = dc_weight * self.h_post_scale.float()
|
|
366
369
|
|
|
367
|
-
static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
|
|
370
|
+
static_beta = rearrange(self.static_beta.float(), '... (s f) -> ... s f', s = streams)
|
|
368
371
|
|
|
369
372
|
beta = dynamic_beta + static_beta
|
|
370
373
|
|
|
371
|
-
mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
374
|
+
mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
372
375
|
|
|
373
376
|
if self.num_input_views == 1:
|
|
374
377
|
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
@@ -383,6 +386,12 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
383
386
|
|
|
384
387
|
branch_input = self.merge_fracs(branch_input)
|
|
385
388
|
|
|
389
|
+
branch_input = branch_input.to(dtype)
|
|
390
|
+
residuals = residuals.to(dtype)
|
|
391
|
+
|
|
392
|
+
if exists(beta):
|
|
393
|
+
beta = beta.to(dtype)
|
|
394
|
+
|
|
386
395
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
387
396
|
|
|
388
397
|
def depth_connection(
|
|
@@ -403,7 +412,9 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
403
412
|
if self.channel_first:
|
|
404
413
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
405
414
|
|
|
406
|
-
|
|
415
|
+
dtype = residuals.dtype
|
|
416
|
+
|
|
417
|
+
output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
|
|
407
418
|
|
|
408
419
|
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
409
420
|
|
|
@@ -416,7 +427,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
416
427
|
if self.channel_first:
|
|
417
428
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
418
429
|
|
|
419
|
-
residuals = self.depth_residual_fn(output, residuals)
|
|
430
|
+
residuals = self.depth_residual_fn(output.to(dtype), residuals)
|
|
420
431
|
|
|
421
432
|
return self.dropout(residuals)
|
|
422
433
|
|
|
@@ -231,3 +231,24 @@ def test_channel_first_hyper_connection(disable):
|
|
|
231
231
|
after_residual = reduce_stream(residual)
|
|
232
232
|
|
|
233
233
|
assert before_residual.shape == after_residual.shape
|
|
234
|
+
|
|
235
|
+
def test_mhc_dtype_restoration():
|
|
236
|
+
from hyper_connections.manifold_constrained_hyper_connections import ManifoldConstrainedHyperConnections
|
|
237
|
+
|
|
238
|
+
mhc = ManifoldConstrainedHyperConnections(
|
|
239
|
+
num_residual_streams = 4,
|
|
240
|
+
dim = 64,
|
|
241
|
+
add_branch_out_to_residual = True
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
residual = torch.randn(4, 1, 64).half()
|
|
245
|
+
|
|
246
|
+
branch_input, _, residual_kwargs = mhc.width_connection(residual)
|
|
247
|
+
|
|
248
|
+
assert branch_input.dtype == torch.half
|
|
249
|
+
assert residual_kwargs['beta'].dtype == torch.half
|
|
250
|
+
|
|
251
|
+
branch_output = torch.randn_like(branch_input).half()
|
|
252
|
+
residual = mhc.depth_connection(branch_output, residual, **residual_kwargs)
|
|
253
|
+
|
|
254
|
+
assert residual.dtype == torch.half
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|