hyper-connections 0.3.6__tar.gz → 0.3.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/PKG-INFO +1 -1
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/manifold_constrained_hyper_connections.py +38 -12
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/pyproject.toml +1 -1
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/tests/test_hyper_connections.py +21 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/.gitignore +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/LICENSE +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/README.md +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper-connections.png +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.3.6 → hyper_connections-0.3.8}/hyper_connections/residuals.py +0 -0
|
@@ -46,6 +46,9 @@ def l1norm(t, dim):
|
|
|
46
46
|
return F.normalize(t, p = 1, dim = dim)
|
|
47
47
|
|
|
48
48
|
def sinkhorn_knopps(log_alpha, iters = 20):
|
|
49
|
+
dtype = log_alpha.dtype
|
|
50
|
+
log_alpha = log_alpha.float()
|
|
51
|
+
|
|
49
52
|
log_alpha = log_alpha - log_alpha.amax(dim = -2, keepdim = True).detach()
|
|
50
53
|
|
|
51
54
|
alpha = log_alpha.exp()
|
|
@@ -54,7 +57,7 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
54
57
|
alpha = l1norm(alpha, dim = -2)
|
|
55
58
|
alpha = l1norm(alpha, dim = -1)
|
|
56
59
|
|
|
57
|
-
return alpha
|
|
60
|
+
return alpha.to(dtype)
|
|
58
61
|
|
|
59
62
|
# main functions
|
|
60
63
|
|
|
@@ -197,7 +200,8 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
197
200
|
num_input_views = 1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
|
198
201
|
depth_residual_fn = add,
|
|
199
202
|
num_fracs = 1, # https://arxiv.org/abs/2503.14125
|
|
200
|
-
sinkhorn_iters = 20
|
|
203
|
+
sinkhorn_iters = 20,
|
|
204
|
+
forward_method_names: tuple[str, ...] = (),
|
|
201
205
|
):
|
|
202
206
|
"""
|
|
203
207
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -287,6 +291,16 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
287
291
|
|
|
288
292
|
self.depth_residual_fn = depth_residual_fn
|
|
289
293
|
|
|
294
|
+
# forwarding method names
|
|
295
|
+
|
|
296
|
+
self.forward_method_names = forward_method_names
|
|
297
|
+
|
|
298
|
+
for forward_method_name in self.forward_method_names:
|
|
299
|
+
assert not hasattr(self, forward_method_name)
|
|
300
|
+
|
|
301
|
+
fn = getattr(self.branch, forward_method_name)
|
|
302
|
+
setattr(self, forward_method_name, fn)
|
|
303
|
+
|
|
290
304
|
def width_connection(
|
|
291
305
|
self,
|
|
292
306
|
residuals
|
|
@@ -316,17 +330,21 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
316
330
|
|
|
317
331
|
# alpha for weighted sum of residuals going into branch
|
|
318
332
|
|
|
319
|
-
|
|
333
|
+
dtype = residuals.dtype
|
|
320
334
|
|
|
321
|
-
|
|
322
|
-
|
|
335
|
+
normed = normed.float()
|
|
336
|
+
|
|
337
|
+
wc_weight = normed @ self.dynamic_alpha_fn.float()
|
|
338
|
+
|
|
339
|
+
pre_branch_scale = repeat(self.pre_branch_scale.float(), '1 -> s', s = self.num_fracs)
|
|
340
|
+
residual_scale = repeat(self.residual_scale.float(), '1 -> s', s = self.num_fracs * streams)
|
|
323
341
|
alpha_scale = cat((pre_branch_scale, residual_scale))
|
|
324
342
|
|
|
325
343
|
alpha_scale = repeat(alpha_scale, 'n -> (v n)', v = self.num_input_views)
|
|
326
344
|
|
|
327
345
|
dynamic_alpha = wc_weight * alpha_scale
|
|
328
346
|
|
|
329
|
-
static_alpha = rearrange(self.static_alpha, '(f s) d -> f s d', s = streams)
|
|
347
|
+
static_alpha = rearrange(self.static_alpha.float(), '(f s) d -> f s d', s = streams)
|
|
330
348
|
|
|
331
349
|
alpha = dynamic_alpha + static_alpha
|
|
332
350
|
|
|
@@ -351,20 +369,20 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
351
369
|
beta = None
|
|
352
370
|
|
|
353
371
|
if self.add_branch_out_to_residual:
|
|
354
|
-
dc_weight = normed @ self.dynamic_beta_fn
|
|
372
|
+
dc_weight = normed @ self.dynamic_beta_fn.float()
|
|
355
373
|
|
|
356
374
|
dc_weight = dc_weight.sigmoid() * 2 # sigmoid * 2 for "H_post", corresponding to dc weight in original paper
|
|
357
375
|
|
|
358
376
|
if not self.has_fracs:
|
|
359
377
|
dc_weight = rearrange(dc_weight, '... -> ... 1')
|
|
360
378
|
|
|
361
|
-
dynamic_beta = dc_weight * self.h_post_scale
|
|
379
|
+
dynamic_beta = dc_weight * self.h_post_scale.float()
|
|
362
380
|
|
|
363
|
-
static_beta = rearrange(self.static_beta, '... (s f) -> ... s f', s = streams)
|
|
381
|
+
static_beta = rearrange(self.static_beta.float(), '... (s f) -> ... s f', s = streams)
|
|
364
382
|
|
|
365
383
|
beta = dynamic_beta + static_beta
|
|
366
384
|
|
|
367
|
-
mix_h = einsum(alpha, residuals, '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
385
|
+
mix_h = einsum(alpha, residuals.float(), '... f1 s f2 t, ... f1 s d -> ... f2 t d')
|
|
368
386
|
|
|
369
387
|
if self.num_input_views == 1:
|
|
370
388
|
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
|
@@ -379,6 +397,12 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
379
397
|
|
|
380
398
|
branch_input = self.merge_fracs(branch_input)
|
|
381
399
|
|
|
400
|
+
branch_input = branch_input.to(dtype)
|
|
401
|
+
residuals = residuals.to(dtype)
|
|
402
|
+
|
|
403
|
+
if exists(beta):
|
|
404
|
+
beta = beta.to(dtype)
|
|
405
|
+
|
|
382
406
|
return branch_input, maybe_transformed_residuals, dict(beta = beta)
|
|
383
407
|
|
|
384
408
|
def depth_connection(
|
|
@@ -399,7 +423,9 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
399
423
|
if self.channel_first:
|
|
400
424
|
branch_output = rearrange(branch_output, 'b d ... -> b ... d')
|
|
401
425
|
|
|
402
|
-
|
|
426
|
+
dtype = residuals.dtype
|
|
427
|
+
|
|
428
|
+
output = einsum(branch_output.float(), beta.float(), 'b ... f1 d, b ... f1 s f2 -> b ... f2 s d')
|
|
403
429
|
|
|
404
430
|
output = rearrange(output, 'b ... s d -> (b s) ... d')
|
|
405
431
|
|
|
@@ -412,7 +438,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
412
438
|
if self.channel_first:
|
|
413
439
|
output = rearrange(output, 'b ... d -> b d ...')
|
|
414
440
|
|
|
415
|
-
residuals = self.depth_residual_fn(output, residuals)
|
|
441
|
+
residuals = self.depth_residual_fn(output.to(dtype), residuals)
|
|
416
442
|
|
|
417
443
|
return self.dropout(residuals)
|
|
418
444
|
|
|
@@ -231,3 +231,24 @@ def test_channel_first_hyper_connection(disable):
|
|
|
231
231
|
after_residual = reduce_stream(residual)
|
|
232
232
|
|
|
233
233
|
assert before_residual.shape == after_residual.shape
|
|
234
|
+
|
|
235
|
+
def test_mhc_dtype_restoration():
|
|
236
|
+
from hyper_connections.manifold_constrained_hyper_connections import ManifoldConstrainedHyperConnections
|
|
237
|
+
|
|
238
|
+
mhc = ManifoldConstrainedHyperConnections(
|
|
239
|
+
num_residual_streams = 4,
|
|
240
|
+
dim = 64,
|
|
241
|
+
add_branch_out_to_residual = True
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
residual = torch.randn(4, 1, 64).half()
|
|
245
|
+
|
|
246
|
+
branch_input, _, residual_kwargs = mhc.width_connection(residual)
|
|
247
|
+
|
|
248
|
+
assert branch_input.dtype == torch.half
|
|
249
|
+
assert residual_kwargs['beta'].dtype == torch.half
|
|
250
|
+
|
|
251
|
+
branch_output = torch.randn_like(branch_input).half()
|
|
252
|
+
residual = mhc.depth_connection(branch_output, residual, **residual_kwargs)
|
|
253
|
+
|
|
254
|
+
assert residual.dtype == torch.half
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|