hyper-connections 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyper_connections/mHCv2.py +21 -8
- hyper_connections/manifold_constrained_hyper_connections.py +2 -2
- hyper_connections/triton_sinkhorn.py +160 -0
- {hyper_connections-0.4.3.dist-info → hyper_connections-0.4.5.dist-info}/METADATA +1 -1
- {hyper_connections-0.4.3.dist-info → hyper_connections-0.4.5.dist-info}/RECORD +7 -6
- {hyper_connections-0.4.3.dist-info → hyper_connections-0.4.5.dist-info}/WHEEL +0 -0
- {hyper_connections-0.4.3.dist-info → hyper_connections-0.4.5.dist-info}/licenses/LICENSE +0 -0
hyper_connections/mHCv2.py
CHANGED
|
@@ -69,8 +69,8 @@ def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
69
69
|
log_alpha = log_alpha.float()
|
|
70
70
|
|
|
71
71
|
for _ in range(iters):
|
|
72
|
-
log_alpha = log_alpha
|
|
73
|
-
log_alpha = log_alpha
|
|
72
|
+
log_alpha = F.log_softmax(log_alpha, dim = -2)
|
|
73
|
+
log_alpha = F.log_softmax(log_alpha, dim = -1)
|
|
74
74
|
|
|
75
75
|
return log_alpha.exp().to(dtype)
|
|
76
76
|
|
|
@@ -109,6 +109,7 @@ def get_init_and_expand_reduce_stream_functions(
|
|
|
109
109
|
add_attn_pool_reduce_stream = False,
|
|
110
110
|
disable = None,
|
|
111
111
|
sinkhorn_iters = 20,
|
|
112
|
+
use_triton_sinkhorn = False,
|
|
112
113
|
**kwargs
|
|
113
114
|
):
|
|
114
115
|
disable = default(disable, num_streams == 1 and num_fracs == 1)
|
|
@@ -116,7 +117,7 @@ def get_init_and_expand_reduce_stream_functions(
|
|
|
116
117
|
hyper_conn_klass = ManifoldConstrainedHyperConnections if not disable else Residual
|
|
117
118
|
|
|
118
119
|
kwargs.pop('add_attn_pool_reduce_stream', None)
|
|
119
|
-
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, **kwargs)
|
|
120
|
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs = num_fracs, sinkhorn_iters = sinkhorn_iters, use_triton_sinkhorn = use_triton_sinkhorn, **kwargs)
|
|
120
121
|
expand_reduce_fns = get_expand_reduce_stream_functions(
|
|
121
122
|
num_streams,
|
|
122
123
|
add_stream_embed = add_stream_embed,
|
|
@@ -231,7 +232,7 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
231
232
|
residual_mix_constraint_fn: Callable | None = None,
|
|
232
233
|
forward_method_names: tuple[str, ...] = (),
|
|
233
234
|
num_dynamic_alpha_proposals = 1,
|
|
234
|
-
|
|
235
|
+
use_triton_sinkhorn = False,
|
|
235
236
|
):
|
|
236
237
|
"""
|
|
237
238
|
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
|
@@ -306,10 +307,22 @@ class ManifoldConstrainedHyperConnections(Module):
|
|
|
306
307
|
# Hres constraint related
|
|
307
308
|
# by default is sinkhorn
|
|
308
309
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
310
|
+
use_triton_sinkhorn_and_available = False
|
|
311
|
+
|
|
312
|
+
if use_triton_sinkhorn:
|
|
313
|
+
try:
|
|
314
|
+
from hyper_connections.triton_sinkhorn import triton_sinkhorn, is_triton_available
|
|
315
|
+
use_triton_sinkhorn_and_available = is_triton_available()
|
|
316
|
+
except ImportError:
|
|
317
|
+
use_triton_sinkhorn_and_available = False
|
|
318
|
+
|
|
319
|
+
if use_triton_sinkhorn_and_available:
|
|
320
|
+
self.residual_mix_constraint_fn = partial(triton_sinkhorn, iters = sinkhorn_iters)
|
|
321
|
+
else:
|
|
322
|
+
self.residual_mix_constraint_fn = default(
|
|
323
|
+
residual_mix_constraint_fn,
|
|
324
|
+
partial(sinkhorn_knopps if not log_domain_sinkhorn else log_domain_sinkhorn_knopps, iters = sinkhorn_iters)
|
|
325
|
+
)
|
|
313
326
|
|
|
314
327
|
# dropouts
|
|
315
328
|
|
|
@@ -65,8 +65,8 @@ def log_domain_sinkhorn_knopps(log_alpha, iters = 20):
|
|
|
65
65
|
log_alpha = log_alpha.float()
|
|
66
66
|
|
|
67
67
|
for _ in range(iters):
|
|
68
|
-
log_alpha = log_alpha
|
|
69
|
-
log_alpha = log_alpha
|
|
68
|
+
log_alpha = F.log_softmax(log_alpha, dim = -2)
|
|
69
|
+
log_alpha = F.log_softmax(log_alpha, dim = -1)
|
|
70
70
|
|
|
71
71
|
return log_alpha.exp().to(dtype)
|
|
72
72
|
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
from torch.autograd import Function
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def sinkhorn_kernel_forward_log(
|
|
8
|
+
input_ptr,
|
|
9
|
+
output_ptr,
|
|
10
|
+
M, N,
|
|
11
|
+
stride_b, stride_m, stride_n,
|
|
12
|
+
iters: tl.constexpr,
|
|
13
|
+
BLOCK_SIZE: tl.constexpr,
|
|
14
|
+
):
|
|
15
|
+
pid_b = tl.program_id(0)
|
|
16
|
+
offs_m = tl.arange(0, BLOCK_SIZE)
|
|
17
|
+
offs_n = tl.arange(0, BLOCK_SIZE)
|
|
18
|
+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
19
|
+
|
|
20
|
+
curr_input_ptr = input_ptr + pid_b * stride_b
|
|
21
|
+
# Use a large negative value for log-space padding to avoid interference
|
|
22
|
+
log_alpha = tl.load(curr_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=-1e10)
|
|
23
|
+
|
|
24
|
+
# Use static_range to force unrolling and avoid compiler bugs with dynamic loops in this environment
|
|
25
|
+
for _ in tl.static_range(iters):
|
|
26
|
+
# Column-wise Log-Softmax (dim=-2)
|
|
27
|
+
col_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=0)
|
|
28
|
+
exp_weights_col = tl.exp(log_alpha - col_max[None, :])
|
|
29
|
+
exp_weights_col = tl.where(mask, exp_weights_col, 0.0)
|
|
30
|
+
col_lse = col_max + tl.log(tl.sum(exp_weights_col, axis=0))
|
|
31
|
+
log_alpha = log_alpha - col_lse[None, :]
|
|
32
|
+
log_alpha = tl.where(mask, log_alpha, -1e10)
|
|
33
|
+
|
|
34
|
+
# Row-wise Log-Softmax (dim=-1)
|
|
35
|
+
row_max = tl.max(tl.where(mask, log_alpha, -1e10), axis=1)
|
|
36
|
+
exp_weights_row = tl.exp(log_alpha - row_max[:, None])
|
|
37
|
+
exp_weights_row = tl.where(mask, exp_weights_row, 0.0)
|
|
38
|
+
row_lse = row_max + tl.log(tl.sum(exp_weights_row, axis=1))
|
|
39
|
+
log_alpha = log_alpha - row_lse[:, None]
|
|
40
|
+
log_alpha = tl.where(mask, log_alpha, -1e10)
|
|
41
|
+
|
|
42
|
+
result_alpha = tl.exp(log_alpha)
|
|
43
|
+
result_alpha = tl.where(mask, result_alpha, 0.0)
|
|
44
|
+
|
|
45
|
+
curr_output_ptr = output_ptr + pid_b * stride_b
|
|
46
|
+
tl.store(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, result_alpha, mask=mask)
|
|
47
|
+
|
|
48
|
+
@triton.jit
|
|
49
|
+
def sinkhorn_kernel_backward_log(
|
|
50
|
+
grad_output_ptr,
|
|
51
|
+
output_ptr,
|
|
52
|
+
grad_input_ptr,
|
|
53
|
+
M, N,
|
|
54
|
+
stride_b, stride_m, stride_n,
|
|
55
|
+
iters: tl.constexpr,
|
|
56
|
+
BLOCK_SIZE: tl.constexpr,
|
|
57
|
+
):
|
|
58
|
+
pid_b = tl.program_id(0)
|
|
59
|
+
offs_m = tl.arange(0, BLOCK_SIZE)
|
|
60
|
+
offs_n = tl.arange(0, BLOCK_SIZE)
|
|
61
|
+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
62
|
+
|
|
63
|
+
curr_output_ptr = output_ptr + pid_b * stride_b
|
|
64
|
+
curr_grad_output_ptr = grad_output_ptr + pid_b * stride_b
|
|
65
|
+
|
|
66
|
+
alpha = tl.load(curr_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
|
|
67
|
+
grad_alpha = tl.load(curr_grad_output_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, mask=mask, other=0.0)
|
|
68
|
+
|
|
69
|
+
# Ensure they are truly zeroed in padded areas for sum robustness
|
|
70
|
+
alpha = tl.where(mask, alpha, 0.0)
|
|
71
|
+
grad_alpha = tl.where(mask, grad_alpha, 0.0)
|
|
72
|
+
|
|
73
|
+
for _ in tl.static_range(iters):
|
|
74
|
+
# Backward of Row-wise Normalization
|
|
75
|
+
# Sum only over valid elements
|
|
76
|
+
row_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=1)
|
|
77
|
+
grad_alpha = grad_alpha - row_sum_grad_alpha[:, None]
|
|
78
|
+
grad_alpha = tl.where(mask, grad_alpha, 0.0)
|
|
79
|
+
|
|
80
|
+
# Backward of Column-wise Normalization
|
|
81
|
+
col_sum_grad_alpha = tl.sum(tl.where(mask, grad_alpha * alpha, 0.0), axis=0)
|
|
82
|
+
grad_alpha = grad_alpha - col_sum_grad_alpha[None, :]
|
|
83
|
+
grad_alpha = tl.where(mask, grad_alpha, 0.0)
|
|
84
|
+
|
|
85
|
+
grad_input = alpha * grad_alpha
|
|
86
|
+
|
|
87
|
+
curr_grad_input_ptr = grad_input_ptr + pid_b * stride_b
|
|
88
|
+
tl.store(curr_grad_input_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n, grad_input, mask=mask)
|
|
89
|
+
|
|
90
|
+
class TritonSinkhornFunction(Function):
|
|
91
|
+
@staticmethod
|
|
92
|
+
def forward(ctx, log_alpha, iters=20):
|
|
93
|
+
# Handle matrix size limits to avoid register spilling/SRAM overflow
|
|
94
|
+
M, N = log_alpha.shape[-2:]
|
|
95
|
+
if max(M, N) > 256:
|
|
96
|
+
from hyper_connections.mHCv2 import log_domain_sinkhorn_knopps
|
|
97
|
+
return log_domain_sinkhorn_knopps(log_alpha, iters)
|
|
98
|
+
|
|
99
|
+
batch_shape = log_alpha.shape[:-2]
|
|
100
|
+
log_alpha_flat = log_alpha.view(-1, M, N).contiguous()
|
|
101
|
+
B = log_alpha_flat.shape[0]
|
|
102
|
+
|
|
103
|
+
output = torch.empty_like(log_alpha_flat)
|
|
104
|
+
BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
|
|
105
|
+
|
|
106
|
+
sinkhorn_kernel_forward_log[(B,)](
|
|
107
|
+
log_alpha_flat,
|
|
108
|
+
output,
|
|
109
|
+
M, N,
|
|
110
|
+
log_alpha_flat.stride(0), log_alpha_flat.stride(1), log_alpha_flat.stride(2),
|
|
111
|
+
iters=iters,
|
|
112
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
113
|
+
num_warps=4
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
ctx.save_for_backward(output)
|
|
117
|
+
ctx.iters = iters
|
|
118
|
+
return output.view(*batch_shape, M, N)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def backward(ctx, grad_output):
|
|
122
|
+
output, = ctx.saved_tensors
|
|
123
|
+
iters = ctx.iters
|
|
124
|
+
B, M, N = output.shape
|
|
125
|
+
BLOCK_SIZE = max(32, triton.next_power_of_2(max(M, N)))
|
|
126
|
+
|
|
127
|
+
# Explicit contiguity for grad_output
|
|
128
|
+
grad_output = grad_output.contiguous()
|
|
129
|
+
grad_input = torch.empty_like(output)
|
|
130
|
+
|
|
131
|
+
sinkhorn_kernel_backward_log[(B,)](
|
|
132
|
+
grad_output.view(B, M, N),
|
|
133
|
+
output,
|
|
134
|
+
grad_input,
|
|
135
|
+
M, N,
|
|
136
|
+
grad_input.stride(0), grad_input.stride(1), grad_input.stride(2),
|
|
137
|
+
iters=iters,
|
|
138
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
139
|
+
num_warps=4
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return grad_input.view_as(grad_output), None
|
|
143
|
+
|
|
144
|
+
def triton_sinkhorn(log_alpha, iters=20):
|
|
145
|
+
if log_alpha.is_cuda:
|
|
146
|
+
try:
|
|
147
|
+
return TritonSinkhornFunction.apply(log_alpha, iters)
|
|
148
|
+
except Exception:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
# fallback
|
|
152
|
+
from hyper_connections.mHCv2 import sinkhorn_knopps
|
|
153
|
+
return sinkhorn_knopps(log_alpha, iters = iters)
|
|
154
|
+
|
|
155
|
+
def is_triton_available():
|
|
156
|
+
try:
|
|
157
|
+
import triton
|
|
158
|
+
return torch.cuda.is_available()
|
|
159
|
+
except ImportError:
|
|
160
|
+
return False
|
|
@@ -3,11 +3,12 @@ hyper_connections/hyper_connections.py,sha256=2F-104cGE82KCK0KeC07NSOJNPT-0PCtvX
|
|
|
3
3
|
hyper_connections/hyper_connections_channel_first.py,sha256=5vAen4WXxNI9K07ndLBQJwdJv-OjoXznta5EIQTaQNA,6512
|
|
4
4
|
hyper_connections/hyper_connections_with_multi_branch_inputs.py,sha256=yn2AlFB6qCYQeRhhhaMlCM3mxxLEdWCYwU2p9TsMwWI,7835
|
|
5
5
|
hyper_connections/hyper_connections_with_multi_input_streams.py,sha256=iFPw5pgCRHTo16nBJ2PExKSrvTyCh7ba7Py14P1oSPE,11311
|
|
6
|
-
hyper_connections/mHCv2.py,sha256=
|
|
7
|
-
hyper_connections/manifold_constrained_hyper_connections.py,sha256=
|
|
6
|
+
hyper_connections/mHCv2.py,sha256=wCtp87OFI3QfosdSL-1qwsiQN9f8gX32_0r8GQGO7P0,17411
|
|
7
|
+
hyper_connections/manifold_constrained_hyper_connections.py,sha256=E4os-6q_SMjJO1JD0EG8rFTCXA7MQoy-aqUlM7KVS5Q,18269
|
|
8
8
|
hyper_connections/residuals.py,sha256=JVSFJj_H7xQ3_Fd-pZH5Hdv9SveAQu29jQNvMyom5ek,921
|
|
9
|
+
hyper_connections/triton_sinkhorn.py,sha256=n2WyQcUemtv5T5Sk2nljnSpV2hEED4I3HaPsIUy4638,5905
|
|
9
10
|
hyper_connections/vit.py,sha256=BOWVfCAIzDQdnTq8OBzNUyiKGGILYZkIQ6mr1GKJVB0,5225
|
|
10
|
-
hyper_connections-0.4.
|
|
11
|
-
hyper_connections-0.4.
|
|
12
|
-
hyper_connections-0.4.
|
|
13
|
-
hyper_connections-0.4.
|
|
11
|
+
hyper_connections-0.4.5.dist-info/METADATA,sha256=sWVb_-yVRmxL8AsAPsk0VdRXOa25uG9zKNc8S_oAXg8,6704
|
|
12
|
+
hyper_connections-0.4.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
13
|
+
hyper_connections-0.4.5.dist-info/licenses/LICENSE,sha256=E7RGS7kpJIStk5za_-4DVhWEAamf65EU0CNML25mq4c,1066
|
|
14
|
+
hyper_connections-0.4.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|