mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
import math
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from einops import rearrange, repeat
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class IndexFirstAxis(torch.autograd.Function):
|
|
11
|
+
@staticmethod
|
|
12
|
+
def forward(ctx, input, indices):
|
|
13
|
+
ctx.save_for_backward(indices)
|
|
14
|
+
assert input.ndim >= 2
|
|
15
|
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
|
16
|
+
second_dim = other_shape.numel()
|
|
17
|
+
return torch.gather(
|
|
18
|
+
rearrange(input, "b ... -> b (...)"),
|
|
19
|
+
0,
|
|
20
|
+
repeat(indices, "z -> z d", d=second_dim),
|
|
21
|
+
).reshape(-1, *other_shape)
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def backward(ctx, grad_output):
|
|
25
|
+
(indices,) = ctx.saved_tensors
|
|
26
|
+
assert grad_output.ndim >= 2
|
|
27
|
+
other_shape = grad_output.shape[1:]
|
|
28
|
+
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
|
29
|
+
grad_input = torch.zeros(
|
|
30
|
+
[ctx.first_axis_dim, grad_output.shape[1]],
|
|
31
|
+
device=grad_output.device,
|
|
32
|
+
dtype=grad_output.dtype,
|
|
33
|
+
)
|
|
34
|
+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
|
35
|
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
index_first_axis = IndexFirstAxis.apply
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class IndexPutFirstAxis(torch.autograd.Function):
|
|
42
|
+
@staticmethod
|
|
43
|
+
def forward(ctx, values, indices, first_axis_dim):
|
|
44
|
+
ctx.save_for_backward(indices)
|
|
45
|
+
assert indices.ndim == 1
|
|
46
|
+
assert values.ndim >= 2
|
|
47
|
+
output = torch.zeros(
|
|
48
|
+
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
|
|
49
|
+
)
|
|
50
|
+
output[indices] = values
|
|
51
|
+
return output
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def backward(ctx, grad_output):
|
|
55
|
+
(indices,) = ctx.saved_tensors
|
|
56
|
+
grad_values = grad_output[indices]
|
|
57
|
+
return grad_values, None, None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
index_put_first_axis = IndexPutFirstAxis.apply
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
|
64
|
+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
|
65
|
+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
|
66
|
+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
67
|
+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
|
68
|
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
69
|
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
70
|
+
return (
|
|
71
|
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
|
72
|
+
indices,
|
|
73
|
+
cu_seqlens,
|
|
74
|
+
max_seqlen_in_batch,
|
|
75
|
+
used_seqlens_in_batch,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def pad_input(hidden_states, indices, batch, seqlen):
|
|
80
|
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
|
81
|
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
|
|
85
|
+
assert mode in ["full", "random", "third"]
|
|
86
|
+
if mode == "full":
|
|
87
|
+
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
|
88
|
+
elif mode == "random":
|
|
89
|
+
lengths = torch.randint(
|
|
90
|
+
max(0 if zero_lengths else 1, max_seqlen - 20),
|
|
91
|
+
max_seqlen + 1,
|
|
92
|
+
(batch_size, 1),
|
|
93
|
+
device=device,
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
lengths = torch.randint(
|
|
97
|
+
max(0 if zero_lengths else 1, max_seqlen // 3),
|
|
98
|
+
max_seqlen + 1,
|
|
99
|
+
(batch_size, 1),
|
|
100
|
+
device=device,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if zero_lengths:
|
|
104
|
+
for i in range(batch_size):
|
|
105
|
+
if i % 5 == 0:
|
|
106
|
+
lengths[i] = 0
|
|
107
|
+
lengths[-1] = 0
|
|
108
|
+
padding_mask = (
|
|
109
|
+
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
|
110
|
+
)
|
|
111
|
+
return padding_mask
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def generate_qkv(
|
|
115
|
+
q,
|
|
116
|
+
k,
|
|
117
|
+
v,
|
|
118
|
+
query_padding_mask=None,
|
|
119
|
+
key_padding_mask=None,
|
|
120
|
+
qv=None,
|
|
121
|
+
kvpacked=False,
|
|
122
|
+
qkvpacked=False,
|
|
123
|
+
query_unused_mask=None,
|
|
124
|
+
key_unused_mask=None,
|
|
125
|
+
):
|
|
126
|
+
assert not (kvpacked and qkvpacked)
|
|
127
|
+
batch_size, seqlen_q, nheads, d = q.shape
|
|
128
|
+
d_v = v.shape[-1]
|
|
129
|
+
_, seqlen_k, nheads_k, _ = k.shape
|
|
130
|
+
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
131
|
+
assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
|
|
132
|
+
if query_unused_mask is not None or key_unused_mask is not None:
|
|
133
|
+
assert not kvpacked
|
|
134
|
+
assert not qkvpacked
|
|
135
|
+
|
|
136
|
+
if query_padding_mask is not None:
|
|
137
|
+
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
|
138
|
+
q, query_padding_mask, query_unused_mask
|
|
139
|
+
)
|
|
140
|
+
output_pad_fn = lambda output_unpad: pad_input(
|
|
141
|
+
output_unpad, indices_q, batch_size, seqlen_q
|
|
142
|
+
)
|
|
143
|
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
|
|
144
|
+
else:
|
|
145
|
+
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
|
146
|
+
cu_seqlens_q = torch.arange(
|
|
147
|
+
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
|
|
148
|
+
)
|
|
149
|
+
seqused_q = None
|
|
150
|
+
max_seqlen_q = seqlen_q
|
|
151
|
+
output_pad_fn = lambda output_unpad: rearrange(
|
|
152
|
+
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
|
153
|
+
)
|
|
154
|
+
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
|
|
155
|
+
|
|
156
|
+
if key_padding_mask is not None:
|
|
157
|
+
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
|
|
158
|
+
k, key_padding_mask, key_unused_mask
|
|
159
|
+
)
|
|
160
|
+
v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask)
|
|
161
|
+
else:
|
|
162
|
+
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
|
163
|
+
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
|
164
|
+
cu_seqlens_k = torch.arange(
|
|
165
|
+
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
|
|
166
|
+
)
|
|
167
|
+
seqused_k = None
|
|
168
|
+
max_seqlen_k = seqlen_k
|
|
169
|
+
|
|
170
|
+
if qkvpacked:
|
|
171
|
+
assert (query_padding_mask == key_padding_mask).all()
|
|
172
|
+
assert nheads == nheads_k
|
|
173
|
+
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
|
174
|
+
qkv = torch.stack([q, k, v], dim=2)
|
|
175
|
+
if query_padding_mask is not None:
|
|
176
|
+
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
|
|
177
|
+
else:
|
|
178
|
+
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
|
|
179
|
+
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
|
180
|
+
)
|
|
181
|
+
return (
|
|
182
|
+
qkv_unpad.detach().requires_grad_(),
|
|
183
|
+
cu_seqlens_q,
|
|
184
|
+
max_seqlen_q,
|
|
185
|
+
qkv.detach().requires_grad_(),
|
|
186
|
+
output_pad_fn,
|
|
187
|
+
dqkv_pad_fn,
|
|
188
|
+
)
|
|
189
|
+
elif kvpacked:
|
|
190
|
+
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
|
191
|
+
kv = torch.stack([k, v], dim=2)
|
|
192
|
+
dq_pad_fn = output_pad_fn
|
|
193
|
+
if key_padding_mask is not None:
|
|
194
|
+
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
|
|
195
|
+
else:
|
|
196
|
+
dkv_pad_fn = lambda dkv_unpad: rearrange(
|
|
197
|
+
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
|
198
|
+
)
|
|
199
|
+
return (
|
|
200
|
+
q_unpad.detach().requires_grad_(),
|
|
201
|
+
kv_unpad.detach().requires_grad_(),
|
|
202
|
+
cu_seqlens_q,
|
|
203
|
+
cu_seqlens_k,
|
|
204
|
+
max_seqlen_q,
|
|
205
|
+
max_seqlen_k,
|
|
206
|
+
q.detach().requires_grad_(),
|
|
207
|
+
kv.detach().requires_grad_(),
|
|
208
|
+
output_pad_fn,
|
|
209
|
+
dq_pad_fn,
|
|
210
|
+
dkv_pad_fn,
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
dq_pad_fn = output_pad_fn
|
|
214
|
+
if key_padding_mask is not None:
|
|
215
|
+
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
|
|
216
|
+
else:
|
|
217
|
+
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
|
|
218
|
+
return (
|
|
219
|
+
q_unpad.detach().requires_grad_(),
|
|
220
|
+
k_unpad.detach().requires_grad_(),
|
|
221
|
+
v_unpad.detach().requires_grad_(),
|
|
222
|
+
qv_unpad.detach() if qv is not None else None,
|
|
223
|
+
cu_seqlens_q,
|
|
224
|
+
cu_seqlens_k,
|
|
225
|
+
seqused_q,
|
|
226
|
+
seqused_k,
|
|
227
|
+
max_seqlen_q,
|
|
228
|
+
max_seqlen_k,
|
|
229
|
+
q.detach().requires_grad_(),
|
|
230
|
+
k.detach().requires_grad_(),
|
|
231
|
+
v.detach().requires_grad_(),
|
|
232
|
+
qv.detach() if qv is not None else None,
|
|
233
|
+
output_pad_fn,
|
|
234
|
+
dq_pad_fn,
|
|
235
|
+
dk_pad_fn,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def construct_local_mask(
|
|
240
|
+
seqlen_q,
|
|
241
|
+
seqlen_k,
|
|
242
|
+
window_size=(None, None),
|
|
243
|
+
sink_token_length=0,
|
|
244
|
+
query_padding_mask=None,
|
|
245
|
+
key_padding_mask=None,
|
|
246
|
+
key_leftpad=None,
|
|
247
|
+
device=None,
|
|
248
|
+
):
|
|
249
|
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
|
250
|
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
|
251
|
+
if key_leftpad is not None:
|
|
252
|
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
|
253
|
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
|
254
|
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
|
255
|
+
sk = (
|
|
256
|
+
seqlen_k
|
|
257
|
+
if key_padding_mask is None
|
|
258
|
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
259
|
+
)
|
|
260
|
+
sq = (
|
|
261
|
+
seqlen_q
|
|
262
|
+
if query_padding_mask is None
|
|
263
|
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
264
|
+
)
|
|
265
|
+
if window_size[0] is None:
|
|
266
|
+
return col_idx > row_idx + sk - sq + window_size[1]
|
|
267
|
+
else:
|
|
268
|
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
|
269
|
+
if window_size[1] is None:
|
|
270
|
+
local_mask_left = col_idx > sk
|
|
271
|
+
else:
|
|
272
|
+
local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk)
|
|
273
|
+
return torch.logical_or(
|
|
274
|
+
local_mask_left,
|
|
275
|
+
torch.logical_and(
|
|
276
|
+
col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length
|
|
277
|
+
),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def construct_chunk_mask(
|
|
282
|
+
seqlen_q,
|
|
283
|
+
seqlen_k,
|
|
284
|
+
attention_chunk,
|
|
285
|
+
query_padding_mask=None,
|
|
286
|
+
key_padding_mask=None,
|
|
287
|
+
key_leftpad=None,
|
|
288
|
+
device=None,
|
|
289
|
+
):
|
|
290
|
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
|
291
|
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
|
292
|
+
if key_leftpad is not None:
|
|
293
|
+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
|
|
294
|
+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
|
|
295
|
+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
|
|
296
|
+
sk = (
|
|
297
|
+
seqlen_k
|
|
298
|
+
if key_padding_mask is None
|
|
299
|
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
300
|
+
)
|
|
301
|
+
sq = (
|
|
302
|
+
seqlen_q
|
|
303
|
+
if query_padding_mask is None
|
|
304
|
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
|
305
|
+
)
|
|
306
|
+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
|
307
|
+
col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
|
|
308
|
+
return torch.logical_or(
|
|
309
|
+
col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def attention_ref(
|
|
314
|
+
q,
|
|
315
|
+
k,
|
|
316
|
+
v,
|
|
317
|
+
query_padding_mask=None,
|
|
318
|
+
key_padding_mask=None,
|
|
319
|
+
key_leftpad=None,
|
|
320
|
+
attn_bias=None,
|
|
321
|
+
dropout_p=0.0,
|
|
322
|
+
dropout_mask=None,
|
|
323
|
+
causal=False,
|
|
324
|
+
qv=None,
|
|
325
|
+
q_descale=None,
|
|
326
|
+
k_descale=None,
|
|
327
|
+
v_descale=None,
|
|
328
|
+
window_size=(None, None),
|
|
329
|
+
attention_chunk=0,
|
|
330
|
+
sink_token_length=0,
|
|
331
|
+
learnable_sink: Optional[torch.Tensor] = None,
|
|
332
|
+
softcap=0.0,
|
|
333
|
+
upcast=True,
|
|
334
|
+
reorder_ops=False,
|
|
335
|
+
intermediate_dtype=None,
|
|
336
|
+
):
|
|
337
|
+
if causal:
|
|
338
|
+
window_size = (window_size[0], 0)
|
|
339
|
+
dtype_og = q.dtype
|
|
340
|
+
if upcast:
|
|
341
|
+
q, k, v = q.float(), k.float(), v.float()
|
|
342
|
+
qv = qv.float() if qv is not None else None
|
|
343
|
+
if q_descale is not None:
|
|
344
|
+
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
|
|
345
|
+
q = (q.float() * q_descale).to(q.dtype)
|
|
346
|
+
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
|
|
347
|
+
if k_descale is not None:
|
|
348
|
+
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
|
|
349
|
+
if v_descale is not None:
|
|
350
|
+
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
|
|
351
|
+
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
|
352
|
+
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
|
|
353
|
+
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
|
|
354
|
+
d = q.shape[-1]
|
|
355
|
+
dv = v.shape[-1]
|
|
356
|
+
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
|
|
357
|
+
if not reorder_ops:
|
|
358
|
+
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
|
|
359
|
+
else:
|
|
360
|
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
|
361
|
+
if qv is not None:
|
|
362
|
+
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
|
|
363
|
+
if softcap > 0:
|
|
364
|
+
scores = torch.tanh(scores / softcap) * softcap
|
|
365
|
+
if key_padding_mask is not None:
|
|
366
|
+
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
|
367
|
+
local_mask = None
|
|
368
|
+
if window_size[0] is not None or window_size[1] is not None:
|
|
369
|
+
local_mask = construct_local_mask(
|
|
370
|
+
seqlen_q,
|
|
371
|
+
seqlen_k,
|
|
372
|
+
window_size,
|
|
373
|
+
sink_token_length,
|
|
374
|
+
query_padding_mask,
|
|
375
|
+
key_padding_mask,
|
|
376
|
+
key_leftpad=key_leftpad,
|
|
377
|
+
device=q.device,
|
|
378
|
+
)
|
|
379
|
+
if attention_chunk > 0:
|
|
380
|
+
chunk_mask = construct_chunk_mask(
|
|
381
|
+
seqlen_q,
|
|
382
|
+
seqlen_k,
|
|
383
|
+
attention_chunk,
|
|
384
|
+
query_padding_mask,
|
|
385
|
+
key_padding_mask,
|
|
386
|
+
key_leftpad=key_leftpad,
|
|
387
|
+
device=q.device,
|
|
388
|
+
)
|
|
389
|
+
local_mask = (
|
|
390
|
+
torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
|
|
391
|
+
)
|
|
392
|
+
if local_mask is not None:
|
|
393
|
+
scores.masked_fill_(local_mask, float("-inf"))
|
|
394
|
+
if attn_bias is not None:
|
|
395
|
+
scores = scores + attn_bias
|
|
396
|
+
if learnable_sink is None:
|
|
397
|
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
|
398
|
+
else:
|
|
399
|
+
scores_fp32 = scores.to(torch.float32)
|
|
400
|
+
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
|
|
401
|
+
learnable_sink = rearrange(learnable_sink, "h -> h 1 1")
|
|
402
|
+
logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)
|
|
403
|
+
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
|
|
404
|
+
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
|
|
405
|
+
learnable_sink - logits_or_sinks_max
|
|
406
|
+
)
|
|
407
|
+
attention = (unnormalized_scores / normalizer).to(v.dtype)
|
|
408
|
+
if query_padding_mask is not None:
|
|
409
|
+
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
|
410
|
+
if key_padding_mask is not None:
|
|
411
|
+
attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
|
|
412
|
+
if local_mask is not None:
|
|
413
|
+
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
|
|
414
|
+
dropout_scaling = 1.0 / (1 - dropout_p)
|
|
415
|
+
if dropout_mask is not None:
|
|
416
|
+
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
|
417
|
+
else:
|
|
418
|
+
attention_drop = attention
|
|
419
|
+
if intermediate_dtype is not None:
|
|
420
|
+
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
|
|
421
|
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
|
422
|
+
if query_padding_mask is not None:
|
|
423
|
+
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
|
424
|
+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|