fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,1001 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ from typing import Any
9
+
10
+ import torch
11
+
12
+
13
+ def cpu_jagged_jagged_bmm_kernel(
14
+ x: torch.Tensor, y: torch.Tensor, x_offsets: torch.Tensor, max_seq_len: int
15
+ ) -> torch.Tensor:
16
+ assert x.size(1) == y.size(0), "incompatible dimensions"
17
+ B = x_offsets.size(0) - 1
18
+ D, _ = x.size()
19
+ _, T = y.size()
20
+ z = torch.empty((B, D, T), dtype=x.dtype, device=x.device)
21
+
22
+ for b in range(B):
23
+ z[b, :, :] = torch.mm(
24
+ x[:, x_offsets[b] : x_offsets[b + 1]],
25
+ y[x_offsets[b] : x_offsets[b + 1], :],
26
+ )
27
+ return z
28
+
29
+
30
+ def cpu_jagged_dense_bmm_kernel(
31
+ x: torch.Tensor, y: torch.Tensor, x_offsets: torch.Tensor, max_seq_len: int
32
+ ) -> torch.Tensor:
33
+ assert x.size(1) == y.size(1), "incompatible dimensions"
34
+ B = x_offsets.size(0) - 1
35
+ z = torch.zeros((x.size(0), y.size(2)), dtype=x.dtype, device=x.device)
36
+
37
+ for b in range(B):
38
+ z[x_offsets[b] : x_offsets[b + 1], :] = torch.mm(
39
+ x[x_offsets[b] : x_offsets[b + 1], :], y[b, :, :]
40
+ )
41
+ return z
42
+
43
+
44
+ class JaggedDenseBmmCPU(torch.autograd.Function):
45
+ """
46
+ Compute batch matrix multiplication between JaggedTensor and dense tensor
47
+ dense: [B, N, D] * [B, D, T] = [B, N, T]
48
+ jagged: [Sum_B, D] * [B, D, T] = [Sum_B, T]
49
+ """
50
+
51
+ @staticmethod
52
+ # pyre-fixme
53
+ def forward(
54
+ ctx: Any, # pyre-ignore
55
+ x: torch.Tensor,
56
+ y: torch.Tensor,
57
+ x_offsets: torch.Tensor,
58
+ N: int,
59
+ ) -> torch.Tensor:
60
+ ctx.save_for_backward(x, y, x_offsets)
61
+ ctx.N = N
62
+ return cpu_jagged_dense_bmm_kernel(x, y, x_offsets, N)
63
+
64
+ @staticmethod
65
+ # pyre-fixme
66
+ def backward(
67
+ ctx: Any, grad_output: torch.Tensor # pyre-ignore
68
+ ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
69
+ """
70
+ # X = [Sum_B, D]
71
+ # Y = [B, D, T]
72
+ # Z = X * Y = [Sum_B, T]
73
+ # dX = dZ * YT # [Sum_B, T] * [B, T, D] = [Sum_B, D]
74
+ # dY = XT * dZ # [D, sum_B] * [sum_B, T] = [D, B, T]
75
+ """
76
+ (x, y, x_offsets) = ctx.saved_tensors
77
+ N = ctx.N
78
+ grad_x = cpu_jagged_dense_bmm_kernel(
79
+ grad_output, y.permute(0, 2, 1), x_offsets, N
80
+ )
81
+ grad_y = cpu_jagged_jagged_bmm_kernel(x.T, grad_output, x_offsets, N)
82
+ return grad_x, grad_y, None, None, None
83
+
84
+
85
+ def cpu_jagged_dense_bmm(
86
+ x: torch.Tensor,
87
+ y: torch.Tensor,
88
+ x_offsets: torch.Tensor,
89
+ N: int,
90
+ allow_tf32: bool,
91
+ use_fbgemm_kernel: bool = True,
92
+ ) -> torch.Tensor:
93
+ """
94
+ Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
95
+ dense: [B, D, N] * [B, N, T] = [B, D, T]
96
+ jagged: [D, Sum_B] * [Sum_B, T] = [B, D, T]
97
+ """
98
+
99
+ # Force the CPU backend to use fbgemm kernel as it has better performance
100
+ use_fbgemm_kernel = True
101
+ if use_fbgemm_kernel:
102
+ return torch.ops.fbgemm.jagged_dense_bmm(x, x_offsets, y, N)[0]
103
+ else:
104
+ return JaggedDenseBmmCPU.apply(x, y, x_offsets, N)
105
+
106
+
107
+ class JaggedJaggedBmm(torch.autograd.Function):
108
+ """
109
+ Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
110
+ dense: [B, D, N] * [B, N, T] = [B, D, T]
111
+ jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
112
+ """
113
+
114
+ @staticmethod
115
+ # pyre-fixme
116
+ def forward(
117
+ ctx: Any, # pyre-ignore
118
+ x: torch.Tensor,
119
+ y: torch.Tensor,
120
+ x_offsets: torch.Tensor,
121
+ N: int,
122
+ ) -> torch.Tensor:
123
+ ctx.save_for_backward(x, y, x_offsets)
124
+ ctx.N = N
125
+ return cpu_jagged_jagged_bmm_kernel(x.T, y, x_offsets, N)
126
+
127
+ @staticmethod
128
+ # pyre-fixme
129
+ def backward(
130
+ ctx: Any, grad_output: torch.Tensor # pyre-ignore
131
+ ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
132
+ """
133
+ # X = [Sum_B, D]
134
+ # Y = [Sum_B, T]
135
+ # Z = XT * Y = [B, D, T]
136
+ # dXT = dZ * YT -> dX = Y * dZT
137
+ # dY = X * dZ -> X * dZ
138
+ """
139
+ (x, y, offsets) = ctx.saved_tensors
140
+ N = ctx.N
141
+ grad_x = cpu_jagged_dense_bmm_kernel(
142
+ y, grad_output.permute(0, 2, 1), offsets, N
143
+ )
144
+ grad_y = cpu_jagged_dense_bmm_kernel(x, grad_output, offsets, N)
145
+ return grad_x, grad_y, None, None, None
146
+
147
+
148
+ def cpu_jagged_jagged_bmm(
149
+ x: torch.Tensor,
150
+ y: torch.Tensor,
151
+ x_offsets: torch.Tensor,
152
+ N: int,
153
+ allow_tf32: bool,
154
+ use_fbgemm_kernel: bool = True,
155
+ ) -> torch.Tensor:
156
+ """
157
+ Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
158
+ dense: [B, D, N] * [B, N, T] = [B, D, T]
159
+ jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
160
+ """
161
+
162
+ # Force the CPU backend to use fbgemm kernel as it has better performance
163
+ use_fbgemm_kernel = True
164
+ if use_fbgemm_kernel:
165
+ return torch.ops.fbgemm.jagged_jagged_bmm(x, y, x_offsets, N)
166
+ else:
167
+ return JaggedJaggedBmm.apply(x, y, x_offsets, N)
168
+
169
+
170
+ def cpu_dense_jagged_cat_jagged_out(
171
+ a: torch.Tensor,
172
+ b: torch.Tensor,
173
+ b_offsets: torch.Tensor,
174
+ max_seq_len: int,
175
+ ) -> tuple[torch.Tensor, torch.Tensor]:
176
+ assert a.size(0) == b_offsets.size(0) - 1
177
+ c = torch.empty(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
178
+ c_offsets = b_offsets + torch.arange(
179
+ b_offsets.size(0), dtype=torch.int64, device=a.device
180
+ )
181
+ lengths = torch.diff(b_offsets)
182
+ c = torch.cat(
183
+ [
184
+ (
185
+ torch.cat((a[i : i + 1], b[b_offsets[i] : b_offsets[i + 1]]), dim=-1)
186
+ if lengths[i] > 0
187
+ else a[i : i + 1]
188
+ )
189
+ for i in range(a.size(0))
190
+ ],
191
+ dim=-1,
192
+ )
193
+ return c, c_offsets
194
+
195
+
196
+ def cpu_jagged_self_substraction_jagged_out(
197
+ jagged_A: torch.Tensor,
198
+ offsets_a: torch.Tensor,
199
+ offsets_b: torch.Tensor,
200
+ max_seq_len: int,
201
+ ) -> torch.Tensor:
202
+ jagged_B = torch.empty(
203
+ (int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
204
+ )
205
+ for i in range(len(offsets_a) - 1):
206
+ if offsets_a[i + 1] - offsets_a[i] == 1:
207
+ continue
208
+
209
+ a = jagged_A[offsets_a[i] : offsets_a[i + 1]]
210
+ jagged_B[offsets_b[i] : offsets_b[i + 1]] = (
211
+ a[:-1].unsqueeze(1) - a[1:].unsqueeze(0)
212
+ ).flatten()
213
+ return jagged_B
214
+
215
+
216
+ def cpu_jagged2_to_padded_dense(
217
+ values: torch.Tensor,
218
+ offsets: torch.Tensor,
219
+ max_length: int,
220
+ padding_value: float = 0.0,
221
+ ) -> torch.Tensor:
222
+ """
223
+ values: jagged tensor with size [sum(Ni * Ni)]
224
+ offsets: offsets for jagged tensor, with size [B + 1]
225
+ max_length: maximum sequence length in the batch
226
+ padding_value: value to use for padding
227
+ return padded dense tensor of size [B, N, N]
228
+ """
229
+ B = offsets.size(0) - 1
230
+ dense_output = torch.full(
231
+ (B, max_length, max_length),
232
+ padding_value,
233
+ dtype=values.dtype,
234
+ device=values.device,
235
+ )
236
+ for b in range(B):
237
+ begin = offsets[b]
238
+ end = offsets[b + 1]
239
+ Ni = int(torch.sqrt(end - begin))
240
+ if Ni == 0:
241
+ continue
242
+ dense_output[b, 0:Ni, 0:Ni] = values[begin:end].view(Ni, Ni)
243
+
244
+ return dense_output
245
+
246
+
247
+ class CPUJaggedDenseElementwiseMul(torch.autograd.Function):
248
+ # NOTE: CPU, GPU, CUDA versions all have their own autograd.Function implementations,
249
+ # ideally we should use one autograd.Function for all of them and do the dispatching
250
+ # inside the autograd.Function.
251
+ """
252
+ Compute elementwise multiplication between jagged tensor and dense tensor.
253
+ z = x * y
254
+ x: [sum_B(L_i)]
255
+ y: dense tensor
256
+ z: [sum_B(L_i)]
257
+ """
258
+
259
+ @staticmethod
260
+ def jagged_dense_elementwise_mul_jagged_out(
261
+ jagged: torch.Tensor,
262
+ dense: torch.Tensor,
263
+ seq_lengths: torch.Tensor,
264
+ offsets: torch.Tensor,
265
+ max_seq_len: int,
266
+ ) -> torch.Tensor:
267
+ out = torch.empty_like(jagged)
268
+ for i in range(seq_lengths.size(0)):
269
+ if seq_lengths[i] == 0:
270
+ continue
271
+ a = jagged[offsets[i] : offsets[i + 1]]
272
+ a = a.view(int(seq_lengths[i]), int(seq_lengths[i]))
273
+ out[offsets[i] : offsets[i + 1]] = (
274
+ a * dense[0 : seq_lengths[i], 0 : seq_lengths[i]]
275
+ ).flatten()
276
+ return out
277
+
278
+ @staticmethod
279
+ # pyre-fixme
280
+ def forward(
281
+ ctx, # pyre-ignore [2]
282
+ x: torch.Tensor,
283
+ y: torch.Tensor,
284
+ x_seq_lengths: torch.Tensor,
285
+ x_offsets: torch.Tensor,
286
+ max_seq_len: int,
287
+ ):
288
+ ctx.max_seq_len = max_seq_len
289
+
290
+ ctx.save_for_backward(
291
+ x,
292
+ y,
293
+ x_seq_lengths,
294
+ x_offsets,
295
+ )
296
+
297
+ return CPUJaggedDenseElementwiseMul.jagged_dense_elementwise_mul_jagged_out(
298
+ x,
299
+ y,
300
+ x_seq_lengths,
301
+ x_offsets,
302
+ max_seq_len,
303
+ )
304
+
305
+ @staticmethod
306
+ # pyre-fixme
307
+ def backward(ctx, grad_output: torch.Tensor):
308
+ (
309
+ x,
310
+ y,
311
+ x_seq_lengths,
312
+ x_offsets,
313
+ ) = ctx.saved_tensors
314
+
315
+ grad_x = CPUJaggedDenseElementwiseMul.jagged_dense_elementwise_mul_jagged_out(
316
+ grad_output,
317
+ y,
318
+ x_seq_lengths,
319
+ x_offsets,
320
+ ctx.max_seq_len,
321
+ )
322
+
323
+ return grad_x, None, None, None, None
324
+
325
+
326
+ def cpu_jagged_dense_elementwise_mul_jagged_out(
327
+ x: torch.Tensor,
328
+ y: torch.Tensor,
329
+ x_seq_lengths: torch.Tensor,
330
+ x_offsets: torch.Tensor,
331
+ max_seq_len: int,
332
+ ) -> torch.Tensor:
333
+ return CPUJaggedDenseElementwiseMul.apply(
334
+ x,
335
+ y,
336
+ x_seq_lengths,
337
+ x_offsets,
338
+ max_seq_len,
339
+ )
340
+
341
+
342
+ class JaggedSoftmaxCPU(torch.autograd.Function):
343
+ @staticmethod
344
+ # pyre-fixme
345
+ def forward(
346
+ ctx: Any, # pyre-ignore
347
+ x: torch.Tensor,
348
+ x_offsets: torch.Tensor,
349
+ max_seq_len: int,
350
+ ) -> torch.Tensor:
351
+ """
352
+ input shpae is [SUM_B, D]
353
+ output shape is [SUM_B, D]
354
+ """
355
+ B = x_offsets.size(0) - 1
356
+ y = torch.zeros(x.size(), device=x.device, dtype=x.dtype)
357
+
358
+ for b in range(B):
359
+ y[x_offsets[b] : x_offsets[b + 1], :] = torch.nn.functional.softmax(
360
+ x[x_offsets[b] : x_offsets[b + 1], :], dim=0
361
+ )
362
+
363
+ ctx.save_for_backward(y, x_offsets)
364
+
365
+ return y
366
+
367
+ @staticmethod
368
+ # pyre-fixme
369
+ def backward(
370
+ ctx: Any, grad_output: torch.Tensor # pyre-ignore
371
+ ) -> tuple[torch.Tensor, None, None]:
372
+ y, x_offsets = ctx.saved_tensors
373
+
374
+ B = x_offsets.size(0) - 1
375
+ grad = torch.zeros(y.size(), device=y.device, dtype=y.dtype)
376
+
377
+ for b in range(B):
378
+ curr_y = y[x_offsets[b] : x_offsets[b + 1]]
379
+ curr_grad = grad_output[x_offsets[b] : x_offsets[b + 1]]
380
+ grad[x_offsets[b] : x_offsets[b + 1]] = curr_y * (
381
+ curr_grad - torch.sum(curr_grad * curr_y, dim=0, keepdim=True)
382
+ )
383
+
384
+ return grad, None, None
385
+
386
+
387
+ def cpu_jagged_softmax(
388
+ x: torch.Tensor,
389
+ x_offsets: torch.Tensor,
390
+ max_seq_len: int,
391
+ use_fbgemm_kernel: bool = True,
392
+ ) -> torch.Tensor:
393
+ """
394
+ CPU version of jagged softmax: [sum(softmax([B_i, D]))]
395
+ """
396
+ # Force the CPU backend to use fbgemm kernel as it has better performance
397
+ use_fbgemm_kernel = True
398
+ if use_fbgemm_kernel:
399
+ return torch.ops.fbgemm.jagged_softmax(x, x_offsets, max_seq_len)[0]
400
+ else:
401
+ return JaggedSoftmaxCPU.apply(x, x_offsets, max_seq_len)
402
+
403
+
404
+ class Jagged2SoftmaxCPU(torch.autograd.Function):
405
+ @staticmethod
406
+ # pyre-fixme
407
+ def forward(
408
+ # pyre-fixme[2]: Parameter must be annotated.
409
+ ctx,
410
+ x: torch.Tensor,
411
+ x_offsets: torch.Tensor,
412
+ row_offsets: torch.Tensor,
413
+ head_offsets: torch.Tensor,
414
+ max_seq_len_row: int,
415
+ max_seq_len_head: int,
416
+ transpose: bool = True,
417
+ ) -> torch.Tensor:
418
+ B = x_offsets.size(0) - 1
419
+ y = torch.zeros(x.size(0), device=x.device, dtype=x.dtype)
420
+
421
+ for i in range(B):
422
+ submatrix = x[x_offsets[i] : x_offsets[i + 1]]
423
+ Ni = int(row_offsets[i + 1] - row_offsets[i])
424
+ softmax_dim = 0 if transpose else 1
425
+ y[x_offsets[i] : x_offsets[i + 1]] = torch.nn.functional.softmax(
426
+ submatrix.reshape((Ni, Ni)), dim=softmax_dim
427
+ ).view(-1)
428
+
429
+ ctx.save_for_backward(y, x_offsets, row_offsets, head_offsets)
430
+ ctx.max_seq_len_row = max_seq_len_row
431
+ ctx.max_seq_len_head = max_seq_len_head
432
+ ctx.transpose = transpose
433
+
434
+ return y
435
+
436
+ @staticmethod
437
+ # pyre-fixme
438
+ def backward(ctx, grad_output: torch.Tensor):
439
+ y, x_offsets, row_offsets, head_offsets = ctx.saved_tensors
440
+ B = x_offsets.size(0) - 1
441
+ transpose = ctx.transpose
442
+ softmax_dim = 0 if transpose else -1
443
+ grad = torch.zeros(y.size(0), device=y.device, dtype=y.dtype)
444
+
445
+ for i in range(B):
446
+ Ni = row_offsets[i + 1] - row_offsets[i]
447
+ curr_y = y[x_offsets[i] : x_offsets[i + 1]].reshape(Ni, Ni)
448
+ curr_grad = grad_output[x_offsets[i] : x_offsets[i + 1]].reshape(Ni, Ni)
449
+ grad[x_offsets[i] : x_offsets[i + 1]] = (
450
+ curr_y
451
+ * (
452
+ curr_grad
453
+ - torch.sum(curr_grad * curr_y, dim=softmax_dim, keepdim=True)
454
+ )
455
+ ).view(-1)
456
+
457
+ return grad, None, None, None, None, None, None
458
+
459
+
460
+ def cpu_jagged2_softmax(
461
+ x: torch.Tensor,
462
+ offsets: torch.Tensor,
463
+ offsets_total: torch.Tensor,
464
+ max_seq_len: int,
465
+ transpose: bool,
466
+ ) -> torch.Tensor:
467
+ """
468
+ GPU version of jagged2 softmax: [sum(softmax([B_i, B_i]))]
469
+ """
470
+ return Jagged2SoftmaxCPU.apply(
471
+ x,
472
+ offsets_total,
473
+ offsets,
474
+ offsets,
475
+ max_seq_len,
476
+ max_seq_len,
477
+ transpose,
478
+ )
479
+
480
+
481
+ # pyre-fixme[3]: Return type must be annotated.
482
+ def cpu_jagged_jagged_bmm_jagged_out_kernel(
483
+ # pyre-fixme[2]: Parameter must be annotated.
484
+ jagged_A,
485
+ # pyre-fixme[2]: Parameter must be annotated.
486
+ jagged_B,
487
+ # pyre-fixme[2]: Parameter must be annotated.
488
+ max_seq_len,
489
+ # pyre-fixme[2]: Parameter must be annotated.
490
+ lengths_m,
491
+ # pyre-fixme[2]: Parameter must be annotated.
492
+ lengths_n,
493
+ # pyre-fixme[2]: Parameter must be annotated.
494
+ lengths_mn,
495
+ # pyre-fixme[2]: Parameter must be annotated.
496
+ offsets_m,
497
+ # pyre-fixme[2]: Parameter must be annotated.
498
+ offsets_n,
499
+ # pyre-fixme[2]: Parameter must be annotated.
500
+ offsets_mn,
501
+ # pyre-fixme[2]: Parameter must be annotated.
502
+ allow_tf32=False,
503
+ ):
504
+ jagged_C = torch.empty((int(lengths_mn.sum().item())), dtype=jagged_A.dtype).to(
505
+ jagged_A.device
506
+ )
507
+ B = lengths_m.size(0)
508
+
509
+ for i in range(B):
510
+ jagged_C[offsets_mn[i] : offsets_mn[i + 1]] = torch.matmul(
511
+ jagged_A[offsets_m[i] : offsets_m[i + 1]],
512
+ jagged_B[offsets_n[i] : offsets_n[i + 1]].T,
513
+ ).flatten()
514
+ return jagged_C
515
+
516
+
517
+ # pyre-fixme[3]: Return type must be annotated.
518
+ def cpu_array_jagged_bmm_jagged_out_kernel(
519
+ # pyre-fixme[2]: Parameter must be annotated.
520
+ array_A,
521
+ # pyre-fixme[2]: Parameter must be annotated.
522
+ jagged_B,
523
+ # pyre-fixme[2]: Parameter must be annotated.
524
+ lengths_am,
525
+ # pyre-fixme[2]: Parameter must be annotated.
526
+ lengths_bk,
527
+ # pyre-fixme[2]: Parameter must be annotated.
528
+ lengths_cm,
529
+ # pyre-fixme[2]: Parameter must be annotated.
530
+ offsets_am,
531
+ # pyre-fixme[2]: Parameter must be annotated.
532
+ offsets_bk,
533
+ # pyre-fixme[2]: Parameter must be annotated.
534
+ offsets_cm,
535
+ # pyre-fixme[2]: Parameter must be annotated.
536
+ max_seq_len,
537
+ # pyre-fixme[2]: Parameter must be annotated.
538
+ allow_tf32=False,
539
+ # pyre-fixme[2]: Parameter must be annotated.
540
+ transpose=0, # one if a is transpose, otherwise zero
541
+ ):
542
+ B = lengths_am.size(0)
543
+ D = jagged_B.size(1)
544
+ jagged_C = torch.zeros(
545
+ (int(lengths_cm.sum()), D), device=jagged_B.device, dtype=jagged_B.dtype
546
+ )
547
+
548
+ for i in range(B):
549
+ seq_len = int(lengths_bk[i])
550
+ capped_seq_len = min(seq_len, max_seq_len)
551
+ a = array_A[offsets_am[i] : offsets_am[i + 1]].view(seq_len, seq_len)
552
+ a = a[:capped_seq_len, :capped_seq_len]
553
+
554
+ if transpose:
555
+ a = a.T
556
+ b = jagged_B[offsets_bk[i] : offsets_bk[i] + capped_seq_len]
557
+ jagged_C[offsets_cm[i] : offsets_cm[i] + capped_seq_len] = torch.matmul(a, b)
558
+
559
+ return jagged_C
560
+
561
+
562
+ class ArrayJaggedBmmNopaddingCPU(torch.autograd.Function):
563
+ """
564
+ Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
565
+ z = X * Y
566
+ x: [Sum_B(N_i, N_i)]
567
+ y: [sum_B(N_i), D]
568
+ z: [sum_B(N_i), D]
569
+ """
570
+
571
+ @staticmethod
572
+ # pyre-fixme
573
+ def forward(
574
+ # pyre-fixme[2]: Parameter must be annotated.
575
+ ctx,
576
+ x: torch.Tensor,
577
+ y: torch.Tensor,
578
+ x_lengths: torch.Tensor,
579
+ x_offsets: torch.Tensor,
580
+ y_lengths: torch.Tensor,
581
+ y_offsets: torch.Tensor,
582
+ z_lengths: torch.Tensor,
583
+ z_offsets: torch.Tensor,
584
+ max_seq_len: int,
585
+ # pyre-fixme[2]: Parameter must be annotated.
586
+ allow_tf32,
587
+ ):
588
+ ctx.allow_tf32 = allow_tf32
589
+ ctx.max_seq_len = max_seq_len
590
+
591
+ ctx.save_for_backward(
592
+ x,
593
+ y,
594
+ x_lengths,
595
+ y_lengths,
596
+ z_lengths,
597
+ x_offsets,
598
+ y_offsets,
599
+ z_offsets,
600
+ )
601
+
602
+ return cpu_array_jagged_bmm_jagged_out_kernel(
603
+ x,
604
+ y,
605
+ x_lengths,
606
+ y_lengths,
607
+ z_lengths,
608
+ x_offsets,
609
+ y_offsets,
610
+ z_offsets,
611
+ max_seq_len,
612
+ allow_tf32,
613
+ 0,
614
+ )
615
+
616
+ @staticmethod
617
+ # pyre-fixme
618
+ def backward(ctx, grad_output: torch.Tensor):
619
+ """
620
+ z = X * Y
621
+ dX = dZ * YT
622
+ dY = XT * dZ
623
+
624
+ dZ: [sum_B(N_i), D]
625
+ YT: [D, sum_B(N_i)] call Y.T
626
+ XT: transposed
627
+ Z: [sum_B(N_i), D]
628
+ """
629
+
630
+ (
631
+ x,
632
+ y,
633
+ x_lengths,
634
+ y_lengths,
635
+ z_lengths,
636
+ x_offsets,
637
+ y_offsets,
638
+ z_offsets,
639
+ ) = ctx.saved_tensors
640
+
641
+ grad_x = cpu_jagged_jagged_bmm_jagged_out_kernel(
642
+ grad_output,
643
+ y,
644
+ ctx.max_seq_len,
645
+ z_lengths,
646
+ y_lengths,
647
+ x_lengths,
648
+ z_offsets,
649
+ y_offsets,
650
+ x_offsets,
651
+ ctx.allow_tf32,
652
+ )
653
+
654
+ grad_y = cpu_array_jagged_bmm_jagged_out_kernel(
655
+ x,
656
+ grad_output,
657
+ x_lengths,
658
+ y_lengths,
659
+ z_lengths,
660
+ x_offsets,
661
+ y_offsets,
662
+ z_offsets,
663
+ ctx.max_seq_len,
664
+ ctx.allow_tf32,
665
+ 1,
666
+ )
667
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
668
+
669
+
670
+ # pyre-fixme[3]: Return type must be annotated.
671
+ def cpu_array_jagged_bmm_jagged_out(
672
+ x: torch.Tensor,
673
+ y: torch.Tensor,
674
+ x_lengths: torch.Tensor,
675
+ x_offsets: torch.Tensor,
676
+ y_lengths: torch.Tensor,
677
+ y_offsets: torch.Tensor,
678
+ z_lengths: torch.Tensor,
679
+ z_offsets: torch.Tensor,
680
+ max_seq_len: int,
681
+ allow_tf32: bool = True,
682
+ ):
683
+ return ArrayJaggedBmmNopaddingCPU.apply(
684
+ x,
685
+ y,
686
+ x_lengths,
687
+ x_offsets,
688
+ y_lengths,
689
+ y_offsets,
690
+ z_lengths,
691
+ z_offsets,
692
+ max_seq_len,
693
+ allow_tf32,
694
+ )
695
+
696
+
697
+ class JaggedJaggedBmmNoPaddingCPU(torch.autograd.Function):
698
+ """
699
+ Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
700
+ z = x x y^T
701
+ x: [sum_B(M_i), D]
702
+ y: [sum_B(N_i), D]
703
+ z: [sum_B(M_i * N_i)], assuming M_i = N_i
704
+ """
705
+
706
+ @staticmethod
707
+ # pyre-fixme
708
+ def forward(
709
+ # pyre-fixme[2]: Parameter must be annotated.
710
+ ctx,
711
+ x: torch.Tensor,
712
+ y: torch.Tensor,
713
+ x_lengths: torch.Tensor,
714
+ x_offsets: torch.Tensor,
715
+ y_lengths: torch.Tensor,
716
+ y_offsets: torch.Tensor,
717
+ z_lengths: torch.Tensor,
718
+ z_offsets: torch.Tensor,
719
+ max_seq_len: int,
720
+ # pyre-fixme[2]: Parameter must be annotated.
721
+ allow_tf32,
722
+ ):
723
+ ctx.allow_tf32 = allow_tf32
724
+ ctx.max_seq_len = max_seq_len
725
+
726
+ ctx.save_for_backward(
727
+ x,
728
+ y,
729
+ x_lengths,
730
+ y_lengths,
731
+ z_lengths,
732
+ x_offsets,
733
+ y_offsets,
734
+ z_offsets,
735
+ )
736
+
737
+ return cpu_jagged_jagged_bmm_jagged_out_kernel(
738
+ x,
739
+ y,
740
+ max_seq_len,
741
+ x_lengths,
742
+ y_lengths,
743
+ z_lengths,
744
+ x_offsets,
745
+ y_offsets,
746
+ z_offsets,
747
+ allow_tf32,
748
+ )
749
+
750
+ @staticmethod
751
+ # pyre-fixme
752
+ def backward(ctx, grad_output: torch.Tensor):
753
+ """
754
+ z = x x y^T
755
+ x: [sum_B(M_i), D]
756
+ y: [sum_B(N_i), D]
757
+ z: [sum_B(M_i * N_i)], assuming M_i = N_i
758
+ dx = dz x (y^T)^T = > dx = dz x y
759
+ d(y^T) = x^T x dz => dy = dz^T x x
760
+ """
761
+ (
762
+ x,
763
+ y,
764
+ x_lengths,
765
+ y_lengths,
766
+ z_lengths,
767
+ x_offsets,
768
+ y_offsets,
769
+ z_offsets,
770
+ ) = ctx.saved_tensors
771
+
772
+ grad_x = cpu_array_jagged_bmm_jagged_out_kernel(
773
+ grad_output,
774
+ y,
775
+ z_lengths,
776
+ y_lengths,
777
+ x_lengths,
778
+ z_offsets,
779
+ y_offsets,
780
+ x_offsets,
781
+ ctx.max_seq_len,
782
+ ctx.allow_tf32,
783
+ transpose=0,
784
+ )
785
+ grad_y = cpu_array_jagged_bmm_jagged_out_kernel(
786
+ grad_output,
787
+ x,
788
+ z_lengths,
789
+ x_lengths,
790
+ y_lengths,
791
+ z_offsets,
792
+ x_offsets,
793
+ y_offsets,
794
+ ctx.max_seq_len,
795
+ ctx.allow_tf32,
796
+ transpose=1,
797
+ )
798
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
799
+
800
+
801
+ # pyre-fixme[3]: Return type must be annotated.
802
+ def cpu_jagged_jagged_bmm_jagged_out(
803
+ x: torch.Tensor,
804
+ y: torch.Tensor,
805
+ x_lengths: torch.Tensor,
806
+ x_offsets: torch.Tensor,
807
+ y_lengths: torch.Tensor,
808
+ y_offsets: torch.Tensor,
809
+ z_lengths: torch.Tensor,
810
+ z_offsets: torch.Tensor,
811
+ max_seq_len: int,
812
+ allow_tf32: bool = True,
813
+ ):
814
+ return JaggedJaggedBmmNoPaddingCPU.apply(
815
+ x,
816
+ y,
817
+ x_lengths,
818
+ x_offsets,
819
+ y_lengths,
820
+ y_offsets,
821
+ z_lengths,
822
+ z_offsets,
823
+ max_seq_len,
824
+ allow_tf32,
825
+ )
826
+
827
+
828
+ def cpu_jagged_flash_attention_basic(
829
+ q_weights: torch.Tensor,
830
+ k_weights: torch.Tensor,
831
+ v_weights: torch.Tensor,
832
+ offsets: torch.Tensor,
833
+ max_seq_len: int,
834
+ use_mask: bool = False,
835
+ allow_tf32: bool = True,
836
+ ) -> torch.Tensor:
837
+ num_objects = offsets[1:] - offsets[0:-1:1]
838
+ attn_lengths = num_objects * num_objects
839
+ attn_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(attn_lengths)
840
+
841
+ s = torch.ops.fbgemm.sll_jagged_jagged_bmm_jagged_out(
842
+ x=q_weights,
843
+ y=k_weights, # transpose is done inside the function
844
+ x_lengths=num_objects,
845
+ x_offsets=offsets,
846
+ y_lengths=num_objects,
847
+ y_offsets=offsets,
848
+ z_lengths=attn_lengths,
849
+ z_offsets=attn_offsets,
850
+ max_seq_len=max_seq_len,
851
+ allow_tf32=allow_tf32,
852
+ )
853
+
854
+ p = (
855
+ torch.ops.fbgemm.sll_jagged2_softmax(
856
+ x=s,
857
+ offsets=offsets,
858
+ offsets_total=attn_offsets,
859
+ max_seq_len=max_seq_len,
860
+ transpose=False,
861
+ )
862
+ / max_seq_len
863
+ )
864
+
865
+ if use_mask:
866
+ attn_mask = torch.triu(
867
+ torch.ones(
868
+ (max_seq_len, max_seq_len),
869
+ dtype=torch.bool,
870
+ device=q_weights.device,
871
+ ),
872
+ ).requires_grad_(False)
873
+ # p = p * attn_mask
874
+ p = torch.ops.fbgemm.sll_jagged_dense_elementwise_mul_jagged_out(
875
+ x=p,
876
+ y=attn_mask,
877
+ x_seq_lengths=num_objects,
878
+ x_offsets=attn_offsets,
879
+ max_seq_len=max_seq_len,
880
+ )
881
+
882
+ jagged_O = torch.ops.fbgemm.sll_array_jagged_bmm_jagged_out(
883
+ x=p,
884
+ y=v_weights,
885
+ x_lengths=attn_lengths,
886
+ x_offsets=attn_offsets,
887
+ y_lengths=num_objects,
888
+ y_offsets=offsets,
889
+ z_lengths=num_objects,
890
+ z_offsets=offsets,
891
+ max_seq_len=max_seq_len,
892
+ allow_tf32=allow_tf32,
893
+ )
894
+
895
+ return jagged_O
896
+
897
+
898
+ class JaggedDenseAddCPU(torch.autograd.Function):
899
+ @staticmethod
900
+ # pyre-fixme
901
+ def forward(
902
+ ctx: Any, # pyre-ignore
903
+ x: torch.Tensor,
904
+ x_offsets: torch.Tensor,
905
+ y: torch.Tensor,
906
+ max_seq_len: int,
907
+ ) -> torch.Tensor:
908
+ ctx.save_for_backward(x_offsets)
909
+ ctx.max_seq_len = max_seq_len
910
+ # TODO: what should be the correct behavior when jagged values has length > max seq len?
911
+ # current behavior is to not truncate jagged values
912
+ # similar for backward grad_output
913
+ padded_x = torch.ops.fbgemm.jagged_to_padded_dense(
914
+ x,
915
+ [x_offsets],
916
+ max_lengths=[max_seq_len],
917
+ padding_value=0.0,
918
+ ) # [B, max_seq_len, D]
919
+ return torch.ops.fbgemm.dense_to_jagged(padded_x + y, [x_offsets])[0]
920
+
921
+ @staticmethod
922
+ # pyre-fixme
923
+ def backward(
924
+ ctx, # pyre-ignore
925
+ grad_output: torch.Tensor,
926
+ ) -> tuple[torch.Tensor, None, torch.Tensor, None]:
927
+ (offsets,) = ctx.saved_tensors
928
+ grad_dense = torch.ops.fbgemm.jagged_to_padded_dense(
929
+ grad_output, [offsets], [ctx.max_seq_len]
930
+ )
931
+ return grad_output, None, grad_dense, None
932
+
933
+
934
+ def cpu_jagged_dense_elementwise_add(
935
+ x: torch.Tensor,
936
+ x_offsets: torch.Tensor,
937
+ y: torch.Tensor,
938
+ max_seq_len: int,
939
+ use_fbgemm_kernel: bool = True,
940
+ ) -> torch.Tensor:
941
+ # Force the CPU backend to use fbgemm kernel as it has better performance
942
+ use_fbgemm_kernel = True
943
+ if use_fbgemm_kernel:
944
+ return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
945
+ x, [x_offsets], y
946
+ )[0]
947
+ else:
948
+ return JaggedDenseAddCPU.apply(x, x_offsets, y, max_seq_len)
949
+
950
+
951
+ def cpu_jagged_dense_flash_attention(
952
+ q: torch.Tensor,
953
+ k: torch.Tensor,
954
+ v: torch.Tensor,
955
+ attn_bias: torch.Tensor,
956
+ offsets: torch.Tensor,
957
+ max_seq_len: int,
958
+ allow_tf32: bool = True,
959
+ ) -> torch.Tensor:
960
+ """
961
+ q: jagged tensor, [sum_B, D]
962
+ k: dense tensor, [B, D, T]
963
+ v: jagged tensor [sum_B, D]
964
+ attn_bias: dense tensor [B, N, T]
965
+ offsets: offsets for jagged tensor [B + 1]
966
+ """
967
+
968
+ # [sum_B, D] * [B, D, T] = [sum_B, T]
969
+ qk = torch.ops.fbgemm.sll_jagged_dense_bmm(
970
+ q,
971
+ k.to(q.dtype),
972
+ offsets,
973
+ max_seq_len,
974
+ allow_tf32=allow_tf32,
975
+ use_fbgemm_kernel=True,
976
+ )
977
+
978
+ softmax_input = torch.ops.fbgemm.sll_jagged_dense_elementwise_add(
979
+ qk,
980
+ offsets,
981
+ attn_bias,
982
+ max_seq_len,
983
+ use_fbgemm_kernel=True,
984
+ )
985
+
986
+ normed_attn_weights = torch.ops.fbgemm.sll_jagged_softmax(
987
+ softmax_input,
988
+ offsets,
989
+ max_seq_len,
990
+ use_fbgemm_kernel=True,
991
+ ) # [sum_B, T]
992
+
993
+ # [sum_B, T] * [sum_B, D] = [B, T, D]
994
+ return torch.ops.fbgemm.sll_jagged_jagged_bmm(
995
+ normed_attn_weights,
996
+ v.to(normed_attn_weights.dtype),
997
+ offsets,
998
+ max_seq_len,
999
+ allow_tf32=allow_tf32,
1000
+ use_fbgemm_kernel=True,
1001
+ )