fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,751 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from .common import expect_contiguous
15
+
16
+
17
+ @triton.jit
18
+ def _multi_head_jagged_flash_attention_fwd_kernel(
19
+ q_ptr,
20
+ k_ptr,
21
+ v_ptr,
22
+ offset_ptr,
23
+ o_ptr,
24
+ lse_i_ptr,
25
+ stride_qh,
26
+ stride_qm,
27
+ stride_qd,
28
+ stride_kh,
29
+ stride_kn,
30
+ stride_kd,
31
+ stride_vh,
32
+ stride_vn,
33
+ stride_vd,
34
+ stride_oh,
35
+ stride_om,
36
+ stride_od,
37
+ stride_lse_h,
38
+ num_heads: tl.constexpr,
39
+ max_seq_len: tl.constexpr,
40
+ D: tl.constexpr,
41
+ allow_tf32: tl.constexpr,
42
+ BLOCK_M: tl.constexpr,
43
+ BLOCK_N: tl.constexpr,
44
+ BLOCK_D: tl.constexpr,
45
+ ):
46
+ pid_m = tl.program_id(axis=0)
47
+ pid_bh = tl.program_id(axis=1)
48
+ pid_batch = pid_bh // num_heads
49
+ pid_head = pid_bh % num_heads
50
+
51
+ begin = tl.load(offset_ptr + pid_batch)
52
+ end = tl.load(offset_ptr + pid_batch + 1)
53
+
54
+ seqlen = end - begin
55
+ seqlen = tl.minimum(seqlen, max_seq_len)
56
+
57
+ if pid_m * BLOCK_M >= seqlen:
58
+ return
59
+
60
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
61
+ offs_d = tl.arange(0, BLOCK_D)
62
+
63
+ acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
64
+
65
+ mi = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
66
+ li = tl.zeros([BLOCK_M], dtype=tl.float32)
67
+ for j in range(0, seqlen, BLOCK_N):
68
+ offs_n = tl.arange(0, BLOCK_N) + j
69
+ q_ptrs = (
70
+ q_ptr
71
+ + pid_head * stride_qh
72
+ + begin * stride_qm
73
+ + (offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd)
74
+ )
75
+
76
+ k_ptrs = (
77
+ k_ptr
78
+ + pid_head * stride_kh
79
+ + begin * stride_kn
80
+ + (offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd)
81
+ )
82
+
83
+ qk = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
84
+
85
+ for d in range(0, D, BLOCK_D):
86
+ curr_d = d + offs_d
87
+
88
+ # Load a block of q into [BLOCK_M, BLOCK_D]
89
+ q = tl.load(
90
+ q_ptrs,
91
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
92
+ mask=((curr_d[None, :] < D) & (offs_m[:, None] < seqlen)),
93
+ other=0.0,
94
+ )
95
+
96
+ # Load a block of k into [BLOCK_D, BLOCK_N]
97
+ k = tl.load(
98
+ k_ptrs,
99
+ mask=((curr_d[:, None] < D) & (offs_n[None, :] < seqlen)),
100
+ other=0.0,
101
+ )
102
+
103
+ # gemm [BLOCK_M, BLOCK_D] x [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
104
+ qk += tl.dot(q, k, allow_tf32=allow_tf32)
105
+
106
+ q_ptrs += BLOCK_D * stride_qd
107
+ k_ptrs += BLOCK_D * stride_kd
108
+
109
+ mi_new = tl.maximum(tl.max(qk, axis=1), mi)
110
+ # Add the correct mask here
111
+ mn_mask = (offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)
112
+
113
+ p = tl.exp(qk - mi_new[:, None])
114
+ p = tl.where(mn_mask, p, 0.0)
115
+
116
+ lij_hat = tl.sum(p, axis=1)
117
+ alpha = tl.exp(mi - mi_new)
118
+
119
+ li = alpha * li + lij_hat
120
+ acc = alpha[:, None] * acc
121
+
122
+ # Load V into block [BLOCK_N, BLOCK_D]
123
+ v_ptrs = (
124
+ v_ptr
125
+ + pid_head * stride_vh
126
+ + begin * stride_vn
127
+ + (offs_d[None, :] * stride_vd + offs_n[:, None] * stride_vn)
128
+ )
129
+ v = tl.load(
130
+ v_ptrs,
131
+ mask=((offs_d[None, :] < D) & (offs_n[:, None] < seqlen)),
132
+ other=0.0,
133
+ )
134
+
135
+ p /= max_seq_len
136
+
137
+ p = p.to(v_ptr.dtype.element_ty)
138
+ # gemm [BLOCK_M, BLOCK_N] x [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D]
139
+ acc += tl.dot(p, v, allow_tf32=allow_tf32)
140
+ mi = mi_new
141
+
142
+ lse_i = mi + tl.math.log(li)
143
+ lse_i_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
144
+ lse_i_ptrs = lse_i_ptr + pid_head * stride_lse_h + begin + lse_i_offsets
145
+
146
+ tl.store(lse_i_ptrs, lse_i, mask=lse_i_offsets < seqlen)
147
+
148
+ acc = acc / li[:, None]
149
+
150
+ # Store O
151
+ o_ptrs = o_ptr + (
152
+ pid_head * stride_oh
153
+ + begin * stride_om
154
+ + offs_m[:, None] * stride_om
155
+ + offs_d[None, :] * stride_od
156
+ )
157
+ o_mask = (offs_m[:, None] < seqlen) & (offs_d[None, :] < D)
158
+ tl.store(o_ptrs, acc, mask=o_mask)
159
+
160
+
161
+ def multi_head_jagged_flash_attention_fwd(
162
+ jagged_Q,
163
+ jagged_K,
164
+ jagged_V,
165
+ offsets,
166
+ max_seq_len,
167
+ allow_tf32=False,
168
+ ):
169
+ assert jagged_Q.size(2) == jagged_K.size(2), "incompatible dimensions"
170
+
171
+ B = offsets.size(0) - 1
172
+ D = jagged_Q.size(2)
173
+ num_heads = jagged_Q.size(0)
174
+
175
+ jagged_O = torch.zeros_like(jagged_Q)
176
+ lse = torch.zeros(
177
+ (num_heads, jagged_Q.size(1)), device=jagged_Q.device, dtype=jagged_Q.dtype
178
+ )
179
+
180
+ BLOCK_M = 32
181
+ BLOCK_N = 32
182
+ BLOCK_D = max(triton.next_power_of_2(D), 16)
183
+
184
+ grid = (triton.cdiv(max_seq_len, BLOCK_M), B * num_heads)
185
+
186
+ _multi_head_jagged_flash_attention_fwd_kernel[grid](
187
+ jagged_Q,
188
+ jagged_K,
189
+ jagged_V,
190
+ offsets,
191
+ jagged_O,
192
+ lse,
193
+ jagged_Q.stride(0),
194
+ jagged_Q.stride(1),
195
+ jagged_Q.stride(2),
196
+ jagged_K.stride(0),
197
+ jagged_K.stride(1),
198
+ jagged_K.stride(2),
199
+ jagged_V.stride(0),
200
+ jagged_V.stride(1),
201
+ jagged_V.stride(2),
202
+ jagged_O.stride(0),
203
+ jagged_O.stride(1),
204
+ jagged_O.stride(2),
205
+ lse.stride(0),
206
+ num_heads,
207
+ max_seq_len,
208
+ D,
209
+ allow_tf32,
210
+ BLOCK_M=BLOCK_M,
211
+ BLOCK_N=BLOCK_N,
212
+ BLOCK_D=BLOCK_D,
213
+ )
214
+
215
+ return jagged_O, lse
216
+
217
+
218
+ @triton.jit
219
+ def _multi_head_jagged_flash_attention_bwd_preprocess_kernel(
220
+ o_ptr,
221
+ o_offset_ptr,
222
+ do_ptr,
223
+ delta_ptr,
224
+ stride_oh,
225
+ stride_om,
226
+ stride_od,
227
+ stride_delta_h,
228
+ num_heads: tl.constexpr,
229
+ max_seq_len: tl.constexpr,
230
+ D: tl.constexpr,
231
+ BLOCK_M: tl.constexpr,
232
+ BLOCK_D: tl.constexpr,
233
+ ):
234
+ pid_m = tl.program_id(axis=0)
235
+ pid_bh = tl.program_id(axis=1)
236
+ pid_batch = pid_bh // num_heads
237
+ pid_head = pid_bh % num_heads
238
+
239
+ begin_o = tl.load(o_offset_ptr + pid_batch)
240
+ end_o = tl.load(o_offset_ptr + pid_batch + 1)
241
+
242
+ M = end_o - begin_o
243
+ M = tl.minimum(M, max_seq_len)
244
+
245
+ if M == 0:
246
+ return
247
+
248
+ offs_om = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
249
+ offs_od = tl.arange(0, BLOCK_D)
250
+
251
+ o_offsets = (
252
+ offs_om[:, None] * stride_om
253
+ + offs_od[None, :] * stride_od
254
+ + pid_head * stride_oh
255
+ + begin_o * stride_om
256
+ )
257
+ o_ptrs = o_ptr + o_offsets
258
+ do_ptrs = do_ptr + o_offsets
259
+ o_mask = (offs_om[:, None] < M) & (offs_od[None, :] < D)
260
+
261
+ # Load o and do
262
+ o = tl.load(o_ptrs, mask=o_mask)
263
+ do = tl.load(do_ptrs, mask=o_mask)
264
+
265
+ delta = tl.sum(o * do, axis=1)
266
+
267
+ tl.store(
268
+ delta_ptr + pid_head * stride_delta_h + begin_o + offs_om,
269
+ delta,
270
+ mask=offs_om < M,
271
+ )
272
+
273
+
274
+ @triton.jit
275
+ def _multi_head_jagged_flash_attention_bwd_kernel(
276
+ q_ptr,
277
+ k_ptr,
278
+ v_ptr,
279
+ o_ptr,
280
+ offset_ptr,
281
+ dq_ptr,
282
+ dk_ptr,
283
+ dv_ptr,
284
+ do_ptr,
285
+ delta_ptr,
286
+ lse_ptr,
287
+ stride_qh,
288
+ stride_qm,
289
+ stride_qd,
290
+ stride_kh,
291
+ stride_kn,
292
+ stride_kd,
293
+ stride_vh,
294
+ stride_vn,
295
+ stride_vd,
296
+ stride_oh,
297
+ stride_om,
298
+ stride_od,
299
+ stride_lse_h,
300
+ stride_delta_h,
301
+ stride_dq_h,
302
+ stride_dq_m,
303
+ stride_dq_d,
304
+ stride_dk_h,
305
+ stride_dk_n,
306
+ stride_dk_d,
307
+ stride_dv_h,
308
+ stride_dv_n,
309
+ stride_dv_d,
310
+ stride_do_h,
311
+ stride_do_m,
312
+ stride_do_d,
313
+ num_heads: tl.constexpr,
314
+ max_seq_len: tl.constexpr,
315
+ D: tl.constexpr,
316
+ allow_tf32: tl.constexpr,
317
+ BLOCK_M: tl.constexpr,
318
+ BLOCK_N: tl.constexpr,
319
+ BLOCK_D: tl.constexpr,
320
+ ):
321
+ pid_bh = tl.program_id(axis=1)
322
+ pid_batch = pid_bh // num_heads
323
+ pid_head = pid_bh % num_heads
324
+
325
+ begin = tl.load(offset_ptr + pid_batch)
326
+ end = tl.load(offset_ptr + pid_batch + 1)
327
+
328
+ seqlen = tl.minimum(end - begin, max_seq_len)
329
+
330
+ if seqlen == 0:
331
+ return
332
+
333
+ pid_n = tl.program_id(axis=0)
334
+ offs_d = tl.arange(0, BLOCK_D)
335
+
336
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
337
+ offs_m = tl.arange(0, BLOCK_M)
338
+
339
+ q_ptrs = (
340
+ q_ptr
341
+ + pid_head * stride_qh
342
+ + begin * stride_qm
343
+ + (offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd)
344
+ )
345
+
346
+ k_ptrs = (
347
+ k_ptr
348
+ + pid_head * stride_kh
349
+ + begin * stride_kn
350
+ + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
351
+ )
352
+
353
+ v_ptrs = (
354
+ v_ptr
355
+ + pid_head * stride_vh
356
+ + begin * stride_vn
357
+ + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
358
+ )
359
+
360
+ do_ptrs = (
361
+ do_ptr
362
+ + pid_head * stride_do_h
363
+ + begin * stride_do_m
364
+ + (offs_m[:, None] * stride_do_m + offs_d[None, :] * stride_do_d)
365
+ )
366
+
367
+ # Load a block of K into [BLOCK_N, BLOCK_D]
368
+ k = tl.load(
369
+ k_ptrs, mask=((offs_d[None, :] < D) & (offs_n[:, None] < seqlen)), other=0.0
370
+ )
371
+ # Load a block of V into [BLOCK_N, BLOCK_D]
372
+ v = tl.load(
373
+ v_ptrs, mask=((offs_d[None, :] < D) & (offs_n[:, None] < seqlen)), other=0.0
374
+ )
375
+
376
+ # Initialize dv and dk
377
+ dv = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float32)
378
+ dk = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float32)
379
+
380
+ for begin_m in range(0, seqlen, BLOCK_M):
381
+ offs_m_curr = begin_m + offs_m
382
+
383
+ # Load a block of Q into [BLOCK_M, BLOCK_D]
384
+ q = tl.load(
385
+ q_ptrs,
386
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
387
+ mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < seqlen)),
388
+ other=0.0,
389
+ )
390
+ # gemm [BLOCK_M, BLOCK_D] x [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
391
+ qk = tl.dot(q, tl.trans(k), allow_tf32=allow_tf32)
392
+
393
+ mn_mask = (offs_m_curr[:, None] < seqlen) & (offs_n[None, :] < seqlen)
394
+
395
+ # Load a block of lse_i into [BLOCK_M]
396
+ lse_i = tl.load(
397
+ lse_ptr + pid_head * stride_lse_h + begin + offs_m_curr,
398
+ mask=offs_m_curr < seqlen,
399
+ other=float("inf"),
400
+ )
401
+
402
+ p = tl.exp(qk - lse_i[:, None])
403
+ p = tl.where(mn_mask, p, 0.0)
404
+ p /= max_seq_len
405
+
406
+ p = p.to(do_ptr.dtype.element_ty)
407
+ do = tl.load(
408
+ do_ptrs,
409
+ mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < seqlen)),
410
+ other=0.0,
411
+ )
412
+
413
+ # gemm [BLOCK_N, BLOCK_M] x [BLOCK_M, BLOCK_D] -> [BLOCK_N, BLOCK_D]
414
+ dv += tl.dot(tl.trans(p), do, allow_tf32=allow_tf32)
415
+ # gemm [BLOCK_M, BLOCK_D] x [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
416
+ dp = tl.dot(do, tl.trans(v), allow_tf32=allow_tf32)
417
+
418
+ # compute ds = p * (dp - delta[:, None])
419
+ Di = tl.load(
420
+ delta_ptr + pid_head * stride_delta_h + begin + offs_m_curr,
421
+ mask=offs_m_curr < seqlen,
422
+ )
423
+ ds = p * (dp - Di[:, None] * max_seq_len)
424
+
425
+ # compute dk = dot(ds.T, q)
426
+ ds = ds.to(q_ptr.dtype.element_ty)
427
+ # gemm [BLOCK_N, BLOCK_M] x [BLOCK_M, BLOCK_D] -> [BLOCK_N, BLOCK_D]
428
+ dk += tl.dot(tl.trans(ds), q, allow_tf32=allow_tf32)
429
+
430
+ q_ptrs += BLOCK_M * stride_qm
431
+ do_ptrs += BLOCK_M * stride_do_m
432
+
433
+ # store back dk and dv
434
+ dk_ptrs = (
435
+ dk_ptr
436
+ + pid_head * stride_dk_h
437
+ + begin * stride_dk_n
438
+ + (offs_n[:, None] * stride_dk_n + offs_d[None, :] * stride_dk_d)
439
+ )
440
+
441
+ dv_ptrs = (
442
+ dv_ptr
443
+ + pid_head * stride_dv_h
444
+ + begin * stride_dv_n
445
+ + (offs_n[:, None] * stride_dv_n + offs_d[None, :] * stride_dv_d)
446
+ )
447
+
448
+ tl.store(dk_ptrs, dk, mask=((offs_d[None, :] < D) & (offs_n[:, None] < seqlen)))
449
+ tl.store(dv_ptrs, dv, mask=((offs_d[None, :] < D) & (offs_n[:, None] < seqlen)))
450
+
451
+ # Start to compute dq
452
+
453
+ start_m = tl.program_id(axis=0) * BLOCK_M
454
+ offs_m_curr = start_m + tl.arange(0, BLOCK_M)
455
+
456
+ dq_ptrs_curr = (
457
+ dq_ptr
458
+ + pid_head * stride_dq_h
459
+ + begin * stride_dq_m
460
+ + (offs_m_curr[:, None] * stride_dq_m + offs_d[None, :] * stride_dq_d)
461
+ )
462
+
463
+ dq_curr = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
464
+
465
+ q_ptrs_curr = (
466
+ q_ptr
467
+ + pid_head * stride_qh
468
+ + begin * stride_qm
469
+ + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
470
+ )
471
+
472
+ q_curr = tl.load(
473
+ q_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < seqlen))
474
+ )
475
+
476
+ # Load a block of lse_i into [BLOCK_M]
477
+ lse_i_curr = tl.load(
478
+ lse_ptr + pid_head * stride_lse_h + begin + offs_m_curr,
479
+ mask=offs_m_curr < seqlen,
480
+ )
481
+
482
+ do_ptrs_curr = (
483
+ do_ptr
484
+ + pid_head * stride_do_h
485
+ + begin * stride_do_m
486
+ + (offs_m_curr[:, None] * stride_do_m + offs_d[None, :] * stride_do_d)
487
+ )
488
+
489
+ # Load do
490
+ do_curr = tl.load(
491
+ do_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < seqlen))
492
+ )
493
+ Di_curr = tl.load(
494
+ delta_ptr + pid_head * stride_delta_h + begin + offs_m_curr,
495
+ mask=offs_m_curr < seqlen,
496
+ )
497
+
498
+ block_start = 0
499
+ while block_start < seqlen:
500
+ offs_n_curr = block_start + tl.arange(0, BLOCK_N)
501
+
502
+ k_ptrs_curr = (
503
+ k_ptr
504
+ + pid_head * stride_kh
505
+ + begin * stride_kn
506
+ + (offs_n_curr[:, None] * stride_kn + offs_d[None, :] * stride_kd)
507
+ )
508
+ v_ptrs_curr = (
509
+ v_ptr
510
+ + pid_head * stride_vh
511
+ + begin * stride_vn
512
+ + (offs_n_curr[:, None] * stride_vn + offs_d[None, :] * stride_vd)
513
+ )
514
+
515
+ k_curr = tl.load(
516
+ k_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_n_curr[:, None] < seqlen))
517
+ )
518
+ v_curr = tl.load(
519
+ v_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_n_curr[:, None] < seqlen))
520
+ )
521
+
522
+ # gemm [BLOCK_M, BLOCK_D] x [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
523
+ qk_curr = tl.dot(q_curr, tl.trans(k_curr), allow_tf32=allow_tf32)
524
+ mn_mask_curr = (offs_m_curr[:, None] < seqlen) & (offs_n_curr[None, :] < seqlen)
525
+
526
+ # Perform softmax
527
+ p_curr = tl.exp(qk_curr - lse_i_curr[:, None])
528
+ p_curr = tl.where(mn_mask_curr, p_curr, 0.0)
529
+ p_curr /= max_seq_len
530
+
531
+ # compute dp = dot(do, v.T)
532
+ # gemm [BLOCK_M, BLOCK_D] x [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
533
+ dp_curr = tl.dot(do_curr, tl.trans(v_curr), allow_tf32=allow_tf32)
534
+
535
+ # compute ds = p * (dp - delta[:, None])
536
+ ds_curr = p_curr * (dp_curr - Di_curr[:, None] * max_seq_len)
537
+
538
+ ds_curr = ds_curr.to(k_ptr.dtype.element_ty)
539
+ # compute dq = dot(ds, k)
540
+ # gemm [BLOCK_M, BLOCK_N] x [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D]
541
+ dq_curr += tl.dot(ds_curr, k_curr, allow_tf32=allow_tf32)
542
+ block_start += BLOCK_N
543
+
544
+ tl.store(
545
+ dq_ptrs_curr,
546
+ dq_curr,
547
+ mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < seqlen)),
548
+ )
549
+
550
+
551
+ def multi_head_jagged_flash_attention_bwd(
552
+ jagged_Q,
553
+ jagged_K,
554
+ jagged_V,
555
+ jagged_O,
556
+ offsets,
557
+ dO,
558
+ lse,
559
+ max_seq_len,
560
+ allow_tf32=False,
561
+ ):
562
+ BLOCK_M = 32
563
+ BLOCK_N = 32
564
+
565
+ B = offsets.size(0) - 1
566
+ num_heads = jagged_Q.size(0)
567
+ D = jagged_Q.size(2)
568
+
569
+ num_blocks_m = triton.cdiv(max_seq_len, BLOCK_M)
570
+ pre_grid = (num_blocks_m, B * num_heads)
571
+
572
+ # Triton requires the block size to be at least 16
573
+ BLOCK_D = max(triton.next_power_of_2(D), 16)
574
+
575
+ delta = torch.empty_like(lse)
576
+ if not dO.is_contiguous():
577
+ dO = dO.contiguous()
578
+
579
+ _multi_head_jagged_flash_attention_bwd_preprocess_kernel[pre_grid](
580
+ jagged_O,
581
+ offsets,
582
+ dO,
583
+ delta,
584
+ jagged_O.stride(0),
585
+ jagged_O.stride(1),
586
+ jagged_O.stride(2),
587
+ delta.stride(0),
588
+ num_heads,
589
+ max_seq_len,
590
+ D,
591
+ BLOCK_M,
592
+ BLOCK_D,
593
+ )
594
+
595
+ grid = (triton.cdiv(max_seq_len, BLOCK_N), B * num_heads)
596
+
597
+ dq = torch.zeros_like(jagged_Q)
598
+ dk = torch.zeros_like(jagged_K)
599
+ dv = torch.zeros_like(jagged_V)
600
+
601
+ _multi_head_jagged_flash_attention_bwd_kernel[grid](
602
+ jagged_Q,
603
+ jagged_K,
604
+ jagged_V,
605
+ jagged_O,
606
+ offsets,
607
+ dq,
608
+ dk,
609
+ dv,
610
+ dO,
611
+ delta,
612
+ lse,
613
+ jagged_Q.stride(0),
614
+ jagged_Q.stride(1),
615
+ jagged_Q.stride(2),
616
+ jagged_K.stride(0),
617
+ jagged_K.stride(1),
618
+ jagged_K.stride(2),
619
+ jagged_V.stride(0),
620
+ jagged_V.stride(1),
621
+ jagged_V.stride(2),
622
+ jagged_O.stride(0),
623
+ jagged_O.stride(1),
624
+ jagged_O.stride(2),
625
+ lse.stride(0),
626
+ delta.stride(0),
627
+ dq.stride(0),
628
+ dq.stride(1),
629
+ dq.stride(2),
630
+ dk.stride(0),
631
+ dk.stride(1),
632
+ dk.stride(2),
633
+ dv.stride(0),
634
+ dv.stride(1),
635
+ dv.stride(2),
636
+ dO.stride(0),
637
+ dO.stride(1),
638
+ dO.stride(2),
639
+ num_heads,
640
+ max_seq_len,
641
+ D,
642
+ allow_tf32=allow_tf32,
643
+ BLOCK_M=BLOCK_M,
644
+ BLOCK_N=BLOCK_N,
645
+ BLOCK_D=BLOCK_D,
646
+ )
647
+
648
+ return dq, dk, dv
649
+
650
+
651
+ class MultiHeadJaggedFlashAttention(torch.autograd.Function):
652
+ @staticmethod
653
+ # pyre-fixme
654
+ def forward(
655
+ ctx,
656
+ jagged_Q: torch.Tensor,
657
+ jagged_K: torch.Tensor,
658
+ jagged_V: torch.Tensor,
659
+ offsets: torch.Tensor,
660
+ max_seq_len: int,
661
+ allow_tf32: bool = True,
662
+ ) -> torch.Tensor:
663
+ ctx.allow_tf32 = allow_tf32
664
+ ctx.max_seq_len = max_seq_len
665
+
666
+ jagged_O, lse = multi_head_jagged_flash_attention_fwd(
667
+ jagged_Q,
668
+ jagged_K,
669
+ jagged_V,
670
+ offsets,
671
+ max_seq_len,
672
+ allow_tf32,
673
+ )
674
+
675
+ ctx.save_for_backward(
676
+ jagged_Q,
677
+ jagged_K,
678
+ jagged_V,
679
+ offsets,
680
+ jagged_O,
681
+ lse,
682
+ )
683
+
684
+ return jagged_O
685
+
686
+ @staticmethod
687
+ # pyre-fixme
688
+ def backward(
689
+ ctx, grad_output: torch.Tensor
690
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]:
691
+ (
692
+ jagged_Q,
693
+ jagged_K,
694
+ jagged_V,
695
+ offsets,
696
+ jagged_O,
697
+ lse,
698
+ ) = ctx.saved_tensors
699
+
700
+ dq, dk, dv = multi_head_jagged_flash_attention_bwd(
701
+ jagged_Q=jagged_Q,
702
+ jagged_K=jagged_K,
703
+ jagged_V=jagged_V,
704
+ jagged_O=jagged_O,
705
+ offsets=offsets,
706
+ dO=grad_output,
707
+ lse=lse,
708
+ max_seq_len=ctx.max_seq_len,
709
+ allow_tf32=ctx.allow_tf32,
710
+ )
711
+
712
+ return (
713
+ dq,
714
+ dk,
715
+ dv,
716
+ None,
717
+ None,
718
+ None,
719
+ )
720
+
721
+
722
+ def multi_head_jagged_flash_attention(
723
+ q_weights: torch.Tensor,
724
+ k_weights: torch.Tensor,
725
+ v_weights: torch.Tensor,
726
+ offsets: torch.Tensor,
727
+ max_seq_len: int,
728
+ allow_tf32: bool = True,
729
+ ) -> torch.Tensor:
730
+ """
731
+ q_weights: jagged tensor with size [H, sum_B, D]
732
+ k_weights: jagged tensor with size [H, sum_B, D]
733
+ v_weights: jagged tensor with size [H, sum_B, D]
734
+ offsets: offsets for jagged tensor, with size [B + 1]
735
+ max_seq_len: max sequence length
736
+ """
737
+ q_weights = expect_contiguous(q_weights)
738
+ k_weights = expect_contiguous(k_weights)
739
+ v_weights = expect_contiguous(v_weights)
740
+ offsets = expect_contiguous(offsets)
741
+
742
+ jagged_O = MultiHeadJaggedFlashAttention.apply(
743
+ q_weights,
744
+ k_weights,
745
+ v_weights,
746
+ offsets,
747
+ max_seq_len,
748
+ allow_tf32,
749
+ )
750
+
751
+ return jagged_O