fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,667 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from .common import expect_contiguous
15
+
16
+
17
+ @triton.jit
18
+ def jagged_flash_attention_basic_kernel(
19
+ q_ptr,
20
+ k_ptr,
21
+ v_ptr,
22
+ offset_ptr,
23
+ o_ptr,
24
+ lse_i_ptr,
25
+ stride_qm,
26
+ stride_qd,
27
+ stride_kd,
28
+ stride_kn,
29
+ stride_vn,
30
+ stride_vd,
31
+ stride_om,
32
+ stride_od,
33
+ max_seq_len,
34
+ D: tl.constexpr,
35
+ NEXT_D: tl.constexpr,
36
+ use_mask: tl.constexpr,
37
+ allow_tf32: tl.constexpr,
38
+ BLOCK_SIZE_M: tl.constexpr,
39
+ BLOCK_SIZE_N: tl.constexpr,
40
+ BLOCK_SIZE_D: tl.constexpr,
41
+ ):
42
+ pid_m = tl.program_id(axis=0)
43
+ pid_batch = tl.program_id(axis=1)
44
+
45
+ begin = tl.load(offset_ptr + pid_batch)
46
+ end = tl.load(offset_ptr + pid_batch + 1)
47
+
48
+ seqlen = end - begin
49
+ seqlen = tl.minimum(seqlen, max_seq_len)
50
+
51
+ if pid_m * BLOCK_SIZE_M >= seqlen:
52
+ return
53
+
54
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
55
+ offs_d = tl.arange(0, BLOCK_SIZE_D)
56
+ # Offset till next power of 2 for D
57
+ offs_nextd = tl.arange(0, NEXT_D)
58
+
59
+ acc = tl.zeros([BLOCK_SIZE_M, NEXT_D], dtype=tl.float32)
60
+
61
+ m_i = tl.zeros([BLOCK_SIZE_M], dtype=tl.float32) - float("inf")
62
+ l_i = tl.zeros([BLOCK_SIZE_M], dtype=tl.float32)
63
+ for j in range(0, seqlen, BLOCK_SIZE_N):
64
+ offs_n = tl.arange(0, BLOCK_SIZE_N) + j
65
+ q_ptrs = (
66
+ q_ptr
67
+ + (offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd)
68
+ + begin * stride_qm
69
+ )
70
+
71
+ k_ptrs = (
72
+ k_ptr
73
+ + (offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kn)
74
+ + begin * stride_kn
75
+ )
76
+
77
+ qk = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
78
+
79
+ for d in range(0, D, BLOCK_SIZE_D):
80
+ updated_offset = d + offs_d
81
+ q = tl.load(
82
+ q_ptrs,
83
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
84
+ mask=((updated_offset[None, :] < D) & (offs_m[:, None] < seqlen)),
85
+ other=0.0,
86
+ )
87
+ k = tl.load(
88
+ k_ptrs,
89
+ mask=((updated_offset[:, None] < D) & (offs_n[None, :] < seqlen)),
90
+ other=0.0,
91
+ )
92
+ qk += tl.dot(q, k, allow_tf32=allow_tf32)
93
+
94
+ q_ptrs += BLOCK_SIZE_D * stride_qd
95
+ k_ptrs += BLOCK_SIZE_D * stride_kd
96
+
97
+ m_ij = tl.maximum(tl.max(qk, axis=1), m_i)
98
+ # Add the correct mask here
99
+ mn_mask = (offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)
100
+
101
+ p = tl.exp(qk - m_ij[:, None])
102
+ p = tl.where(mn_mask, p, 0.0)
103
+
104
+ l_ij = tl.sum(p, axis=1)
105
+ alpha = tl.exp(m_i - m_ij)
106
+
107
+ l_i = l_i * alpha + l_ij
108
+ acc = acc * alpha[:, None]
109
+
110
+ # Load V
111
+ v_ptrs = (
112
+ v_ptr
113
+ + (offs_nextd[None, :] * stride_vd + offs_n[:, None] * stride_vn)
114
+ + begin * stride_vn
115
+ )
116
+ v = tl.load(
117
+ v_ptrs,
118
+ mask=((offs_nextd[None, :] < D) & (offs_n[:, None] < seqlen)),
119
+ other=0.0,
120
+ )
121
+
122
+ p /= max_seq_len
123
+
124
+ if use_mask:
125
+ attn_mask = offs_m[:, None] - offs_n[None, :]
126
+ attn_mask = tl.where(mn_mask, attn_mask, 0.0)
127
+ attn_mask = tl.where(attn_mask > 0, 0.0, 1.0)
128
+ p = tl.where(attn_mask > 0, p, 0.0)
129
+
130
+ p = p.to(v_ptr.dtype.element_ty)
131
+ acc_j = tl.dot(p, v, allow_tf32=allow_tf32)
132
+ acc += acc_j
133
+ m_i = m_ij
134
+
135
+ lse_i = m_i + tl.math.log(l_i)
136
+ lse_i_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
137
+ lse_i_ptrs = lse_i_ptr + lse_i_offsets + begin
138
+
139
+ tl.store(lse_i_ptrs, lse_i, mask=lse_i_offsets < seqlen)
140
+
141
+ acc = acc / l_i[:, None]
142
+
143
+ # Store O
144
+ o_ptrs = o_ptr + (
145
+ offs_m[:, None] * stride_om
146
+ + offs_nextd[None, :] * stride_od
147
+ + begin * stride_om
148
+ )
149
+ o_mask = (offs_m[:, None] < seqlen) & (offs_nextd[None, :] < D)
150
+ tl.store(o_ptrs, acc, mask=o_mask)
151
+
152
+
153
+ def jagged_flash_attention_basic_fwd(
154
+ jagged_Q,
155
+ jagged_K,
156
+ jagged_V,
157
+ offsets,
158
+ max_seq_len,
159
+ use_mask,
160
+ allow_tf32=False,
161
+ ):
162
+ assert jagged_Q.size(1) == jagged_K.size(0), "incompatible dimensions"
163
+
164
+ B = offsets.size(0) - 1
165
+ D = jagged_Q.size(1)
166
+
167
+ jagged_O = torch.zeros_like(jagged_Q)
168
+ lse = torch.empty((jagged_Q.size(0)), device=jagged_Q.device, dtype=jagged_Q.dtype)
169
+
170
+ BLOCK_SIZE_M = 32
171
+ BLOCK_SIZE_N = 32
172
+ BLOCK_SIZE_D = 32
173
+
174
+ grid = (triton.cdiv(max_seq_len, BLOCK_SIZE_M), B)
175
+
176
+ jagged_flash_attention_basic_kernel[grid](
177
+ jagged_Q,
178
+ jagged_K,
179
+ jagged_V,
180
+ offsets,
181
+ jagged_O,
182
+ lse,
183
+ jagged_Q.stride(0),
184
+ jagged_Q.stride(1),
185
+ jagged_K.stride(0),
186
+ jagged_K.stride(1),
187
+ jagged_V.stride(0),
188
+ jagged_V.stride(1),
189
+ jagged_O.stride(0),
190
+ jagged_O.stride(1),
191
+ max_seq_len,
192
+ D,
193
+ triton.next_power_of_2(D),
194
+ use_mask,
195
+ allow_tf32,
196
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
197
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
198
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
199
+ )
200
+
201
+ return jagged_O, lse
202
+
203
+
204
+ # Similar to fwd kernel, this one is using a grid of
205
+ # (num_blocks_m, B) where num_blocks_m is seq_len / BLOCK_SIZE_M
206
+ @triton.jit
207
+ def _jagged_flash_attention_bwd_preprocess_basic_kernel(
208
+ o_ptr,
209
+ o_offset_ptr,
210
+ do_ptr,
211
+ delta_ptr,
212
+ stride_om,
213
+ stride_od,
214
+ max_seq_len,
215
+ D: tl.constexpr,
216
+ BLOCK_SIZE_M: tl.constexpr,
217
+ BLOCK_SIZE_D: tl.constexpr,
218
+ ):
219
+ pid_m = tl.program_id(axis=0)
220
+ pid_batch = tl.program_id(axis=1)
221
+
222
+ begin_o = tl.load(o_offset_ptr + pid_batch)
223
+ end_o = tl.load(o_offset_ptr + pid_batch + 1)
224
+
225
+ M = end_o - begin_o
226
+ M = tl.minimum(M, max_seq_len)
227
+
228
+ offs_om = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
229
+ offs_od = tl.arange(0, BLOCK_SIZE_D)
230
+
231
+ o_offsets = (
232
+ offs_om[:, None] * stride_om
233
+ + offs_od[None, :] * stride_od
234
+ + begin_o * stride_om
235
+ )
236
+ o_ptrs = o_ptr + o_offsets
237
+ do_ptrs = do_ptr + o_offsets
238
+ o_mask = (offs_om[:, None] < M) & (offs_od[None, :] < D)
239
+
240
+ # Load O
241
+ o = tl.load(o_ptrs, mask=o_mask)
242
+ do = tl.load(do_ptrs, mask=o_mask)
243
+
244
+ delta = tl.sum(o * do, axis=1)
245
+
246
+ tl.store(delta_ptr + begin_o + offs_om, delta, mask=offs_om < M)
247
+
248
+
249
+ @triton.jit
250
+ def _jagged_flash_attention_bwd_basic_kernel(
251
+ q_ptr,
252
+ k_ptr,
253
+ v_ptr,
254
+ o_ptr,
255
+ offset_ptr,
256
+ dq_ptr,
257
+ dk_ptr,
258
+ dv_ptr,
259
+ do_ptr,
260
+ delta_ptr,
261
+ lse_ptr,
262
+ stride_qm,
263
+ stride_qd,
264
+ stride_kn,
265
+ stride_kd,
266
+ stride_vn,
267
+ stride_vd,
268
+ stride_om,
269
+ stride_od,
270
+ stride_dqm,
271
+ stride_dqd,
272
+ stride_dkn,
273
+ stride_dkd,
274
+ stride_dvn,
275
+ stride_dvd,
276
+ stride_dom,
277
+ stride_dod,
278
+ max_seq_len,
279
+ D: tl.constexpr,
280
+ use_mask: tl.constexpr,
281
+ allow_tf32: tl.constexpr,
282
+ BLOCK_SIZE_M: tl.constexpr,
283
+ BLOCK_SIZE_N: tl.constexpr,
284
+ BLOCK_SIZE_D: tl.constexpr,
285
+ ):
286
+ pid_batch = tl.program_id(axis=1)
287
+
288
+ begin = tl.load(offset_ptr + pid_batch)
289
+ end = tl.load(offset_ptr + pid_batch + 1)
290
+
291
+ M = tl.minimum(end - begin, max_seq_len)
292
+
293
+ pid_n = tl.program_id(axis=0)
294
+ offs_d = tl.arange(0, BLOCK_SIZE_D)
295
+
296
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
297
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
298
+
299
+ q_ptrs = (
300
+ q_ptr
301
+ + begin * stride_qm
302
+ + (offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd)
303
+ )
304
+
305
+ k_ptrs = (
306
+ k_ptr
307
+ + begin * stride_kn
308
+ + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
309
+ )
310
+
311
+ v_ptrs = (
312
+ v_ptr
313
+ + begin * stride_vn
314
+ + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
315
+ )
316
+
317
+ do_ptrs = (
318
+ do_ptr
319
+ + begin * stride_dom
320
+ + (offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod)
321
+ )
322
+
323
+ # Load K and V
324
+ k = tl.load(k_ptrs, mask=((offs_d[None, :] < D) & (offs_n[:, None] < M)))
325
+ v = tl.load(v_ptrs, mask=((offs_d[None, :] < D) & (offs_n[:, None] < M)))
326
+
327
+ # Initialize dv and dk
328
+ dv = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_D], dtype=tl.float32)
329
+ dk = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_D], dtype=tl.float32)
330
+
331
+ for begin_m in range(0, M, BLOCK_SIZE_M):
332
+ offs_m_temp = begin_m + offs_m
333
+
334
+ # Load Q
335
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
336
+ q = tl.load(q_ptrs, mask=((offs_d[None, :] < D) & (offs_m_temp[:, None] < M)))
337
+ qk = tl.dot(q, tl.trans(k), allow_tf32=allow_tf32)
338
+
339
+ mn_mask = (offs_m_temp[:, None] < M) & (offs_n[None, :] < M)
340
+
341
+ # Load lse_i
342
+ lse_i = tl.load(lse_ptr + offs_m_temp + begin, mask=offs_m_temp < M)
343
+
344
+ p = tl.exp(qk - lse_i[:, None])
345
+ p = tl.where(mn_mask, p, 0.0)
346
+ p /= max_seq_len
347
+ p_masked = p
348
+
349
+ attn_mask = None
350
+ if use_mask:
351
+ attn_mask = offs_m_temp[:, None] - offs_n[None, :]
352
+ attn_mask = tl.where(mn_mask, attn_mask, 0.0)
353
+ attn_mask = tl.where(attn_mask > 0, 0.0, 1.0)
354
+ p_masked = tl.where(attn_mask > 0, p, 0.0)
355
+
356
+ p_masked = p_masked.to(do_ptr.dtype.element_ty)
357
+ do = tl.load(do_ptrs, mask=((offs_d[None, :] < D) & (offs_m_temp[:, None] < M)))
358
+ dv += tl.dot(tl.trans(p_masked), do, allow_tf32=allow_tf32)
359
+ dp = tl.dot(do, tl.trans(v), allow_tf32=allow_tf32)
360
+
361
+ # compute ds = p * (dp - delta[:, None])
362
+ Di = tl.load(delta_ptr + offs_m_temp + begin, mask=offs_m_temp < M)
363
+ dp_masked = dp
364
+ if use_mask:
365
+ dp_masked = tl.where(attn_mask > 0, dp, 0.0)
366
+
367
+ ds = p * (dp_masked - Di[:, None] * max_seq_len)
368
+
369
+ # compute dk = dot(ds.T, q)
370
+ ds = ds.to(q_ptr.dtype.element_ty)
371
+ dk += tl.dot(tl.trans(ds), q, allow_tf32=allow_tf32)
372
+
373
+ q_ptrs += BLOCK_SIZE_M * stride_qm
374
+ do_ptrs += BLOCK_SIZE_M * stride_dom
375
+
376
+ # store back dk and dv
377
+ dk_ptrs = (
378
+ dk_ptr
379
+ + begin * stride_dkn
380
+ + (offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd)
381
+ )
382
+
383
+ dv_ptrs = (
384
+ dv_ptr
385
+ + begin * stride_dvn
386
+ + (offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd)
387
+ )
388
+
389
+ tl.store(dk_ptrs, dk, mask=((offs_d[None, :] < D) & (offs_n[:, None] < M)))
390
+ tl.store(dv_ptrs, dv, mask=((offs_d[None, :] < D) & (offs_n[:, None] < M)))
391
+
392
+ start_m = tl.program_id(axis=0) * BLOCK_SIZE_N
393
+ offs_m_curr = start_m + tl.arange(0, BLOCK_SIZE_N)
394
+
395
+ dq_ptrs_curr = (
396
+ dq_ptr
397
+ + begin * stride_dqm
398
+ + (offs_m_curr[:, None] * stride_dqm + offs_d[None, :] * stride_dqd)
399
+ )
400
+
401
+ dq_curr = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_D], dtype=tl.float32)
402
+
403
+ q_ptrs_curr = (
404
+ q_ptr
405
+ + begin * stride_qm
406
+ + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
407
+ )
408
+
409
+ q_curr = tl.load(
410
+ q_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < M))
411
+ )
412
+
413
+ # Load lse_i
414
+ lse_i_curr = tl.load(lse_ptr + offs_m_curr + begin, mask=offs_m_curr < M)
415
+
416
+ do_ptrs_curr = (
417
+ do_ptr
418
+ + begin * stride_dom
419
+ + (offs_m_curr[:, None] * stride_dom + offs_d[None, :] * stride_dod)
420
+ )
421
+
422
+ # Load do
423
+ do_curr = tl.load(
424
+ do_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < M))
425
+ )
426
+ Di_curr = tl.load(delta_ptr + offs_m_curr + begin, mask=offs_m_curr < M)
427
+
428
+ # When computing dV, we want to compute [BLOCK_SIZE_N] rows of dV.
429
+ # Therefore, the loop's block size is BLOCK_SIZE_M instead of BLOCK_SIZE_N.
430
+ block_start = 0
431
+ while block_start < M:
432
+ offs_n_curr = block_start + tl.arange(0, BLOCK_SIZE_M)
433
+
434
+ k_ptrs_curr = (
435
+ k_ptr
436
+ + begin * stride_kn
437
+ + (offs_n_curr[:, None] * stride_kn + offs_d[None, :] * stride_kd)
438
+ )
439
+ v_ptrs_curr = (
440
+ v_ptr
441
+ + begin * stride_vn
442
+ + (offs_n_curr[:, None] * stride_vn + offs_d[None, :] * stride_vd)
443
+ )
444
+
445
+ k_curr = tl.load(
446
+ k_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_n_curr[:, None] < M))
447
+ )
448
+ v_curr = tl.load(
449
+ v_ptrs_curr, mask=((offs_d[None, :] < D) & (offs_n_curr[:, None] < M))
450
+ )
451
+
452
+ qk_curr = tl.dot(q_curr, tl.trans(k_curr), allow_tf32=allow_tf32)
453
+ mn_mask_curr = (offs_m_curr[:, None] < M) & (offs_n_curr[None, :] < M)
454
+
455
+ p_curr = tl.exp(qk_curr - lse_i_curr[:, None])
456
+ p_curr = tl.where(mn_mask_curr, p_curr, 0.0)
457
+ p_curr /= max_seq_len
458
+
459
+ # compute dp = dot(v, do)
460
+ dp_curr = tl.dot(do_curr, tl.trans(v_curr), allow_tf32=allow_tf32)
461
+ dp_curr_masked = dp_curr
462
+
463
+ # compute ds = p * (dp - delta[:, None])
464
+ if use_mask:
465
+ attn_mask = offs_m_curr[:, None] - offs_n_curr[None, :]
466
+ attn_mask = tl.where(mn_mask_curr, attn_mask, 0.0)
467
+ attn_mask = tl.where(attn_mask > 0, 0.0, 1.0)
468
+ dp_curr_masked = tl.where(attn_mask > 0, dp_curr, 0.0)
469
+
470
+ ds_curr = p_curr * (dp_curr_masked - Di_curr[:, None] * max_seq_len)
471
+
472
+ ds_curr = ds_curr.to(k_ptr.dtype.element_ty)
473
+ dq_curr += tl.dot(ds_curr, k_curr, allow_tf32=allow_tf32)
474
+ block_start += BLOCK_SIZE_M
475
+
476
+ tl.store(
477
+ dq_ptrs_curr, dq_curr, mask=((offs_d[None, :] < D) & (offs_m_curr[:, None] < M))
478
+ )
479
+
480
+
481
+ def jagged_flash_attention_basic_backward(
482
+ jagged_Q,
483
+ # K is non-transposed
484
+ jagged_K,
485
+ jagged_V,
486
+ jagged_O,
487
+ offsets,
488
+ dO,
489
+ lse,
490
+ max_seq_len,
491
+ use_mask,
492
+ allow_tf32=False,
493
+ ):
494
+ BLOCK_SIZE_M = 32
495
+ BLOCK_SIZE_N = 32
496
+
497
+ B = offsets.size(0) - 1
498
+ num_blocks_m = triton.cdiv(max_seq_len, BLOCK_SIZE_M)
499
+ pre_grid = (num_blocks_m, B)
500
+
501
+ BLOCK_SIZE_D = max(triton.next_power_of_2(jagged_Q.size(1)), 16)
502
+
503
+ delta = torch.empty_like(lse)
504
+ if not dO.is_contiguous():
505
+ dO = dO.contiguous()
506
+
507
+ _jagged_flash_attention_bwd_preprocess_basic_kernel[pre_grid](
508
+ jagged_O,
509
+ offsets,
510
+ dO,
511
+ delta,
512
+ jagged_O.stride(0),
513
+ jagged_O.stride(1),
514
+ max_seq_len,
515
+ jagged_O.size(1),
516
+ BLOCK_SIZE_M,
517
+ BLOCK_SIZE_D,
518
+ )
519
+
520
+ grid = (triton.cdiv(max_seq_len, BLOCK_SIZE_N), B)
521
+
522
+ dq = torch.zeros_like(jagged_Q)
523
+ dk = torch.zeros_like(jagged_K)
524
+ dv = torch.zeros_like(jagged_V)
525
+
526
+ D = jagged_Q.size(1)
527
+ _jagged_flash_attention_bwd_basic_kernel[grid](
528
+ jagged_Q,
529
+ jagged_K,
530
+ jagged_V,
531
+ jagged_O,
532
+ offsets,
533
+ dq,
534
+ dk,
535
+ dv,
536
+ dO,
537
+ delta,
538
+ lse,
539
+ jagged_Q.stride(0),
540
+ jagged_Q.stride(1),
541
+ jagged_K.stride(0),
542
+ jagged_K.stride(1),
543
+ jagged_V.stride(0),
544
+ jagged_V.stride(1),
545
+ jagged_O.stride(0),
546
+ jagged_O.stride(1),
547
+ dq.stride(0),
548
+ dq.stride(1),
549
+ dk.stride(0),
550
+ dk.stride(1),
551
+ dv.stride(0),
552
+ dv.stride(1),
553
+ dO.stride(0),
554
+ dO.stride(1),
555
+ max_seq_len,
556
+ D,
557
+ use_mask=use_mask,
558
+ allow_tf32=allow_tf32,
559
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
560
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
561
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
562
+ )
563
+
564
+ return dq, dk, dv
565
+
566
+
567
+ class JaggedFlashAttentionBasic(torch.autograd.Function):
568
+ @staticmethod
569
+ # pyre-fixme
570
+ def forward(
571
+ ctx,
572
+ jagged_Q: torch.Tensor,
573
+ jagged_K: torch.Tensor,
574
+ jagged_V: torch.Tensor,
575
+ offsets: torch.Tensor,
576
+ max_seq_len: int,
577
+ use_mask: bool = True,
578
+ allow_tf32: bool = False,
579
+ ) -> torch.Tensor:
580
+ ctx.allow_tf32 = allow_tf32
581
+ ctx.max_seq_len = max_seq_len
582
+ ctx.use_mask = use_mask
583
+
584
+ jagged_O, lse = jagged_flash_attention_basic_fwd(
585
+ jagged_Q,
586
+ jagged_K.T,
587
+ jagged_V,
588
+ offsets,
589
+ max_seq_len,
590
+ use_mask,
591
+ allow_tf32,
592
+ )
593
+
594
+ ctx.save_for_backward(
595
+ jagged_Q,
596
+ jagged_K,
597
+ jagged_V,
598
+ offsets,
599
+ jagged_O,
600
+ lse,
601
+ )
602
+
603
+ return jagged_O
604
+
605
+ @staticmethod
606
+ # pyre-fixme
607
+ def backward(
608
+ ctx, grad_output: torch.Tensor
609
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]:
610
+ (
611
+ jagged_Q,
612
+ jagged_K,
613
+ jagged_V,
614
+ offsets,
615
+ jagged_O,
616
+ lse,
617
+ ) = ctx.saved_tensors
618
+
619
+ dq, dk, dv = jagged_flash_attention_basic_backward(
620
+ jagged_Q=jagged_Q,
621
+ jagged_K=jagged_K,
622
+ jagged_V=jagged_V,
623
+ jagged_O=jagged_O,
624
+ offsets=offsets,
625
+ dO=grad_output,
626
+ lse=lse,
627
+ max_seq_len=ctx.max_seq_len,
628
+ use_mask=ctx.use_mask,
629
+ allow_tf32=ctx.allow_tf32,
630
+ )
631
+
632
+ return (
633
+ dq,
634
+ dk,
635
+ dv,
636
+ None,
637
+ None,
638
+ None,
639
+ None,
640
+ )
641
+
642
+
643
+ def jagged_flash_attention_basic(
644
+ q_weights: torch.Tensor,
645
+ k_weights: torch.Tensor,
646
+ v_weights: torch.Tensor,
647
+ offsets: torch.Tensor,
648
+ max_seq_len: int,
649
+ use_mask: bool = False,
650
+ allow_tf32: bool = True,
651
+ ) -> torch.Tensor:
652
+ q_weights = expect_contiguous(q_weights)
653
+ k_weights = expect_contiguous(k_weights)
654
+ v_weights = expect_contiguous(v_weights)
655
+ jagged_offsets = expect_contiguous(offsets)
656
+
657
+ jagged_O = JaggedFlashAttentionBasic.apply(
658
+ q_weights,
659
+ k_weights,
660
+ v_weights,
661
+ jagged_offsets,
662
+ max_seq_len,
663
+ use_mask,
664
+ allow_tf32,
665
+ )
666
+
667
+ return jagged_O