mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,424 @@
1
+ # @nolint # fbcode
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ class IndexFirstAxis(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(ctx, input, indices):
13
+ ctx.save_for_backward(indices)
14
+ assert input.ndim >= 2
15
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
16
+ second_dim = other_shape.numel()
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"),
19
+ 0,
20
+ repeat(indices, "z -> z d", d=second_dim),
21
+ ).reshape(-1, *other_shape)
22
+
23
+ @staticmethod
24
+ def backward(ctx, grad_output):
25
+ (indices,) = ctx.saved_tensors
26
+ assert grad_output.ndim >= 2
27
+ other_shape = grad_output.shape[1:]
28
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
29
+ grad_input = torch.zeros(
30
+ [ctx.first_axis_dim, grad_output.shape[1]],
31
+ device=grad_output.device,
32
+ dtype=grad_output.dtype,
33
+ )
34
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
35
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
36
+
37
+
38
+ index_first_axis = IndexFirstAxis.apply
39
+
40
+
41
+ class IndexPutFirstAxis(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(ctx, values, indices, first_axis_dim):
44
+ ctx.save_for_backward(indices)
45
+ assert indices.ndim == 1
46
+ assert values.ndim >= 2
47
+ output = torch.zeros(
48
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
49
+ )
50
+ output[indices] = values
51
+ return output
52
+
53
+ @staticmethod
54
+ def backward(ctx, grad_output):
55
+ (indices,) = ctx.saved_tensors
56
+ grad_values = grad_output[indices]
57
+ return grad_values, None, None
58
+
59
+
60
+ index_put_first_axis = IndexPutFirstAxis.apply
61
+
62
+
63
+ def unpad_input(hidden_states, attention_mask, unused_mask=None):
64
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
65
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
66
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
67
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
68
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
69
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
70
+ return (
71
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ used_seqlens_in_batch,
76
+ )
77
+
78
+
79
+ def pad_input(hidden_states, indices, batch, seqlen):
80
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
81
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
82
+
83
+
84
+ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
85
+ assert mode in ["full", "random", "third"]
86
+ if mode == "full":
87
+ lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
88
+ elif mode == "random":
89
+ lengths = torch.randint(
90
+ max(0 if zero_lengths else 1, max_seqlen - 20),
91
+ max_seqlen + 1,
92
+ (batch_size, 1),
93
+ device=device,
94
+ )
95
+ else:
96
+ lengths = torch.randint(
97
+ max(0 if zero_lengths else 1, max_seqlen // 3),
98
+ max_seqlen + 1,
99
+ (batch_size, 1),
100
+ device=device,
101
+ )
102
+
103
+ if zero_lengths:
104
+ for i in range(batch_size):
105
+ if i % 5 == 0:
106
+ lengths[i] = 0
107
+ lengths[-1] = 0
108
+ padding_mask = (
109
+ repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
110
+ )
111
+ return padding_mask
112
+
113
+
114
+ def generate_qkv(
115
+ q,
116
+ k,
117
+ v,
118
+ query_padding_mask=None,
119
+ key_padding_mask=None,
120
+ qv=None,
121
+ kvpacked=False,
122
+ qkvpacked=False,
123
+ query_unused_mask=None,
124
+ key_unused_mask=None,
125
+ ):
126
+ assert not (kvpacked and qkvpacked)
127
+ batch_size, seqlen_q, nheads, d = q.shape
128
+ d_v = v.shape[-1]
129
+ _, seqlen_k, nheads_k, _ = k.shape
130
+ assert k.shape == (batch_size, seqlen_k, nheads_k, d)
131
+ assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
132
+ if query_unused_mask is not None or key_unused_mask is not None:
133
+ assert not kvpacked
134
+ assert not qkvpacked
135
+
136
+ if query_padding_mask is not None:
137
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
138
+ q, query_padding_mask, query_unused_mask
139
+ )
140
+ output_pad_fn = lambda output_unpad: pad_input(
141
+ output_unpad, indices_q, batch_size, seqlen_q
142
+ )
143
+ qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
144
+ else:
145
+ q_unpad = rearrange(q, "b s h d -> (b s) h d")
146
+ cu_seqlens_q = torch.arange(
147
+ 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
148
+ )
149
+ seqused_q = None
150
+ max_seqlen_q = seqlen_q
151
+ output_pad_fn = lambda output_unpad: rearrange(
152
+ output_unpad, "(b s) h d -> b s h d", b=batch_size
153
+ )
154
+ qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
155
+
156
+ if key_padding_mask is not None:
157
+ k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
158
+ k, key_padding_mask, key_unused_mask
159
+ )
160
+ v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask)
161
+ else:
162
+ k_unpad = rearrange(k, "b s h d -> (b s) h d")
163
+ v_unpad = rearrange(v, "b s h d -> (b s) h d")
164
+ cu_seqlens_k = torch.arange(
165
+ 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
166
+ )
167
+ seqused_k = None
168
+ max_seqlen_k = seqlen_k
169
+
170
+ if qkvpacked:
171
+ assert (query_padding_mask == key_padding_mask).all()
172
+ assert nheads == nheads_k
173
+ qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
174
+ qkv = torch.stack([q, k, v], dim=2)
175
+ if query_padding_mask is not None:
176
+ dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
177
+ else:
178
+ dqkv_pad_fn = lambda dqkv_unpad: rearrange(
179
+ dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
180
+ )
181
+ return (
182
+ qkv_unpad.detach().requires_grad_(),
183
+ cu_seqlens_q,
184
+ max_seqlen_q,
185
+ qkv.detach().requires_grad_(),
186
+ output_pad_fn,
187
+ dqkv_pad_fn,
188
+ )
189
+ elif kvpacked:
190
+ kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
191
+ kv = torch.stack([k, v], dim=2)
192
+ dq_pad_fn = output_pad_fn
193
+ if key_padding_mask is not None:
194
+ dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
195
+ else:
196
+ dkv_pad_fn = lambda dkv_unpad: rearrange(
197
+ dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
198
+ )
199
+ return (
200
+ q_unpad.detach().requires_grad_(),
201
+ kv_unpad.detach().requires_grad_(),
202
+ cu_seqlens_q,
203
+ cu_seqlens_k,
204
+ max_seqlen_q,
205
+ max_seqlen_k,
206
+ q.detach().requires_grad_(),
207
+ kv.detach().requires_grad_(),
208
+ output_pad_fn,
209
+ dq_pad_fn,
210
+ dkv_pad_fn,
211
+ )
212
+ else:
213
+ dq_pad_fn = output_pad_fn
214
+ if key_padding_mask is not None:
215
+ dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
216
+ else:
217
+ dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
218
+ return (
219
+ q_unpad.detach().requires_grad_(),
220
+ k_unpad.detach().requires_grad_(),
221
+ v_unpad.detach().requires_grad_(),
222
+ qv_unpad.detach() if qv is not None else None,
223
+ cu_seqlens_q,
224
+ cu_seqlens_k,
225
+ seqused_q,
226
+ seqused_k,
227
+ max_seqlen_q,
228
+ max_seqlen_k,
229
+ q.detach().requires_grad_(),
230
+ k.detach().requires_grad_(),
231
+ v.detach().requires_grad_(),
232
+ qv.detach() if qv is not None else None,
233
+ output_pad_fn,
234
+ dq_pad_fn,
235
+ dk_pad_fn,
236
+ )
237
+
238
+
239
+ def construct_local_mask(
240
+ seqlen_q,
241
+ seqlen_k,
242
+ window_size=(None, None),
243
+ sink_token_length=0,
244
+ query_padding_mask=None,
245
+ key_padding_mask=None,
246
+ key_leftpad=None,
247
+ device=None,
248
+ ):
249
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
250
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
251
+ if key_leftpad is not None:
252
+ key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
253
+ col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
254
+ col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
255
+ sk = (
256
+ seqlen_k
257
+ if key_padding_mask is None
258
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
259
+ )
260
+ sq = (
261
+ seqlen_q
262
+ if query_padding_mask is None
263
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
264
+ )
265
+ if window_size[0] is None:
266
+ return col_idx > row_idx + sk - sq + window_size[1]
267
+ else:
268
+ sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
269
+ if window_size[1] is None:
270
+ local_mask_left = col_idx > sk
271
+ else:
272
+ local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk)
273
+ return torch.logical_or(
274
+ local_mask_left,
275
+ torch.logical_and(
276
+ col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length
277
+ ),
278
+ )
279
+
280
+
281
+ def construct_chunk_mask(
282
+ seqlen_q,
283
+ seqlen_k,
284
+ attention_chunk,
285
+ query_padding_mask=None,
286
+ key_padding_mask=None,
287
+ key_leftpad=None,
288
+ device=None,
289
+ ):
290
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
291
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
292
+ if key_leftpad is not None:
293
+ key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
294
+ col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
295
+ col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
296
+ sk = (
297
+ seqlen_k
298
+ if key_padding_mask is None
299
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
300
+ )
301
+ sq = (
302
+ seqlen_q
303
+ if query_padding_mask is None
304
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
305
+ )
306
+ sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
307
+ col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
308
+ return torch.logical_or(
309
+ col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk
310
+ )
311
+
312
+
313
+ def attention_ref(
314
+ q,
315
+ k,
316
+ v,
317
+ query_padding_mask=None,
318
+ key_padding_mask=None,
319
+ key_leftpad=None,
320
+ attn_bias=None,
321
+ dropout_p=0.0,
322
+ dropout_mask=None,
323
+ causal=False,
324
+ qv=None,
325
+ q_descale=None,
326
+ k_descale=None,
327
+ v_descale=None,
328
+ window_size=(None, None),
329
+ attention_chunk=0,
330
+ sink_token_length=0,
331
+ learnable_sink: Optional[torch.Tensor] = None,
332
+ softcap=0.0,
333
+ upcast=True,
334
+ reorder_ops=False,
335
+ intermediate_dtype=None,
336
+ ):
337
+ if causal:
338
+ window_size = (window_size[0], 0)
339
+ dtype_og = q.dtype
340
+ if upcast:
341
+ q, k, v = q.float(), k.float(), v.float()
342
+ qv = qv.float() if qv is not None else None
343
+ if q_descale is not None:
344
+ q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
345
+ q = (q.float() * q_descale).to(q.dtype)
346
+ qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
347
+ if k_descale is not None:
348
+ k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
349
+ if v_descale is not None:
350
+ v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
351
+ seqlen_q, seqlen_k = q.shape[1], k.shape[1]
352
+ k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
353
+ v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
354
+ d = q.shape[-1]
355
+ dv = v.shape[-1]
356
+ softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
357
+ if not reorder_ops:
358
+ scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
359
+ else:
360
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
361
+ if qv is not None:
362
+ scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
363
+ if softcap > 0:
364
+ scores = torch.tanh(scores / softcap) * softcap
365
+ if key_padding_mask is not None:
366
+ scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
367
+ local_mask = None
368
+ if window_size[0] is not None or window_size[1] is not None:
369
+ local_mask = construct_local_mask(
370
+ seqlen_q,
371
+ seqlen_k,
372
+ window_size,
373
+ sink_token_length,
374
+ query_padding_mask,
375
+ key_padding_mask,
376
+ key_leftpad=key_leftpad,
377
+ device=q.device,
378
+ )
379
+ if attention_chunk > 0:
380
+ chunk_mask = construct_chunk_mask(
381
+ seqlen_q,
382
+ seqlen_k,
383
+ attention_chunk,
384
+ query_padding_mask,
385
+ key_padding_mask,
386
+ key_leftpad=key_leftpad,
387
+ device=q.device,
388
+ )
389
+ local_mask = (
390
+ torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
391
+ )
392
+ if local_mask is not None:
393
+ scores.masked_fill_(local_mask, float("-inf"))
394
+ if attn_bias is not None:
395
+ scores = scores + attn_bias
396
+ if learnable_sink is None:
397
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
398
+ else:
399
+ scores_fp32 = scores.to(torch.float32)
400
+ logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
401
+ learnable_sink = rearrange(learnable_sink, "h -> h 1 1")
402
+ logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)
403
+ unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
404
+ normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
405
+ learnable_sink - logits_or_sinks_max
406
+ )
407
+ attention = (unnormalized_scores / normalizer).to(v.dtype)
408
+ if query_padding_mask is not None:
409
+ attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
410
+ if key_padding_mask is not None:
411
+ attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
412
+ if local_mask is not None:
413
+ attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
414
+ dropout_scaling = 1.0 / (1 - dropout_p)
415
+ if dropout_mask is not None:
416
+ attention_drop = attention.masked_fill(~dropout_mask, 0.0)
417
+ else:
418
+ attention_drop = attention
419
+ if intermediate_dtype is not None:
420
+ attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
421
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
422
+ if query_padding_mask is not None:
423
+ output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
424
+ return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)