hyper-connections 0.4.4__tar.gz → 0.4.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/PKG-INFO +1 -1
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/mHCv2.py +27 -6
- hyper_connections-0.4.6/hyper_connections/triton_sinkhorn.py +160 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/pyproject.toml +1 -1
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/tests/test_hyper_connections.py +35 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/.github/workflows/python-publish.yml +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/.github/workflows/test.yml +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/.gitignore +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/LICENSE +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/README.md +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper-connections.png +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/__init__.py +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections.py +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections_channel_first.py +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections_with_multi_branch_inputs.py +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/hyper_connections_with_multi_input_streams.py +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/manifold_constrained_hyper_connections.py +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/residuals.py +0 -0
- {hyper_connections-0.4.4 → hyper_connections-0.4.6}/hyper_connections/vit.py +0 -0
|
@@ -47,6 +47,10 @@ def l1norm(t, dim):
|
|
|
47
47
|
return F.normalize(t, p = 1, dim = dim)
|
|
48
48
|
|
|
49
49
|
def sinkhorn_knopps(log_alpha, iters = 20):
|
|
50
|
+
|
|
51
|
+
if iters <= 0:
|
|
52
|
+
return log_alpha
|
|
53
|
+
|
|
50
54
|
assert log_alpha.shape[-2] == log_alpha.shape[-1]
|
|
51
55
|
|
|
52
56
|
dtype = log_alpha.dtype
|
|
@@ -63,6 +67,10 @@ def sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
63
67
|
return alpha.to(dtype)
|
|
64
68
|
|
|
65
69
|
def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
|
|
70
|
+
|
|
71
|
+
if iters <= 0:
|
|
72
|
+
return log_alpha
|
|
73
|
+
|
|
66
74
|
assert log_alpha.shape[-2] == log_alpha.shape[-1]
|
|
67
75
|
|
|
68
76
|
dtype = log_alpha.dtype
|
|
@@ -109,6 +117,7 @@ def get_init_and_expand_reduce_stream_functions(
|
|
|
109
117
|
add_attn_pool_reduce_stream = False,
|
|
110
118
|
disable = None,
|
|
111
119
|
sinkhorn_iters = 20,
|
|
120
|
+
use_triton_sinkhorn = False,
|
|
112
121
|
**kwargs
|
|
113
122
|
):
|
|
114
123
|
disable = default(disable, num_streams == 1 and num_fracs == 1)
|
|
@@ -116,7 +125,7 @@ def get_init_and_expand_reduce_stream_functions(
|
|
|
116
125
|
hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
|
|
117
126
|
|
|
118
127
|
kwargs.pop('add_attn_pool_reduce_stream', None)
|
|
119
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, **kwargs)
|
|
128
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
|
|
120
129
|
expand_reduce_fns = get_expand_reduce_stream_functions(
|
|
121
130
|
num_streams,
|
|
122
131
|
add_stream_embed = add_stream_embed,
|
|
@@ -231,7 +240,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
231
240
|
residual_mix_constraint_fn: Callable | None = None,
|
|
232
241
|
forward_method_names: tuple[str, ...] = (),
|
|
233
242
|
num_dynamic_alpha_proposals = 1,
|
|
234
|
-
|
|
243
|
+
use_triton_sinkhorn = False,
|
|
235
244
|
):
|
|
236
245
|
"""
|
|
237
246
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -306,10 +315,22 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
306
315
|
# Hres constraint related
|
|
307
316
|
# by default is sinkhorn
|
|
308
317
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
318
|
+
use_triton_sinkhorn_and_available = False
|
|
319
|
+
|
|
320
|
+
if use_triton_sinkhorn:
|
|
321
|
+
try:
|
|
322
|
+
from hyper_connections.triton_sinkhorn import triton_sinkhorn, is_triton_available
|
|
323
|
+
use_triton_sinkhorn_and_available = is_triton_available()
|
|
324
|
+
except ImportError:
|
|
325
|
+
use_triton_sinkhorn_and_available = False
|
|
326
|
+
|
|
327
|
+
if use_triton_sinkhorn_and_available:
|
|
328
|
+
self.residual_mix_constraint_fn = partial(triton_sinkhorn, iters = sinkhorn_iters)
|
|
329
|
+
else:
|
|
330
|
+
self.residual_mix_constraint_fn = default(
|
|
331
|
+
residual_mix_constraint_fn,
|
|
332
|
+
partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
|
|
333
|
+
)
|
|
313
334
|
|
|
314
335
|
# dropouts
|
|
315
336
|
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
from torch.autograd import Function
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def sinkhorn_kernel_forward_log(
|
|
8
|
+
input_ptr,
|
|
9
|
+
output_ptr,
|
|
10
|
+
M, N,
|
|
11
|
+
stride_b, stride_m, stride_n,
|
|
12
|
+
iters: tl.constexpr,
|
|
13
|
+
BLOCK_SIZE: tl.constexpr,
|
|
14
|
+
):
|
|
15
|
+
pid_b = tl.program_id(0)
|
|
16
|
+
offs_m = tl.arange(0, BLOCK_SIZE)
|
|
17
|
+
offs_n = tl.arange(0, BLOCK_SIZE)
|
|
18
|
+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
19
|
+
|
|
20
|
+
curr_input_ptr = input_ptr + pid_b * stride_b
|
|
21
|
+
# Use a large negative value for log-space padding to avoid interference
|
|
22
|
+
log_alpha = tl.load(curr_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=-1e10)
|
|
23
|
+
|
|
24
|
+
# Use static_range to force unrolling and avoid compiler bugs with dynamic loops in this environment
|
|
25
|
+
for _ in tl.static_range(iters):
|
|
26
|
+
# Column-wise Log-Softmax (dim=-2)
|
|
27
|
+
col_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=0)
|
|
28
|
+
exp_weights_col = tl.exp(log_alpha - col_max[None, :])
|
|
29
|
+
exp_weights_col = tl.where(mask, exp_weights_col, 0.0)
|
|
30
|
+
col_lse = col_max + tl.log(tl.sum(exp_weights_col, axis=0))
|
|
31
|
+
log_alpha = log_alpha - col_lse[None, :]
|
|
32
|
+
log_alpha = tl.where(mask, log_alpha, -1e10)
|
|
33
|
+
|
|
34
|
+
# Row-wise Log-Softmax (dim=-1)
|
|
35
|
+
row_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=1)
|
|
36
|
+
exp_weights_row = tl.exp(log_alpha - row_max[:, None])
|
|
37
|
+
exp_weights_row = tl.where(mask, exp_weights_row, 0.0)
|
|
38
|
+
row_lse = row_max + tl.log(tl.sum(exp_weights_row, axis=1))
|
|
39
|
+
log_alpha = log_alpha - row_lse[:, None]
|
|
40
|
+
log_alpha = tl.where(mask, log_alpha, -1e10)
|
|
41
|
+
|
|
42
|
+
result_alpha = tl.exp(log_alpha)
|
|
43
|
+
result_alpha = tl.where(mask, result_alpha, 0.0)
|
|
44
|
+
|
|
45
|
+
curr_output_ptr = output_ptr + pid_b * stride_b
|
|
46
|
+
tl.store(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, result_alpha, mask=mask)
|
|
47
|
+
|
|
48
|
+
@triton.jit
|
|
49
|
+
def sinkhorn_kernel_backward_log(
|
|
50
|
+
grad_output_ptr,
|
|
51
|
+
output_ptr,
|
|
52
|
+
grad_input_ptr,
|
|
53
|
+
M, N,
|
|
54
|
+
stride_b, stride_m, stride_n,
|
|
55
|
+
iters: tl.constexpr,
|
|
56
|
+
BLOCK_SIZE: tl.constexpr,
|
|
57
|
+
):
|
|
58
|
+
pid_b = tl.program_id(0)
|
|
59
|
+
offs_m = tl.arange(0, BLOCK_SIZE)
|
|
60
|
+
offs_n = tl.arange(0, BLOCK_SIZE)
|
|
61
|
+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
62
|
+
|
|
63
|
+
curr_output_ptr = output_ptr + pid_b * stride_b
|
|
64
|
+
curr_grad_output_ptr = grad_output_ptr + pid_b * stride_b
|
|
65
|
+
|
|
66
|
+
alpha = tl.load(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
|
|
67
|
+
grad_alpha = tl.load(curr_grad_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
|
|
68
|
+
|
|
69
|
+
# Ensure they are truly zeroed in padded areas for sum robustness
|
|
70
|
+
alpha = tl.where(mask, alpha, 0.0)
|
|
71
|
+
grad_alpha = tl.where(mask, grad_alpha, 0.0)
|
|
72
|
+
|
|
73
|
+
for _ in tl.static_range(iters):
|
|
74
|
+
# Backward of Row-wise Normalization
|
|
75
|
+
# Sum only over valid elements
|
|
76
|
+
row_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=1)
|
|
77
|
+
grad_alpha = grad_alpha - row_sum_grad_alpha[:, None]
|
|
78
|
+
grad_alpha = tl.where(mask, grad_alpha, 0.0)
|
|
79
|
+
|
|
80
|
+
# Backward of Column-wise Normalization
|
|
81
|
+
col_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=0)
|
|
82
|
+
grad_alpha = grad_alpha - col_sum_grad_alpha[None, :]
|
|
83
|
+
grad_alpha = tl.where(mask, grad_alpha, 0.0)
|
|
84
|
+
|
|
85
|
+
grad_input = alpha * grad_alpha
|
|
86
|
+
|
|
87
|
+
curr_grad_input_ptr = grad_input_ptr + pid_b * stride_b
|
|
88
|
+
tl.store(curr_grad_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, grad_input, mask=mask)
|
|
89
|
+
|
|
90
|
+
class TritonSinkhornFunction(Function):
|
|
91
|
+
@staticmethod
|
|
92
|
+
def forward(ctx, log_alpha, iters=20):
|
|
93
|
+
# Handle matrix size limits to avoid register spilling/SRAM overflow
|
|
94
|
+
M, N = log_alpha.shape[-2:]
|
|
95
|
+
if max(M, N) > 256:
|
|
96
|
+
from hyper_connections.mHCv2 import log_domain_sinkhorn_knopps
|
|
97
|
+
return log_domain_sinkhorn_knopps(log_alpha, iters)
|
|
98
|
+
|
|
99
|
+
batch_shape = log_alpha.shape[:-2]
|
|
100
|
+
log_alpha_flat = log_alpha.view(-1, M, N).contiguous()
|
|
101
|
+
B = log_alpha_flat.shape[0]
|
|
102
|
+
|
|
103
|
+
output = torch.empty_like(log_alpha_flat)
|
|
104
|
+
BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
|
|
105
|
+
|
|
106
|
+
sinkhorn_kernel_forward_log[(B,)](
|
|
107
|
+
log_alpha_flat,
|
|
108
|
+
output,
|
|
109
|
+
M, N,
|
|
110
|
+
log_alpha_flat.stride(0), log_alpha_flat.stride(1), log_alpha_flat.stride(2),
|
|
111
|
+
iters=iters,
|
|
112
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
113
|
+
num_warps=4
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
ctx.save_for_backward(output)
|
|
117
|
+
ctx.iters = iters
|
|
118
|
+
return output.view(*batch_shape, M, N)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def backward(ctx, grad_output):
|
|
122
|
+
output, = ctx.saved_tensors
|
|
123
|
+
iters = ctx.iters
|
|
124
|
+
B, M, N = output.shape
|
|
125
|
+
BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
|
|
126
|
+
|
|
127
|
+
# Explicit contiguity for grad_output
|
|
128
|
+
grad_output = grad_output.contiguous()
|
|
129
|
+
grad_input = torch.empty_like(output)
|
|
130
|
+
|
|
131
|
+
sinkhorn_kernel_backward_log[(B,)](
|
|
132
|
+
grad_output.view(B, M, N),
|
|
133
|
+
output,
|
|
134
|
+
grad_input,
|
|
135
|
+
M, N,
|
|
136
|
+
grad_input.stride(0), grad_input.stride(1), grad_input.stride(2),
|
|
137
|
+
iters=iters,
|
|
138
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
139
|
+
num_warps=4
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return grad_input.view_as(grad_output), None
|
|
143
|
+
|
|
144
|
+
def triton_sinkhorn(log_alpha, iters=20):
|
|
145
|
+
if log_alpha.is_cuda:
|
|
146
|
+
try:
|
|
147
|
+
return TritonSinkhornFunction.apply(log_alpha, iters)
|
|
148
|
+
except Exception:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
# fallback
|
|
152
|
+
from hyper_connections.mHCv2 import sinkhorn_knopps
|
|
153
|
+
return sinkhorn_knopps(log_alpha, iters = iters)
|
|
154
|
+
|
|
155
|
+
def is_triton_available():
|
|
156
|
+
try:
|
|
157
|
+
import triton
|
|
158
|
+
return torch.cuda.is_available()
|
|
159
|
+
except ImportError:
|
|
160
|
+
return False
|
|
@@ -280,3 +280,38 @@ def test_mhcv2(
|
|
|
280
280
|
residual = reduce_stream(residual)
|
|
281
281
|
|
|
282
282
|
assert residual.shape == (2, 1024, 512)
|
|
283
|
+
|
|
284
|
+
def test_triton_sinkhorn():
|
|
285
|
+
import torch
|
|
286
|
+
if not torch.cuda.is_available():
|
|
287
|
+
pytest.skip('CUDA not available')
|
|
288
|
+
|
|
289
|
+
from hyper_connections.triton_sinkhorn import triton_sinkhorn
|
|
290
|
+
from hyper_connections.mHCv2 import sinkhorn_knopps, log_domain_sinkhorn_knopps
|
|
291
|
+
|
|
292
|
+
B, M, N = 2, 16, 16
|
|
293
|
+
log_alpha = torch.randn(B, M, N, device = 'cuda', requires_grad = True, dtype = torch.float32)
|
|
294
|
+
iters = 20
|
|
295
|
+
|
|
296
|
+
# 1. Forward equivalence with sinkhorn_knopps
|
|
297
|
+
out_triton = triton_sinkhorn(log_alpha, iters = iters)
|
|
298
|
+
out_torch = sinkhorn_knopps(log_alpha, iters = iters)
|
|
299
|
+
torch.testing.assert_close(out_triton, out_torch, atol = 1e-4, rtol = 1e-4)
|
|
300
|
+
|
|
301
|
+
# 2. Forward equivalence with log_domain_sinkhorn_knopps
|
|
302
|
+
out_log_torch = log_domain_sinkhorn_knopps(log_alpha, iters = iters)
|
|
303
|
+
torch.testing.assert_close(out_triton, out_log_torch, atol = 1e-4, rtol = 1e-4)
|
|
304
|
+
|
|
305
|
+
# 3. Backward parity check
|
|
306
|
+
out_triton.backward(torch.ones_like(out_triton))
|
|
307
|
+
grad_triton = log_alpha.grad.clone()
|
|
308
|
+
|
|
309
|
+
log_alpha.grad.zero_()
|
|
310
|
+
out_torch.backward(torch.ones_like(out_torch))
|
|
311
|
+
grad_torch = log_alpha.grad.clone()
|
|
312
|
+
|
|
313
|
+
torch.testing.assert_close(grad_triton, grad_torch, atol = 1e-3, rtol = 1e-3)
|
|
314
|
+
|
|
315
|
+
log_alpha_double = torch.randn(1, 4, 4, device = 'cuda', requires_grad = True, dtype = torch.float64)
|
|
316
|
+
|
|
317
|
+
torch.autograd.gradcheck(triton_sinkhorn, (log_alpha_double, 10), eps = 1e-6, atol = 1e-5)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|