hyper-connections 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,8 +69,8 @@ def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
69
69
  log_alpha = log_alpha.float()
70
70
 
71
71
  for _ in range(iters):
72
- log_alpha = log_alpha - log_alpha.logsumexp(dim = -2, keepdim = True)
73
- log_alpha = log_alpha - log_alpha.logsumexp(dim = -1, keepdim = True)
72
+ log_alpha = F.log_softmax(log_alpha, dim = -2)
73
+ log_alpha = F.log_softmax(log_alpha, dim = -1)
74
74
 
75
75
  return log_alpha.exp().to(dtype)
76
76
 
@@ -109,6 +109,7 @@ def get_init_and_expand_reduce_stream_functions(
109
109
  add_attn_pool_reduce_stream = False,
110
110
  disable = None,
111
111
  sinkhorn_iters = 20,
112
+ use_triton_sinkhorn = False,
112
113
  **kwargs
113
114
  ):
114
115
  disable = default(disable, num_streams == 1 and num_fracs == 1)
@@ -116,7 +117,7 @@ def get_init_and_expand_reduce_stream_functions(
116
117
  hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
117
118
 
118
119
  kwargs.pop('add_attn_pool_reduce_stream', None)
119
- init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, **kwargs)
120
+ init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
120
121
  expand_reduce_fns = get_expand_reduce_stream_functions(
121
122
  num_streams,
122
123
  add_stream_embed = add_stream_embed,
@@ -231,7 +232,7 @@ class ManifoldConstrainedHyperConnections(Module):
231
232
  residual_mix_constraint_fn: Callable | None = None,
232
233
  forward_method_names: tuple[str, ...] = (),
233
234
  num_dynamic_alpha_proposals = 1,
234
-
235
+ use_triton_sinkhorn = False,
235
236
  ):
