fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,861 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from .common import expect_contiguous
15
+
16
+
17
+ @triton.jit
18
+ def jagged_dense_flash_attention_fwd_kernel(
19
+ q_ptr,
20
+ k_ptr,
21
+ v_ptr,
22
+ ab_ptr, # attn bias ptr
23
+ o_ptr,
24
+ lse_ptr,
25
+ jagged_offsets_ptr,
26
+ max_seq_len,
27
+ stride_ql,
28
+ stride_qd,
29
+ stride_kb,
30
+ stride_kd,
31
+ stride_kt,
32
+ stride_vn,
33
+ stride_vd,
34
+ stride_ab_b, # attn bias stride batch
35
+ stride_ab_n,
36
+ stride_ab_t,
37
+ stride_ob,
38
+ stride_ot,
39
+ stride_od,
40
+ D: tl.constexpr,
41
+ T: tl.constexpr,
42
+ allow_tf32: tl.constexpr,
43
+ BLOCK_T: tl.constexpr,
44
+ BLOCK_L: tl.constexpr,
45
+ BLOCK_D: tl.constexpr,
46
+ ):
47
+ pid_t = tl.program_id(0)
48
+ pid_batch = tl.program_id(1)
49
+
50
+ # begin offset of the current sample
51
+ begin = tl.load(jagged_offsets_ptr + pid_batch)
52
+ # end offset of the current sample
53
+ end = tl.load(jagged_offsets_ptr + pid_batch + 1)
54
+
55
+ # The seq length of the current sample
56
+ length = end - begin
57
+ length = tl.minimum(length, max_seq_len)
58
+
59
+ if length == 0:
60
+ return
61
+
62
+ q_start_ptr = q_ptr + begin * stride_ql
63
+ k_start_ptr = k_ptr + pid_batch * stride_kb
64
+ ab_start_ptr = ab_ptr + pid_batch * stride_ab_b
65
+ v_start_ptr = v_ptr + begin * stride_vn
66
+
67
+ offs_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)
68
+ offs_d = tl.arange(0, BLOCK_D)
69
+
70
+ # Load a block of Q into [BLOCK_D, BLOCK_T]
71
+ ki_ptrs = k_start_ptr + offs_d[:, None] * stride_kd + offs_t[None, :] * stride_kt
72
+
73
+ ki = tl.load(
74
+ ki_ptrs,
75
+ mask=((offs_d[:, None] < D) & (offs_t[None, :] < T)),
76
+ other=0.0,
77
+ )
78
+
79
+ mi = tl.zeros([BLOCK_T], dtype=tl.float32) - float("inf")
80
+ li = tl.zeros([BLOCK_T], dtype=tl.float32)
81
+ oi = tl.zeros([BLOCK_T, BLOCK_D], dtype=tl.float32)
82
+
83
+ # Loop through the seq length dimension
84
+ for start_l in range(0, length, BLOCK_L):
85
+ offs_l = start_l + tl.arange(0, BLOCK_L)
86
+
87
+ # Load a block of K into [BLOCK_L, BLOCK_D]
88
+ qj_ptrs = (
89
+ q_start_ptr
90
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
91
+ + offs_l[:, None] * stride_ql
92
+ + offs_d[None, :] * stride_qd
93
+ )
94
+
95
+ qj = tl.load(
96
+ qj_ptrs,
97
+ mask=((offs_l[:, None] < length) & (offs_d[None, :] < D)),
98
+ other=0.0,
99
+ )
100
+
101
+ # gemm [BLOCK_L, BLOCK_D] x [BLOCK_D, BLOCK_T] = [BLOCK_L, BLOCK_T]
102
+ qk = tl.dot(qj, ki, allow_tf32=allow_tf32)
103
+
104
+ # Load a block of attn bias into [BLOCK_L, BLOCK_T]
105
+ ab_ptrs = (
106
+ ab_start_ptr + offs_l[:, None] * stride_ab_n + offs_t[None, :] * stride_ab_t
107
+ )
108
+
109
+ abij = tl.load(
110
+ ab_ptrs,
111
+ mask=((offs_l[:, None] < length) & (offs_t[None, :] < T)),
112
+ other=0.0,
113
+ )
114
+
115
+ # q*k output + attn bias
116
+ qk = qk + abij
117
+
118
+ # Note: softmax on axis 0
119
+ mij_hat = tl.max(qk, axis=0)
120
+ mi_new = tl.maximum(mi, mij_hat)
121
+ pij_hat = tl.exp(qk - mi_new[None, :])
122
+ pij_hat = tl.where(
123
+ (offs_l[:, None] < length) & (offs_t[None, :] < T), pij_hat, 0.0
124
+ )
125
+ lij_hat = tl.sum(pij_hat, axis=0)
126
+ alpha = tl.exp(mi - mi_new)
127
+ li_new = alpha * li + lij_hat
128
+ oi = alpha[:, None] * oi
129
+
130
+ # Load a block of V into [BLOCK_L, BLOCK_D]
131
+ vj_ptrs = (
132
+ v_start_ptr + offs_l[:, None] * stride_vn + offs_d[None, :] * stride_vd
133
+ )
134
+
135
+ vj = tl.load(
136
+ vj_ptrs,
137
+ mask=((offs_l[:, None] < length) & (offs_d[None, :] < D)),
138
+ other=0.0,
139
+ )
140
+
141
+ pij_hat = pij_hat.to(v_ptr.dtype.element_ty)
142
+ # gemm [BLOCK_T, BLOCK_L] x [BLOCK_L, BLOCK_D] = [BLOCK_T, BLOCK_D]
143
+ oi = oi + tl.dot(tl.trans(pij_hat), vj, allow_tf32=allow_tf32)
144
+
145
+ mi = mi_new
146
+ li = li_new
147
+
148
+ oi = oi / li[:, None]
149
+
150
+ lse_ptrs = lse_ptr + pid_batch * T + offs_t
151
+ # Save both mi and li to avoid recomputation in backward
152
+ lse_i = mi + tl.log(li)
153
+ tl.store(lse_ptrs, lse_i, mask=(offs_t < T))
154
+
155
+ # Write the output [BLOCK_T, BLOCK_D]
156
+ attn_out_ptrs = (
157
+ o_ptr
158
+ + pid_batch * stride_ob
159
+ + offs_t[:, None] * stride_ot
160
+ + offs_d[None, :] * stride_od
161
+ )
162
+ tl.store(attn_out_ptrs, oi, mask=((offs_t[:, None] < T) & (offs_d[None, :] < D)))
163
+
164
+
165
+ def jagged_dense_flash_attention_fwd(
166
+ Q,
167
+ K,
168
+ V,
169
+ attn_bias,
170
+ jagged_offsets,
171
+ max_seq_len,
172
+ allow_tf32=False,
173
+ ) -> tuple[torch.Tensor, torch.Tensor]:
174
+ """
175
+ Q: jagged tensor, [sum_B, D]
176
+ K: dense tensor, [B, D, T]
177
+ V: jagged tensor [sum_B, D]
178
+ attn_bias: dense tensor [B, N, T]
179
+ out: dense tenros: [B, T, D]
180
+
181
+ Attention steps:
182
+ 1. Q * K: [sum_B, D] * [B, D, T] = [sum_B, T]
183
+ 2. softmax_input = Q * K + attn_bias
184
+ [sum_B, T] + [B, N, T] = [sum_B, T]
185
+ 3. softmax_out = softmax(softmax_input):
186
+ softmax([sum_B, T]) = [sum_B, T]
187
+ 4. softmax_out * V:
188
+ [sum_B, T] * [sum_B, D] = [B, T, D]
189
+ """
190
+ assert Q.size(1) == K.size(1), "incompatible dimensions for Q and K"
191
+ assert Q.size() == V.size(), "incompatible dimensions for Q and V"
192
+ assert jagged_offsets.is_contiguous(), "jagged_offsets must be contiguous"
193
+
194
+ (B, D, T) = K.size()
195
+ assert D > 0 and (D & (D - 1)) == 0, "D needs to be a power of two"
196
+
197
+ attn_out = torch.zeros(B, T, D, dtype=Q.dtype, device=Q.device)
198
+ lse = torch.empty((B, T), dtype=K.dtype, device=K.device)
199
+
200
+ BLOCK_T = 32
201
+ BLOCK_L = 32
202
+ BLOCK_D = D
203
+
204
+ num_blocks_t = triton.cdiv(T, BLOCK_T)
205
+ grid = (num_blocks_t, B)
206
+
207
+ jagged_dense_flash_attention_fwd_kernel[grid](
208
+ Q,
209
+ K,
210
+ V,
211
+ attn_bias,
212
+ attn_out,
213
+ lse,
214
+ jagged_offsets,
215
+ max_seq_len,
216
+ Q.stride(0),
217
+ Q.stride(1),
218
+ K.stride(0),
219
+ K.stride(1),
220
+ K.stride(2),
221
+ V.stride(0),
222
+ V.stride(1),
223
+ attn_bias.stride(0),
224
+ attn_bias.stride(1),
225
+ attn_bias.stride(2),
226
+ attn_out.stride(0),
227
+ attn_out.stride(1),
228
+ attn_out.stride(2),
229
+ D,
230
+ T,
231
+ allow_tf32,
232
+ BLOCK_T, # pyre-ignore
233
+ BLOCK_L, # pyre-ignore
234
+ BLOCK_D,
235
+ )
236
+
237
+ return attn_out, lse
238
+
239
+
240
+ @triton.jit
241
+ def _bwd_preprocess_do_o_dot(
242
+ o_ptr,
243
+ do_ptr,
244
+ delta_ptr,
245
+ T,
246
+ stride_ob,
247
+ stride_ot,
248
+ stride_od,
249
+ stride_do_b,
250
+ stride_do_t,
251
+ stride_do_d,
252
+ BLOCK_T: tl.constexpr,
253
+ BLOCK_D: tl.constexpr,
254
+ ):
255
+ start_t = tl.program_id(0)
256
+ offs_t = start_t * BLOCK_T + tl.arange(0, BLOCK_T)
257
+ pid_b = tl.program_id(1)
258
+ offs_d = tl.arange(0, BLOCK_D)
259
+
260
+ o_ptrs = (
261
+ o_ptr
262
+ + pid_b * stride_ob
263
+ + offs_t[:, None] * stride_ot
264
+ + offs_d[None, :] * stride_od
265
+ )
266
+ do_ptrs = (
267
+ do_ptr
268
+ + pid_b * stride_do_b
269
+ + offs_t[:, None] * stride_do_t
270
+ + offs_d[None, :] * stride_do_d
271
+ )
272
+ o = tl.load(o_ptrs, mask=(offs_t[:, None] < T), other=0.0)
273
+ do = tl.load(do_ptrs, mask=(offs_t[:, None] < T), other=0.0)
274
+ delta = tl.sum(o * do, axis=1)
275
+
276
+ delta_ptrs = delta_ptr + pid_b * T + offs_t
277
+ tl.store(delta_ptrs, delta, mask=(offs_t < T))
278
+
279
+
280
+ @triton.jit
281
+ def _jagged_dense_flash_attention_bwd_dv_db_dq_kernel(
282
+ q_ptr,
283
+ k_ptr,
284
+ v_ptr,
285
+ ab_ptr, # attn bias
286
+ jagged_offsets_ptr,
287
+ out_ptr,
288
+ do_ptr,
289
+ lse_ptr,
290
+ delta_ptr,
291
+ dq_ptr,
292
+ dk_ptr,
293
+ dv_ptr,
294
+ dbias_ptr,
295
+ max_seq_len,
296
+ stride_ql,
297
+ stride_qd,
298
+ stride_kb,
299
+ stride_kd,
300
+ stride_kt,
301
+ stride_vl,
302
+ stride_vd,
303
+ stride_ab_b, # attn bias stride batch
304
+ stride_ab_l,
305
+ stride_ab_t,
306
+ stride_ob,
307
+ stride_ot,
308
+ stride_od,
309
+ stride_dq_l,
310
+ stride_dq_d,
311
+ stride_dv_l,
312
+ stride_dv_d,
313
+ stride_db_b,
314
+ stride_db_l,
315
+ stride_db_t,
316
+ stride_do_b,
317
+ stride_do_t,
318
+ stride_do_d,
319
+ T: tl.constexpr,
320
+ BLOCK_T: tl.constexpr,
321
+ BLOCK_L: tl.constexpr,
322
+ BLOCK_D: tl.constexpr,
323
+ allow_tf32: tl.constexpr,
324
+ ):
325
+ pid_l = tl.program_id(0)
326
+ pid_b = tl.program_id(1)
327
+ # begin offset of the current sample
328
+ begin = tl.load(jagged_offsets_ptr + pid_b)
329
+ # end offset of the current sample
330
+ end = tl.load(jagged_offsets_ptr + pid_b + 1)
331
+
332
+ # The seq length of the current sample
333
+ seqlen = end - begin
334
+ seqlen = tl.minimum(seqlen, max_seq_len)
335
+
336
+ if seqlen == 0:
337
+ return
338
+
339
+ q_start_ptr = q_ptr + begin * stride_ql
340
+ k_start_ptr = k_ptr + pid_b * stride_kb
341
+ ab_start_ptr = ab_ptr + pid_b * stride_ab_b
342
+ v_start_ptr = v_ptr + begin * stride_vl
343
+ do_start_ptr = do_ptr + pid_b * stride_do_b
344
+ dq_start_ptr = dq_ptr + begin * stride_dq_l
345
+ dv_start_ptr = dv_ptr + begin * stride_dv_l
346
+ dbias_start_ptr = dbias_ptr + pid_b * stride_db_b
347
+ delta_ptrs = delta_ptr + pid_b * T
348
+ lse_ptrs = lse_ptr + pid_b * T
349
+
350
+ start_l = pid_l * BLOCK_L
351
+ offs_l_curr = start_l + tl.arange(0, BLOCK_L)
352
+ offs_d = tl.arange(0, BLOCK_D)
353
+ offs_t = tl.arange(0, BLOCK_T)
354
+
355
+ q_ptrs = (
356
+ q_start_ptr + offs_l_curr[:, None] * stride_ql + offs_d[None, :] * stride_qd
357
+ )
358
+ k_ptrs = k_start_ptr + offs_d[:, None] * stride_kd + offs_t[None, :] * stride_kt
359
+ v_ptrs = (
360
+ v_start_ptr + offs_l_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd
361
+ )
362
+
363
+ do_ptrs = (
364
+ do_start_ptr + offs_t[:, None] * stride_do_t + offs_d[None, :] * stride_do_d
365
+ )
366
+
367
+ dq = tl.zeros([BLOCK_L, BLOCK_D], dtype=tl.float32)
368
+ dv = tl.zeros([BLOCK_L, BLOCK_D], dtype=tl.float32)
369
+
370
+ # Load a block of q into [BLOCK_L, BLOCK_D]
371
+ q = tl.load(
372
+ q_ptrs,
373
+ mask=((offs_l_curr[:, None] < seqlen) & (offs_d[None, :] < BLOCK_D)),
374
+ other=0.0,
375
+ )
376
+ v = tl.load(v_ptrs, mask=(offs_l_curr[:, None] < seqlen), other=0.0)
377
+
378
+ # for start_t in range(0, T, BLOCK_T):
379
+ start_t = 0
380
+ while start_t < T:
381
+ offs_t_curr = start_t + tl.arange(0, BLOCK_T)
382
+
383
+ # Load a block of k into [BLOCK_D, BLOCK_T]
384
+ k = tl.load(
385
+ k_ptrs,
386
+ mask=((offs_t_curr[None, :] < T) & (offs_d[:, None] < BLOCK_D)),
387
+ other=0.0,
388
+ )
389
+ qk = tl.zeros([BLOCK_L, BLOCK_T], dtype=tl.float32)
390
+
391
+ # gemm [BLOCK_L, BLOCK_D] x [BLOCK_D, BLOCK_T] -> [BLOCK_L, BLOCK_T]
392
+ qk += tl.dot(q, k, allow_tf32=allow_tf32)
393
+
394
+ ab_ptrs = (
395
+ ab_start_ptr
396
+ + offs_l_curr[:, None] * stride_ab_l
397
+ + offs_t_curr[None, :] * stride_ab_t
398
+ )
399
+
400
+ ab = tl.load(
401
+ ab_ptrs,
402
+ mask=((offs_l_curr[:, None] < seqlen) & (offs_t_curr[None, :] < T)),
403
+ other=0.0,
404
+ )
405
+
406
+ # q*k output + attn bias
407
+ qk = qk + ab
408
+
409
+ # Mask out invalid positions for softmax
410
+ qk_mask = (offs_l_curr[:, None] < seqlen) & (offs_t_curr[None, :] < T)
411
+ qk = tl.where(qk_mask, qk, float("-inf"))
412
+
413
+ lse_t = tl.load(
414
+ lse_ptrs + offs_t_curr, mask=(offs_t_curr < T), other=float("inf")
415
+ )
416
+ # Perform softmax
417
+ p = tl.exp(qk - lse_t[None, :])
418
+ p = tl.where(qk_mask, p, 0.0)
419
+
420
+ # Compute dv
421
+ # Load a block of do into [BLOCK_T, BLOCK_D]
422
+ do = tl.load(do_ptrs, mask=(offs_t_curr[:, None] < T), other=0.0)
423
+
424
+ # gemm [BLOCK_L, BLOCK_T] x [BLOCK_T, BLOCK_D] -> [BLOCK_L, BLOCK_D]
425
+ dv += tl.dot(p, do, allow_tf32=allow_tf32)
426
+
427
+ # Compute dp
428
+ delta = tl.load(delta_ptrs + offs_t_curr, mask=(offs_t_curr < T))
429
+ dp = tl.zeros([BLOCK_L, BLOCK_T], dtype=tl.float32)
430
+
431
+ # gemm [BLOCK_T, BLOCK_D] x [BLOCK_D, BLOCK_L] = [BLOCK_T, BLOCK_L]
432
+ # [BLOCK_T, BLOCK_L]^T -> [BLOCK_L, BLOCK_T]
433
+ dp += tl.trans(tl.dot(do, tl.trans(v), allow_tf32=allow_tf32))
434
+
435
+ # Compute ds = p * (dp - delta)
436
+ ds = p * (dp - delta[None, :])
437
+
438
+ # Save dbias = ds
439
+ dbias_ptrs = (
440
+ dbias_start_ptr
441
+ + offs_l_curr[:, None] * stride_db_l
442
+ + offs_t_curr[None, :] * stride_db_t
443
+ )
444
+ tl.store(
445
+ dbias_ptrs,
446
+ ds,
447
+ mask=((offs_l_curr[:, None] < seqlen) & (offs_t_curr[None, :] < T)),
448
+ )
449
+
450
+ # Compute dq
451
+ # gemm [BLOCK_L, BLOCK_T] x [BLOCK_T, BLOCK_D] -> [BLOCK_L, BLOCK_D]
452
+ dq += tl.dot(ds, tl.trans(k), allow_tf32=allow_tf32)
453
+
454
+ k_ptrs += BLOCK_T * stride_kt
455
+ do_ptrs += BLOCK_T * stride_do_t
456
+ start_t += BLOCK_T
457
+
458
+ dq_ptrs = (
459
+ dq_start_ptr
460
+ + offs_l_curr[:, None] * stride_dq_l
461
+ + offs_d[None, :] * stride_dq_d
462
+ )
463
+ dv_ptrs = (
464
+ dv_start_ptr
465
+ + offs_l_curr[:, None] * stride_dv_l
466
+ + offs_d[None, :] * stride_dv_d
467
+ )
468
+ tl.store(dq_ptrs, dq, mask=(offs_l_curr[:, None] < seqlen))
469
+ tl.store(dv_ptrs, dv, mask=(offs_l_curr[:, None] < seqlen))
470
+
471
+
472
+ @triton.jit
473
+ def _jagged_dense_flash_attention_bwd_dk_kernel(
474
+ q_ptr,
475
+ k_ptr,
476
+ v_ptr,
477
+ ab_ptr, # attn bias
478
+ jagged_offsets_ptr,
479
+ out_ptr,
480
+ do_ptr,
481
+ lse_ptr,
482
+ delta_ptr,
483
+ dq_ptr,
484
+ dk_ptr,
485
+ dv_ptr,
486
+ dbias_ptr,
487
+ max_seq_len,
488
+ stride_ql,
489
+ stride_qd,
490
+ stride_kb,
491
+ stride_kd,
492
+ stride_kt,
493
+ stride_vl,
494
+ stride_vd,
495
+ stride_ab_b, # attn bias stride batch
496
+ stride_ab_l,
497
+ stride_ab_t,
498
+ stride_ob,
499
+ stride_ot,
500
+ stride_od,
501
+ stride_dk_b,
502
+ stride_dk_d,
503
+ stride_dk_t,
504
+ stride_do_b,
505
+ stride_do_t,
506
+ stride_do_d,
507
+ D,
508
+ T: tl.constexpr,
509
+ BLOCK_T: tl.constexpr,
510
+ BLOCK_L: tl.constexpr,
511
+ BLOCK_D: tl.constexpr,
512
+ allow_tf32: tl.constexpr,
513
+ ):
514
+ pid_t = tl.program_id(0)
515
+ pid_b = tl.program_id(1)
516
+ # begin offset of the current sample
517
+ begin = tl.load(jagged_offsets_ptr + pid_b)
518
+ # end offset of the current sample
519
+ end = tl.load(jagged_offsets_ptr + pid_b + 1)
520
+
521
+ # The seq length of the current sample
522
+ seqlen = end - begin
523
+ seqlen = tl.minimum(seqlen, max_seq_len)
524
+
525
+ if seqlen == 0:
526
+ return
527
+
528
+ q_start_ptr = q_ptr + begin * stride_ql
529
+ k_start_ptr = k_ptr + pid_b * stride_kb
530
+ ab_start_ptr = ab_ptr + pid_b * stride_ab_b
531
+ v_start_ptr = v_ptr + begin * stride_vl
532
+ do_start_ptr = do_ptr + pid_b * stride_do_b
533
+ dk_start_ptr = dk_ptr + pid_b * stride_dk_b
534
+ delta_ptrs = delta_ptr + pid_b * T
535
+ lse_ptrs = lse_ptr + pid_b * T
536
+
537
+ offs_t_curr = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)
538
+ offs_d = tl.arange(0, BLOCK_D)
539
+
540
+ k_ptrs = (
541
+ k_start_ptr + offs_d[:, None] * stride_kd + offs_t_curr[None, :] * stride_kt
542
+ )
543
+
544
+ do_ptrs = (
545
+ do_start_ptr
546
+ + offs_t_curr[:, None] * stride_do_t
547
+ + offs_d[None, :] * stride_do_d
548
+ )
549
+
550
+ dk_ptrs = (
551
+ dk_start_ptr
552
+ + offs_d[:, None] * stride_dk_d
553
+ + offs_t_curr[None, :] * stride_dk_t
554
+ )
555
+
556
+ dk = tl.zeros([BLOCK_D, BLOCK_T], dtype=tl.float32)
557
+
558
+ # Load a block of k into [BLOCK_D, BLOCK_T]
559
+ k = tl.load(
560
+ k_ptrs,
561
+ mask=((offs_t_curr[None, :] < T) & (offs_d[:, None] < BLOCK_D)),
562
+ other=0.0,
563
+ )
564
+
565
+ start_l = 0
566
+ while start_l < seqlen:
567
+ offs_l_curr = start_l + tl.arange(0, BLOCK_L)
568
+
569
+ # Load a block of q into [BLOCK_L, BLOCK_D]
570
+ q_ptrs = (
571
+ q_start_ptr + offs_l_curr[:, None] * stride_ql + offs_d[None, :] * stride_qd
572
+ )
573
+
574
+ q = tl.load(
575
+ q_ptrs,
576
+ mask=(offs_l_curr[:, None] < seqlen),
577
+ other=0.0,
578
+ )
579
+
580
+ v_ptrs = (
581
+ v_start_ptr + offs_l_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd
582
+ )
583
+
584
+ v = tl.load(v_ptrs, mask=(offs_l_curr[:, None] < seqlen), other=0.0)
585
+
586
+ qk = tl.zeros([BLOCK_L, BLOCK_T], dtype=tl.float32)
587
+ # gemm [BLOCK_L, BLOCK_D] x [BLOCK_D, BLOCK_T] -> [BLOCK_L, BLOCK_T]
588
+
589
+ qk = tl.dot(q, k, allow_tf32=allow_tf32)
590
+ qk = tl.where(
591
+ (offs_l_curr[:, None] < seqlen) & (offs_t_curr[None, :] < T), qk, 0.0
592
+ )
593
+
594
+ ab_ptrs = (
595
+ ab_start_ptr
596
+ + offs_l_curr[:, None] * stride_ab_l
597
+ + offs_t_curr[None, :] * stride_ab_t
598
+ )
599
+
600
+ ab = tl.load(
601
+ ab_ptrs,
602
+ mask=((offs_l_curr[:, None] < seqlen) & (offs_t_curr[None, :] < T)),
603
+ other=0.0,
604
+ )
605
+
606
+ # q*k output + attn bias
607
+ qk = qk + ab
608
+
609
+ # Mask out invalid positions for softmax
610
+ qk_mask = (offs_l_curr[:, None] < seqlen) & (offs_t_curr[None, :] < T)
611
+ qk = tl.where(qk_mask, qk, float("-inf"))
612
+
613
+ lse_t = tl.load(lse_ptrs + offs_t_curr, mask=(offs_t_curr < T))
614
+ # Perform softmax
615
+ p = tl.exp(qk - lse_t[None, :])
616
+ p = tl.where(qk_mask, p, 0.0)
617
+
618
+ # Compute dv
619
+ # Load a block of do into [BLOCK_T, BLOCK_D]
620
+ do = tl.load(do_ptrs, mask=(offs_t_curr[:, None] < T), other=0.0)
621
+
622
+ # Compute dp
623
+ delta = tl.load(delta_ptrs + offs_t_curr, mask=(offs_t_curr < T))
624
+
625
+ # gemm [BLOCK_T, BLOCK_D] x [BLOCK_D, BLOCK_L] = [BLOCK_T, BLOCK_L]
626
+ # [BLOCK_T, BLOCK_L]^T -> [BLOCK_L, BLOCK_T]
627
+ dp = tl.trans(tl.dot(do, tl.trans(v), allow_tf32=allow_tf32))
628
+
629
+ # Compute ds = p * (dp - delta)
630
+ ds = p * (dp - delta[None, :])
631
+
632
+ # Compute dk
633
+ # gemm [BLOCK_D, BLOCK_L] x [BLOCK_L, BLOCK_T] = [BLOCK_D, BLOCK_T]
634
+ dk += tl.dot(tl.trans(q), ds, allow_tf32=allow_tf32)
635
+
636
+ start_l += BLOCK_L
637
+
638
+ tl.store(dk_ptrs, dk, mask=(offs_t_curr[None, :] < T))
639
+
640
+
641
+ def jagged_dense_flash_attention_bwd(
642
+ Q,
643
+ K,
644
+ V,
645
+ Out,
646
+ lse,
647
+ do, # derivative of attn_out
648
+ attn_bias,
649
+ jagged_offsets,
650
+ max_seq_len,
651
+ allow_tf32=False,
652
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
653
+ """
654
+ Q: jagged tensor, [sum_B, D]
655
+ K: dense tensor, [B, D, T]
656
+ V: jagged tensor [sum_B, D]
657
+ Out: dense tensor: [B, T, D]
658
+ lse: dense tensor [B, T]
659
+ do: dense tensor [B, T, D]
660
+ attn_bias: dense tensor [B, N, T]
661
+ jagged_offsets: tensor [B + 1]
662
+ """
663
+ assert Q.size(1) == K.size(1), "incompatible dimensions for Q and K"
664
+ assert Q.size() == V.size(), "incompatible dimensions for Q and V"
665
+ assert lse.size(1) == K.size(2), "incompatible dimensions for LSE and K"
666
+
667
+ if not do.is_contiguous():
668
+ do = do.contiguous()
669
+
670
+ (B, D, T) = K.size()
671
+ BLOCK_T = 32
672
+ BLOCK_L = 32
673
+ BLOCK_D = D
674
+ num_blocks_k = triton.cdiv(T, BLOCK_T)
675
+
676
+ dk = torch.zeros_like(K)
677
+ dq = torch.zeros_like(Q)
678
+ dv = torch.zeros_like(V)
679
+ dbias = torch.zeros_like(attn_bias)
680
+
681
+ delta = torch.empty_like(lse)
682
+ _bwd_preprocess_do_o_dot[(num_blocks_k, B)](
683
+ Out,
684
+ do,
685
+ delta,
686
+ T,
687
+ Out.stride(0),
688
+ Out.stride(1),
689
+ Out.stride(2),
690
+ do.stride(0),
691
+ do.stride(1),
692
+ do.stride(2),
693
+ BLOCK_T, # pyre-ignore
694
+ BLOCK_D,
695
+ )
696
+
697
+ num_blocks_l = triton.cdiv(max_seq_len, BLOCK_L)
698
+ _jagged_dense_flash_attention_bwd_dv_db_dq_kernel[(num_blocks_l, B)](
699
+ Q,
700
+ K,
701
+ V,
702
+ attn_bias,
703
+ jagged_offsets,
704
+ Out,
705
+ do,
706
+ lse,
707
+ delta,
708
+ dq,
709
+ dk,
710
+ dv,
711
+ dbias,
712
+ max_seq_len,
713
+ Q.stride(0),
714
+ Q.stride(1),
715
+ K.stride(0),
716
+ K.stride(1),
717
+ K.stride(2),
718
+ V.stride(0),
719
+ V.stride(1),
720
+ attn_bias.stride(0),
721
+ attn_bias.stride(1),
722
+ attn_bias.stride(2),
723
+ Out.stride(0),
724
+ Out.stride(1),
725
+ Out.stride(2),
726
+ dq.stride(0),
727
+ dq.stride(1),
728
+ dv.stride(0),
729
+ dv.stride(1),
730
+ dbias.stride(0),
731
+ dbias.stride(1),
732
+ dbias.stride(2),
733
+ do.stride(0),
734
+ do.stride(1),
735
+ do.stride(2),
736
+ T,
737
+ BLOCK_T, # pyre-ignore
738
+ BLOCK_L, # pyre-ignore
739
+ BLOCK_D,
740
+ allow_tf32,
741
+ )
742
+
743
+ num_blocks_t = triton.cdiv(T, BLOCK_T)
744
+ _jagged_dense_flash_attention_bwd_dk_kernel[(num_blocks_t, B)](
745
+ Q,
746
+ K,
747
+ V,
748
+ attn_bias,
749
+ jagged_offsets,
750
+ Out,
751
+ do,
752
+ lse,
753
+ delta,
754
+ dq,
755
+ dk,
756
+ dv,
757
+ dbias,
758
+ max_seq_len,
759
+ Q.stride(0),
760
+ Q.stride(1),
761
+ K.stride(0),
762
+ K.stride(1),
763
+ K.stride(2),
764
+ V.stride(0),
765
+ V.stride(1),
766
+ attn_bias.stride(0),
767
+ attn_bias.stride(1),
768
+ attn_bias.stride(2),
769
+ Out.stride(0),
770
+ Out.stride(1),
771
+ Out.stride(2),
772
+ dk.stride(0),
773
+ dk.stride(1),
774
+ dk.stride(2),
775
+ do.stride(0),
776
+ do.stride(1),
777
+ do.stride(2),
778
+ D,
779
+ T,
780
+ BLOCK_T, # pyre-ignore
781
+ BLOCK_L, # pyre-ignore
782
+ BLOCK_D,
783
+ allow_tf32,
784
+ )
785
+
786
+ return dq, dk, dv, dbias
787
+
788
+
789
+ class JaggedDenseFlashAttention(torch.autograd.Function):
790
+ @staticmethod
791
+ # pyre-fixme
792
+ def forward(
793
+ ctx,
794
+ Q: torch.Tensor,
795
+ K: torch.Tensor,
796
+ V: torch.Tensor,
797
+ attn_bias: torch.Tensor,
798
+ jagged_offsets: torch.Tensor,
799
+ max_seq_len: int,
800
+ allow_tf32: bool = False,
801
+ ) -> torch.Tensor:
802
+ attn_out, lse = jagged_dense_flash_attention_fwd(
803
+ Q, K, V, attn_bias, jagged_offsets, max_seq_len, allow_tf32
804
+ )
805
+ ctx.save_for_backward(Q, K, V, attn_bias, jagged_offsets, lse, attn_out)
806
+ ctx.max_seq_len = max_seq_len
807
+ ctx.allow_tf32 = allow_tf32
808
+ return attn_out
809
+
810
+ @staticmethod
811
+ # pyre-fixme
812
+ def backward(
813
+ ctx, do: torch.Tensor
814
+ ) -> tuple[
815
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None
816
+ ]:
817
+ Q, K, V, attn_bias, jagged_offsets, lse, attn_out = ctx.saved_tensors
818
+ max_seq_len = ctx.max_seq_len
819
+ allow_tf32 = ctx.allow_tf32
820
+
821
+ dq, dk, dv, dbias = jagged_dense_flash_attention_bwd(
822
+ Q,
823
+ K,
824
+ V,
825
+ attn_out,
826
+ lse,
827
+ do,
828
+ attn_bias,
829
+ jagged_offsets,
830
+ max_seq_len,
831
+ allow_tf32,
832
+ )
833
+ return dq, dk, dv, dbias, None, None, None
834
+
835
+
836
+ def jagged_dense_flash_attention(
837
+ q: torch.Tensor,
838
+ k: torch.Tensor,
839
+ v: torch.Tensor,
840
+ attn_bias: torch.Tensor,
841
+ offsets: torch.Tensor,
842
+ max_seq_len: int,
843
+ allow_tf32: bool = True,
844
+ ):
845
+ """
846
+ q: jagged tensor, [sum_B, D]
847
+ k: dense tensor, [B, D, T]
848
+ v: jagged tensor [sum_B, D]
849
+ attn_bias: dense tensor [B, N, T]
850
+ offsets: offsets for jagged tensor [B + 1]
851
+ """
852
+
853
+ q = expect_contiguous(q)
854
+ k = expect_contiguous(k)
855
+ v = expect_contiguous(v)
856
+ attn_bias = expect_contiguous(attn_bias)
857
+ offsets = expect_contiguous(offsets)
858
+
859
+ return JaggedDenseFlashAttention.apply(
860
+ q, k, v, attn_bias, offsets, max_seq_len, allow_tf32
861
+ )