hyper-connections 0.4.4__tar.gz → 0.4.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/PKG-INFO +1 -1
  2. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/mHCv2.py +27 -6
  3. hyper_connections-0.4.6/hyper_connections/triton_sinkhorn.py +160 -0
  4. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/pyproject.toml +1 -1
  5. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/tests/test_hyper_connections.py +35 -0
  6. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/.github/workflows/python-publish.yml +0 -0
  7. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/.github/workflows/test.yml +0 -0
  8. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/.gitignore +0 -0
  9. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/LICENSE +0 -0
  10. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/README.md +0 -0
  11. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper-connections.png +0 -0
  12. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/__init__.py +0 -0
  13. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections.py +0 -0
  14. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections_channel_first.py +0 -0
  15. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
  16. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
  17. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/manifold_constrained_hyper_connections.py +0 -0
  18. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/residuals.py +0 -0
  19. {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/vit.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.4
3
+ Version: 0.4.6
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -47,6 +47,10 @@ def l1norm(t, dim):
47
47
  return F.normalize(t, p = 1, dim = dim)
48
48
 
49
49
  def sinkhorn_knopps(log_alpha, iters = 20):
50
+
51
+ if iters <= 0:
52
+ return log_alpha
53
+
50
54
  assert log_alpha.shape[-2] == log_alpha.shape[-1]
51
55
 
52
56
  dtype = log_alpha.dtype
@@ -63,6 +67,10 @@ def sinkhorn_knopps(log_alpha, iters = 20):
63
67
  return alpha.to(dtype)
64
68
 
65
69
  def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
70
+
71
+ if iters <= 0:
72
+ return log_alpha
73
+
66
74
  assert log_alpha.shape[-2] == log_alpha.shape[-1]
67
75
 
68
76
  dtype = log_alpha.dtype
@@ -109,6 +117,7 @@ def get_init_and_expand_reduce_stream_functions(
109
117
  add_attn_pool_reduce_stream = False,
110
118
  disable = None,
111
119
  sinkhorn_iters = 20,
120
+ use_triton_sinkhorn = False,
112
121
  **kwargs
113
122
  ):
114
123
  disable = default(disable, num_streams == 1 and num_fracs == 1)
@@ -116,7 +125,7 @@ def get_init_and_expand_reduce_stream_functions(
116
125
  hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
117
126
 
118
127
  kwargs.pop('add_attn_pool_reduce_stream', None)
119
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, **kwargs)
128
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
120
129
  expand_reduce_fns = get_expand_reduce_stream_functions(
121
130
  num_streams,
122
131
  add_stream_embed = add_stream_embed,
@@ -231,7 +240,7 @@ class ManifoldConstrainedHyperConnections(Module):
231
240
  residual_mix_constraint_fn: Callable | None = None,
232
241
  forward_method_names: tuple[str, ...] = (),
233
242
  num_dynamic_alpha_proposals = 1,
234
-
243
+ use_triton_sinkhorn = False,
235
244
  ):
236
245
  """
237
246
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -306,10 +315,22 @@ class ManifoldConstrainedHyperConnections(Module):
306
315
  # Hres constraint related
307
316
  # by default is sinkhorn
308
317
 
309
- self.residual_mix_constraint_fn = default(
310
- residual_mix_constraint_fn,
311
- partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
312
- )
318
+ use_triton_sinkhorn_and_available = False
319
+
320
+ if use_triton_sinkhorn:
321
+ try:
322
+ from hyper_connections.triton_sinkhorn import triton_sinkhorn, is_triton_available
323
+ use_triton_sinkhorn_and_available = is_triton_available()
324
+ except ImportError:
325
+ use_triton_sinkhorn_and_available = False
326
+
327
+ if use_triton_sinkhorn_and_available:
328
+ self.residual_mix_constraint_fn = partial(triton_sinkhorn, iters = sinkhorn_iters)
329
+ else:
330
+ self.residual_mix_constraint_fn = default(
331
+ residual_mix_constraint_fn,
332
+ partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
333
+ )
313
334
 
314
335
  # dropouts
315
336
 
@@ -0,0 +1,160 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from torch.autograd import Function
5
+
6
+ @triton.jit
7
+ def sinkhorn_kernel_forward_log(
8
+ input_ptr,
9
+ output_ptr,
10
+ M, N,
11
+ stride_b, stride_m, stride_n,
12
+ iters: tl.constexpr,
13
+ BLOCK_SIZE: tl.constexpr,
14
+ ):
15
+ pid_b = tl.program_id(0)
16
+ offs_m = tl.arange(0, BLOCK_SIZE)
17
+ offs_n = tl.arange(0, BLOCK_SIZE)
18
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
19
+
20
+ curr_input_ptr = input_ptr + pid_b * stride_b
21
+ # Use a large negative value for log-space padding to avoid interference
22
+ log_alpha = tl.load(curr_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=-1e10)
23
+
24
+ # Use static_range to force unrolling and avoid compiler bugs with dynamic loops in this environment
25
+ for _ in tl.static_range(iters):
26
+ # Column-wise Log-Softmax (dim=-2)
27
+ col_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=0)
28
+ exp_weights_col = tl.exp(log_alpha - col_max[None, :])
29
+ exp_weights_col = tl.where(mask, exp_weights_col, 0.0)
30
+ col_lse = col_max + tl.log(tl.sum(exp_weights_col, axis=0))
31
+ log_alpha = log_alpha - col_lse[None, :]
32
+ log_alpha = tl.where(mask, log_alpha, -1e10)
33
+
34
+ # Row-wise Log-Softmax (dim=-1)
35
+ row_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=1)
36
+ exp_weights_row = tl.exp(log_alpha - row_max[:, None])
37
+ exp_weights_row = tl.where(mask, exp_weights_row, 0.0)
38
+ row_lse = row_max + tl.log(tl.sum(exp_weights_row, axis=1))
39
+ log_alpha = log_alpha - row_lse[:, None]
40
+ log_alpha = tl.where(mask, log_alpha, -1e10)
41
+
42
+ result_alpha = tl.exp(log_alpha)
43
+ result_alpha = tl.where(mask, result_alpha, 0.0)
44
+
45
+ curr_output_ptr = output_ptr + pid_b * stride_b
46
+ tl.store(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, result_alpha, mask=mask)
47
+
48
+ @triton.jit
49
+ def sinkhorn_kernel_backward_log(
50
+ grad_output_ptr,
51
+ output_ptr,
52
+ grad_input_ptr,
53
+ M, N,
54
+ stride_b, stride_m, stride_n,
55
+ iters: tl.constexpr,
56
+ BLOCK_SIZE: tl.constexpr,
57
+ ):
58
+ pid_b = tl.program_id(0)
59
+ offs_m = tl.arange(0, BLOCK_SIZE)
60
+ offs_n = tl.arange(0, BLOCK_SIZE)
61
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
62
+
63
+ curr_output_ptr = output_ptr + pid_b * stride_b
64
+ curr_grad_output_ptr = grad_output_ptr + pid_b * stride_b
65
+
66
+ alpha = tl.load(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
67
+ grad_alpha = tl.load(curr_grad_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
68
+
69
+ # Ensure they are truly zeroed in padded areas for sum robustness
70
+ alpha = tl.where(mask, alpha, 0.0)
71
+ grad_alpha = tl.where(mask, grad_alpha, 0.0)
72
+
73
+ for _ in tl.static_range(iters):
74
+ # Backward of Row-wise Normalization
75
+ # Sum only over valid elements
76
+ row_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=1)
77
+ grad_alpha = grad_alpha - row_sum_grad_alpha[:, None]
78
+ grad_alpha = tl.where(mask, grad_alpha, 0.0)
79
+
80
+ # Backward of Column-wise Normalization
81
+ col_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=0)
82
+ grad_alpha = grad_alpha - col_sum_grad_alpha[None, :]
83
+ grad_alpha = tl.where(mask, grad_alpha, 0.0)
84
+
85
+ grad_input = alpha * grad_alpha
86
+
87
+ curr_grad_input_ptr = grad_input_ptr + pid_b * stride_b
88
+ tl.store(curr_grad_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, grad_input, mask=mask)
89
+
90
+ class TritonSinkhornFunction(Function):
91
+ @staticmethod
92
+ def forward(ctx, log_alpha, iters=20):
93
+ # Handle matrix size limits to avoid register spilling/SRAM overflow
94
+ M, N = log_alpha.shape[-2:]
95
+ if max(M, N) > 256:
96
+ from hyper_connections.mHCv2 import log_domain_sinkhorn_knopps
97
+ return log_domain_sinkhorn_knopps(log_alpha, iters)
98
+
99
+ batch_shape = log_alpha.shape[:-2]
100
+ log_alpha_flat = log_alpha.view(-1, M, N).contiguous()
101
+ B = log_alpha_flat.shape[0]
102
+
103
+ output = torch.empty_like(log_alpha_flat)
104
+ BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
105
+
106
+ sinkhorn_kernel_forward_log[(B,)](
107
+ log_alpha_flat,
108
+ output,
109
+ M, N,
110
+ log_alpha_flat.stride(0), log_alpha_flat.stride(1), log_alpha_flat.stride(2),
111
+ iters=iters,
112
+ BLOCK_SIZE=BLOCK_SIZE,
113
+ num_warps=4
114
+ )
115
+
116
+ ctx.save_for_backward(output)
117
+ ctx.iters = iters
118
+ return output.view(*batch_shape, M, N)
119
+
120
+ @staticmethod
121
+ def backward(ctx, grad_output):
122
+ output, = ctx.saved_tensors
123
+ iters = ctx.iters
124
+ B, M, N = output.shape
125
+ BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
126
+
127
+ # Explicit contiguity for grad_output
128
+ grad_output = grad_output.contiguous()
129
+ grad_input = torch.empty_like(output)
130
+
131
+ sinkhorn_kernel_backward_log[(B,)](
132
+ grad_output.view(B, M, N),
133
+ output,
134
+ grad_input,
135
+ M, N,
136
+ grad_input.stride(0), grad_input.stride(1), grad_input.stride(2),
137
+ iters=iters,
138
+ BLOCK_SIZE=BLOCK_SIZE,
139
+ num_warps=4
140
+ )
141
+
142
+ return grad_input.view_as(grad_output), None
143
+
144
+ def triton_sinkhorn(log_alpha, iters=20):
145
+ if log_alpha.is_cuda:
146
+ try:
147
+ return TritonSinkhornFunction.apply(log_alpha, iters)
148
+ except Exception:
149
+ pass
150
+
151
+ # fallback
152
+ from hyper_connections.mHCv2 import sinkhorn_knopps
153
+ return sinkhorn_knopps(log_alpha, iters = iters)
154
+
155
+ def is_triton_available():
156
+ try:
157
+ import triton
158
+ return torch.cuda.is_available()
159
+ except ImportError:
160
+ return False
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hyper-connections"
3
- version = "0.4.4"
3
+ version = "0.4.6"
4
4
  description = "Hyper-Connections"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -280,3 +280,38 @@ def test_mhcv2(
280
280
  residual = reduce_stream(residual)
281
281
 
282
282
  assert residual.shape == (2, 1024, 512)
283
+
284
+ def test_triton_sinkhorn():
285
+ import torch
286
+ if not torch.cuda.is_available():
287
+ pytest.skip('CUDA not available')
288
+
289
+ from hyper_connections.triton_sinkhorn import triton_sinkhorn
290
+ from hyper_connections.mHCv2 import sinkhorn_knopps, log_domain_sinkhorn_knopps
291
+
292
+ B, M, N = 2, 16, 16
293
+ log_alpha = torch.randn(B, M, N, device = 'cuda', requires_grad = True, dtype = torch.float32)
294
+ iters = 20
295
+
296
+ # 1. Forward equivalence with sinkhorn_knopps
297
+ out_triton = triton_sinkhorn(log_alpha, iters = iters)
298
+ out_torch = sinkhorn_knopps(log_alpha, iters = iters)
299
+ torch.testing.assert_close(out_triton, out_torch, atol = 1e-4, rtol = 1e-4)
300
+
301
+ # 2. Forward equivalence with log_domain_sinkhorn_knopps
302
+ out_log_torch = log_domain_sinkhorn_knopps(log_alpha, iters = iters)
303
+ torch.testing.assert_close(out_triton, out_log_torch, atol = 1e-4, rtol = 1e-4)
304
+
305
+ # 3. Backward parity check
306
+ out_triton.backward(torch.ones_like(out_triton))
307
+ grad_triton = log_alpha.grad.clone()
308
+
309
+ log_alpha.grad.zero_()
310
+ out_torch.backward(torch.ones_like(out_torch))
311
+ grad_torch = log_alpha.grad.clone()
312
+
313
+ torch.testing.assert_close(grad_triton, grad_torch, atol = 1e-3, rtol = 1e-3)
314
+
315
+ log_alpha_double = torch.randn(1, 4, 4, device = 'cuda', requires_grad = True, dtype = torch.float64)
316
+
317
+ torch.autograd.gradcheck(triton_sinkhorn, (log_alpha_double, 10), eps = 1e-6, atol = 1e-5)