236
237
  """
237
238
  Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
@@ -306,10 +307,22 @@ class ManifoldConstrainedHyperConnections(Module):
306
307
  # Hres constraint related
307
308
  # by default is sinkhorn
308
309
 
309
- self.residual_mix_constraint_fn = default(
310
- residual_mix_constraint_fn,
311
- partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
312
- )
310
+ use_triton_sinkhorn_and_available = False
311
+
312
+ if use_triton_sinkhorn:
313
+ try:
314
+ from hyper_connections.triton_sinkhorn import triton_sinkhorn, is_triton_available
315
+ use_triton_sinkhorn_and_available = is_triton_available()
316
+ except ImportError:
317
+ use_triton_sinkhorn_and_available = False
318
+
319
+ if use_triton_sinkhorn_and_available:
320
+ self.residual_mix_constraint_fn = partial(triton_sinkhorn, iters = sinkhorn_iters)
321
+ else:
322
+ self.residual_mix_constraint_fn = default(
323
+ residual_mix_constraint_fn,
324
+ partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
325
+ )
313
326
 
314
327
  # dropouts
315
328
 
@@ -65,8 +65,8 @@ def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
65
65
  log_alpha = log_alpha.float()
66
66
 
67
67
  for _ in range(iters):
68
- log_alpha = log_alpha - log_alpha.logsumexp(dim = -2, keepdim = True)
69
- log_alpha = log_alpha - log_alpha.logsumexp(dim = -1, keepdim = True)
68
+ log_alpha = F.log_softmax(log_alpha, dim = -2)
69
+ log_alpha = F.log_softmax(log_alpha, dim = -1)
70
70
 
71
71
  return log_alpha.exp().to(dtype)
72
72
 
@@ -0,0 +1,160 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from torch.autograd import Function
5
+
6
+ @triton.jit
7
+ def sinkhorn_kernel_forward_log(
8
+ input_ptr,
9
+ output_ptr,
10
+ M, N,
11
+ stride_b, stride_m, stride_n,
12
+ iters: tl.constexpr,
13
+ BLOCK_SIZE: tl.constexpr,
14
+ ):
15
+ pid_b = tl.program_id(0)
16
+ offs_m = tl.arange(0, BLOCK_SIZE)
17
+ offs_n = tl.arange(0, BLOCK_SIZE)
18
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
19
+
20
+ curr_input_ptr = input_ptr + pid_b * stride_b
21
+ # Use a large negative value for log-space padding to avoid interference
22
+ log_alpha = tl.load(curr_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=-1e10)
23
+
24
+ # Use static_range to force unrolling and avoid compiler bugs with dynamic loops in this environment
25
+ for _ in tl.static_range(iters):
26
+ # Column-wise Log-Softmax (dim=-2)
27
+ col_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=0)
28
+ exp_weights_col = tl.exp(log_alpha - col_max[None, :])
29
+ exp_weights_col = tl.where(mask, exp_weights_col, 0.0)
30
+ col_lse = col_max + tl.log(tl.sum(exp_weights_col, axis=0))
31
+ log_alpha = log_alpha - col_lse[None, :]
32
+ log_alpha = tl.where(mask, log_alpha, -1e10)
33
+
34
+ # Row-wise Log-Softmax (dim=-1)
35
+ row_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=1)
36
+ exp_weights_row = tl.exp(log_alpha - row_max[:, None])
37
+ exp_weights_row = tl.where(mask, exp_weights_row, 0.0)
38
+ row_lse = row_max + tl.log(tl.sum(exp_weights_row, axis=1))
39
+ log_alpha = log_alpha - row_lse[:, None]
40
+ log_alpha = tl.where(mask, log_alpha, -1e10)
41
+
42
+ result_alpha = tl.exp(log_alpha)
43
+ result_alpha = tl.where(mask, result_alpha, 0.0)
44
+
45
+ curr_output_ptr = output_ptr + pid_b * stride_b
46
+ tl.store(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, result_alpha, mask=mask)
47
+
48
+ @triton.jit
49
+ def sinkhorn_kernel_backward_log(
50
+ grad_output_ptr,
51
+ output_ptr,
52
+ grad_input_ptr,
53
+ M, N,
54
+ stride_b, stride_m, stride_n,
55
+ iters: tl.constexpr,
56
+ BLOCK_SIZE: tl.constexpr,
57
+ ):
58
+ pid_b = tl.program_id(0)
59
+ offs_m = tl.arange(0, BLOCK_SIZE)
60
+ offs_n = tl.arange(0, BLOCK_SIZE)
61
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
62
+
63
+ curr_output_ptr = output_ptr + pid_b * stride_b
64
+ curr_grad_output_ptr = grad_output_ptr + pid_b * stride_b
65
+
66
+ alpha = tl.load(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
67
+ grad_alpha = tl.load(curr_grad_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
68
+
69
+ # Ensure they are truly zeroed in padded areas for sum robustness
70
+ alpha = tl.where(mask, alpha, 0.0)
71
+ grad_alpha = tl.where(mask, grad_alpha, 0.0)
72
+
73
+ for _ in tl.static_range(iters):
74
+ # Backward of Row-wise Normalization
75
+ # Sum only over valid elements
76
+ row_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=1)
77
+ grad_alpha = grad_alpha - row_sum_grad_alpha[:, None]
78
+ grad_alpha = tl.where(mask, grad_alpha, 0.0)
79
+
80
+ # Backward of Column-wise Normalization
81
+ col_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=0)
82
+ grad_alpha = grad_alpha - col_sum_grad_alpha[None, :]
83
+ grad_alpha = tl.where(mask, grad_alpha, 0.0)
84
+
85
+ grad_input = alpha * grad_alpha
86
+
87
+ curr_grad_input_ptr = grad_input_ptr + pid_b * stride_b
88
+ tl.store(curr_grad_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, grad_input, mask=mask)
89
+
90
+ class TritonSinkhornFunction(Function):
91
+ @staticmethod
92
+ def forward(ctx, log_alpha, iters=20):
93
+ # Handle matrix size limits to avoid register spilling/SRAM overflow
94
+ M, N = log_alpha.shape[-2:]
95
+ if max(M, N) > 256:
96
+ from hyper_connections.mHCv2 import log_domain_sinkhorn_knopps
97
+ return log_domain_sinkhorn_knopps(log_alpha, iters)
98
+
99
+ batch_shape = log_alpha.shape[:-2]
100
+ log_alpha_flat = log_alpha.view(-1, M, N).contiguous()
101
+ B = log_alpha_flat.shape[0]
102
+
103
+ output = torch.empty_like(log_alpha_flat)
104
+ BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
105
+
106
+ sinkhorn_kernel_forward_log[(B,)](
107
+ log_alpha_flat,
108
+ output,
109
+ M, N,
110
+ log_alpha_flat.stride(0), log_alpha_flat.stride(1), log_alpha_flat.stride(2),
111
+ iters=iters,
112
+ BLOCK_SIZE=BLOCK_SIZE,
113
+ num_warps=4
114
+ )
115
+
116
+ ctx.save_for_backward(output)
117
+ ctx.iters = iters
118
+ return output.view(*batch_shape, M, N)
119
+
120
+ @staticmethod
121
+ def backward(ctx, grad_output):
122
+ output, = ctx.saved_tensors
123
+ iters = ctx.iters
124
+ B, M, N = output.shape
125
+ BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
126
+
127
+ # Explicit contiguity for grad_output
128
+ grad_output = grad_output.contiguous()
129
+ grad_input = torch.empty_like(output)
130
+
131
+ sinkhorn_kernel_backward_log[(B,)](
132
+ grad_output.view(B, M, N),
133
+ output,
134
+ grad_input,
135
+ M, N,
136
+ grad_input.stride(0), grad_input.stride(1), grad_input.stride(2),
137
+ iters=iters,
138
+ BLOCK_SIZE=BLOCK_SIZE,
139
+ num_warps=4
140
+ )
141
+
142
+ return grad_input.view_as(grad_output), None
143
+
144
+ def triton_sinkhorn(log_alpha, iters=20):
145
+ if log_alpha.is_cuda:
146
+ try:
147
+ return TritonSinkhornFunction.apply(log_alpha, iters)
148
+ except Exception:
149
+ pass
150
+
151
+ # fallback
152
+ from hyper_connections.mHCv2 import sinkhorn_knopps
153
+ return sinkhorn_knopps(log_alpha, iters = iters)
154
+
155
+ def is_triton_available():
156
+ try:
157
+ import triton
158
+ return torch.cuda.is_available()
159
+ except ImportError:
160
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hyper-connections
3
- Version: 0.4.3
3
+ Version: 0.4.5
4
4
  Summary: Hyper-Connections
5
5
  Project-URL: Homepage, https://pypi.org/project/hyper-connections/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hyper-connections
@@ -3,11 +3,12 @@ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX
3
3
  hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
4
4
  hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
5
5
  hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
6
- hyper_connections/mHCv2.py,sha256=j3A4XbisBXzqdW9vYCrPRrK2M6iPAqMOjxGCj3lsQ-g,16810
7
- hyper_connections/manifold_constrained_hyper_connections.py,sha256=rQzAIkP84adzEVyrMasqMuZV76-6LAioUbwKnABcBto,18315
6
+ hyper_connections/mHCv2.py,sha256=wCtp87OFI3QfosdSL-1qwsiQN9f8gX32_0r8GQGO7P0,17411
7
+ hyper_connections/manifold_constrained_hyper_connections.py,sha256=E4os-6q_SMjJO1JD0EG8rFTCXA7MQoy-aqUlM7KVS5Q,18269
8
8
  hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
9
+ hyper_connections/triton_sinkhorn.py,sha256=n2WyQcUemtv5T5Sk2nljnSpV2hEED4I3HaPsIUy4638,5905
9
10
  hyper_connections/vit.py,sha256=BOWVfCAIzDQdnTq8OBzNUyiKGGILYZkIQ6mr1GKJVB0,5225
10
- hyper_connections-0.4.3.dist-info/METADATA,sha256=h_zeG-qAgyg-vDktRMaPpGuYzmA-kxrcUmPvVQ4CYvs,6704
11
- hyper_connections-0.4.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
12
- hyper_connections-0.4.3.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
13
- hyper_connections-0.4.3.dist-info/RECORD,,
11
+ hyper_connections-0.4.5.dist-info/METADATA,sha256=sWVb_-yVRmxL8AsAPsk0VdRXOa25uG9zKNc8S_oAXg8,6704
12
+ hyper_connections-0.4.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
13
+ hyper_connections-0.4.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
14
+ hyper_connections-0.4.5.dist-info/RECORD,,