fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,553 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+
14
+ @triton.jit
15
+ def array_jagged_bmm_kernel(
16
+ a_ptr, # 1D array
17
+ b_ptr, # jagged matrix
18
+ c_ptr, # output, jagged matrix
19
+ a_offsets_ptr,
20
+ b_offsets_ptr,
21
+ c_offsets_ptr,
22
+ D, # emb dimension
23
+ stride_bk,
24
+ stride_bn,
25
+ stride_cm,
26
+ stride_cn,
27
+ transpose, # one if a is transpose, otherwise zero
28
+ max_seq_len,
29
+ BLOCK_SIZE_M: tl.constexpr,
30
+ BLOCK_SIZE_N: tl.constexpr,
31
+ BLOCK_SIZE_K: tl.constexpr,
32
+ allow_tf32: tl.constexpr,
33
+ ):
34
+
35
+ pid_batch = tl.program_id(2)
36
+ pid_m = tl.program_id(1)
37
+ pid_n = tl.program_id(0)
38
+
39
+ batch_offset_am = tl.load(a_offsets_ptr + pid_batch)
40
+ batch_offset_bk = tl.load(b_offsets_ptr + pid_batch)
41
+ batch_offset_cm = tl.load(c_offsets_ptr + pid_batch)
42
+
43
+ # calculate M, N, K
44
+ batch_K = tl.load(b_offsets_ptr + pid_batch + 1) - batch_offset_bk # b [batch_K, D]
45
+ batch_M = tl.load(c_offsets_ptr + pid_batch + 1) - batch_offset_cm
46
+
47
+ # use uncapped seq length to determine strides of a
48
+ stride_am = batch_M * (1 - transpose) + 1 * transpose
49
+ stride_ak = batch_M * transpose + 1 * (1 - transpose)
50
+
51
+ # truncate seq length
52
+ batch_K = tl.minimum(batch_K, max_seq_len)
53
+ batch_M = tl.minimum(batch_M, max_seq_len)
54
+
55
+ if batch_K == 0:
56
+ return
57
+
58
+ batch_N = D
59
+
60
+ # c [batch_M, D] boundary check
61
+ if pid_m * BLOCK_SIZE_M >= batch_M or pid_n * BLOCK_SIZE_N >= batch_N:
62
+ return
63
+
64
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % batch_M
65
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % batch_N
66
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
67
+ a_ptrs = (
68
+ a_ptr
69
+ + batch_offset_am
70
+ + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
71
+ )
72
+ b_ptrs = (
73
+ b_ptr
74
+ + batch_offset_bk * stride_bk
75
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
76
+ )
77
+
78
+ c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
79
+ for k in range(0, tl.cdiv(batch_K, BLOCK_SIZE_K)):
80
+ a = tl.load(
81
+ a_ptrs, mask=offs_k[None, :] < batch_K - k * BLOCK_SIZE_K, other=0.0
82
+ )
83
+ b = tl.load(
84
+ b_ptrs, mask=offs_k[:, None] < batch_K - k * BLOCK_SIZE_K, other=0.0
85
+ )
86
+ c += tl.dot(a, b, allow_tf32=allow_tf32)
87
+ a_ptrs += BLOCK_SIZE_K * stride_ak
88
+ b_ptrs += BLOCK_SIZE_K * stride_bk
89
+
90
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
91
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
92
+ c_ptrs = (
93
+ c_ptr
94
+ + stride_cm * batch_offset_cm
95
+ + stride_cm * offs_cm[:, None]
96
+ + stride_cn * offs_cn[None, :]
97
+ )
98
+ c_mask = (offs_cm[:, None] < batch_M) & (offs_cn[None, :] < batch_N)
99
+ tl.store(c_ptrs, c, mask=c_mask)
100
+
101
+
102
+ @triton.jit
103
+ def jagged_jagged_bmm_jagged_out_kernel(
104
+ a_ptr,
105
+ a_offset_ptr,
106
+ b_ptr,
107
+ b_offset_ptr,
108
+ c_ptr,
109
+ offsets_mn_ptr,
110
+ max_seq_len,
111
+ num_blocks_n,
112
+ K,
113
+ stride_am,
114
+ stride_ak,
115
+ stride_bk,
116
+ stride_bn,
117
+ allow_tf32: tl.constexpr,
118
+ BLOCK_SIZE_M: tl.constexpr,
119
+ BLOCK_SIZE_N: tl.constexpr,
120
+ BLOCK_SIZE_K: tl.constexpr,
121
+ ):
122
+ """
123
+ Kernel for computing C = A x B.
124
+ A has shape (sum_B(Mi), K), B has shape (K, sum_B(Ni))
125
+ and C has shape (sum_B(Mi * Ni))
126
+ """
127
+
128
+ pid = tl.program_id(axis=0)
129
+ pid_batch = tl.program_id(axis=1)
130
+
131
+ begin_a = tl.load(a_offset_ptr + pid_batch)
132
+ end_a = tl.load(a_offset_ptr + pid_batch + 1)
133
+
134
+ begin_b = tl.load(b_offset_ptr + pid_batch)
135
+ end_b = tl.load(b_offset_ptr + pid_batch + 1)
136
+
137
+ offset_mn = tl.load(offsets_mn_ptr + pid_batch)
138
+
139
+ M = end_a - begin_a
140
+ M = tl.minimum(M, max_seq_len)
141
+
142
+ N = end_b - begin_b
143
+ N = tl.minimum(N, max_seq_len)
144
+
145
+ pid_m = pid // num_blocks_n
146
+ pid_n = pid % num_blocks_n
147
+
148
+ if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
149
+ return
150
+
151
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
152
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
153
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
154
+
155
+ a_ptrs = (
156
+ a_ptr
157
+ + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
158
+ + begin_a * stride_am
159
+ )
160
+
161
+ b_ptrs = (
162
+ b_ptr
163
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
164
+ + begin_b * stride_bn
165
+ )
166
+
167
+ c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
168
+
169
+ for k in range(0, K, BLOCK_SIZE_K):
170
+ updated_offset = k + offs_k
171
+ a = tl.load(
172
+ a_ptrs,
173
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
174
+ mask=((updated_offset[None, :] < K) & (offs_am[:, None] < M)),
175
+ other=0.0,
176
+ )
177
+ b = tl.load(
178
+ b_ptrs,
179
+ mask=((updated_offset[:, None] < K) & (offs_bn[None, :] < N)),
180
+ other=0.0,
181
+ )
182
+ c += tl.dot(a, b, allow_tf32=allow_tf32)
183
+ a_ptrs += BLOCK_SIZE_K * stride_ak
184
+ b_ptrs += BLOCK_SIZE_K * stride_bk
185
+
186
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
187
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
188
+ c_ptrs = c_ptr + offset_mn + N * offs_cm[:, None] + offs_cn[None, :]
189
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
190
+ tl.store(c_ptrs, c, mask=c_mask)
191
+
192
+
193
+ def triton_array_jagged_bmm_jagged_out(
194
+ array_A,
195
+ jagged_B,
196
+ lengths_am,
197
+ lengths_bk,
198
+ lengths_cm,
199
+ offsets_am,
200
+ offsets_bk,
201
+ offsets_cm,
202
+ max_seq_len,
203
+ allow_tf32=False,
204
+ transpose=0, # one if a is transpose, otherwise zero
205
+ ):
206
+ B = lengths_am.size(0)
207
+ D = jagged_B.size(1)
208
+ L = jagged_B.size(0)
209
+ # gradients of the emb vectors beyond max_seq_len is set to zeros
210
+ jagged_C = torch.zeros((L, D), device=jagged_B.device, dtype=jagged_B.dtype)
211
+
212
+ BLOCK_SIZE_M = 32
213
+ BLOCK_SIZE_N = 32
214
+ BLOCK_SIZE_K = 32
215
+
216
+ num_blocks_m = triton.cdiv(max_seq_len, BLOCK_SIZE_M)
217
+ num_blocks_n = triton.cdiv(D, BLOCK_SIZE_N)
218
+ grid = (num_blocks_n, num_blocks_m, B)
219
+
220
+ array_jagged_bmm_kernel[grid](
221
+ array_A,
222
+ jagged_B,
223
+ jagged_C,
224
+ offsets_am,
225
+ offsets_bk,
226
+ offsets_cm,
227
+ D,
228
+ jagged_B.stride(0),
229
+ jagged_B.stride(1),
230
+ jagged_C.stride(0),
231
+ jagged_C.stride(1),
232
+ transpose,
233
+ max_seq_len,
234
+ BLOCK_SIZE_M,
235
+ BLOCK_SIZE_N,
236
+ BLOCK_SIZE_K,
237
+ allow_tf32,
238
+ )
239
+
240
+ return jagged_C
241
+
242
+
243
+ def triton_jagged_jagged_bmm_jagged_out(
244
+ jagged_A,
245
+ jagged_B,
246
+ max_seq_len,
247
+ lengths_m,
248
+ lengths_n,
249
+ lengths_mn,
250
+ offsets_m,
251
+ offsets_n,
252
+ offsets_mn,
253
+ allow_tf32=False,
254
+ ):
255
+ assert jagged_A.size(1) == jagged_B.size(0), "incompatible dimensions"
256
+ assert offsets_mn.is_contiguous(), "mn offsets mush be contiguous"
257
+ assert offsets_m.is_contiguous(), "m offsets mush be contiguous"
258
+ assert offsets_n.is_contiguous(), "n offsets mush be contiguous"
259
+
260
+ B = lengths_m.size(0)
261
+ jagged_C = torch.zeros(
262
+ (lengths_mn.sum()), device=jagged_A.device, dtype=jagged_A.dtype
263
+ )
264
+
265
+ BLOCK_SIZE_M = 32
266
+ BLOCK_SIZE_N = 32
267
+ BLOCK_SIZE_K = 32
268
+
269
+ num_blocks_m = triton.cdiv(max_seq_len, BLOCK_SIZE_M)
270
+ num_blocks_n = triton.cdiv(max_seq_len, BLOCK_SIZE_N)
271
+ grid = (num_blocks_m * num_blocks_n, B)
272
+
273
+ jagged_jagged_bmm_jagged_out_kernel[grid](
274
+ jagged_A,
275
+ offsets_m,
276
+ jagged_B,
277
+ offsets_n,
278
+ jagged_C,
279
+ offsets_mn,
280
+ max_seq_len,
281
+ num_blocks_n,
282
+ jagged_A.size(1),
283
+ jagged_A.stride(0),
284
+ jagged_A.stride(1),
285
+ jagged_B.stride(0),
286
+ jagged_B.stride(1),
287
+ allow_tf32,
288
+ BLOCK_SIZE_M,
289
+ BLOCK_SIZE_N,
290
+ BLOCK_SIZE_K,
291
+ )
292
+
293
+ return jagged_C
294
+
295
+
296
+ class ArrayJaggedBmmNopadding(torch.autograd.Function):
297
+ """
298
+ Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
299
+ z = X * Y
300
+ x: [Sum_B(N_i, N_i)]
301
+ y: [sum_B(N_i), D]
302
+ z: [sum_B(N_i), D]
303
+ """
304
+
305
+ @staticmethod
306
+ # pyre-fixme
307
+ def forward(
308
+ ctx,
309
+ x: torch.Tensor,
310
+ y: torch.Tensor,
311
+ x_lengths: torch.Tensor,
312
+ x_offsets: torch.Tensor,
313
+ y_lengths: torch.Tensor,
314
+ y_offsets: torch.Tensor,
315
+ z_lengths: torch.Tensor,
316
+ z_offsets: torch.Tensor,
317
+ max_seq_len: int,
318
+ allow_tf32,
319
+ ):
320
+ ctx.allow_tf32 = allow_tf32
321
+ ctx.max_seq_len = max_seq_len
322
+
323
+ ctx.save_for_backward(
324
+ x,
325
+ y,
326
+ x_lengths,
327
+ y_lengths,
328
+ z_lengths,
329
+ x_offsets,
330
+ y_offsets,
331
+ z_offsets,
332
+ )
333
+
334
+ return triton_array_jagged_bmm_jagged_out(
335
+ x,
336
+ y,
337
+ x_lengths,
338
+ y_lengths,
339
+ z_lengths,
340
+ x_offsets,
341
+ y_offsets,
342
+ z_offsets,
343
+ max_seq_len,
344
+ allow_tf32,
345
+ 0,
346
+ )
347
+
348
+ @staticmethod
349
+ # pyre-fixme
350
+ def backward(ctx, grad_output: torch.Tensor):
351
+ """
352
+ z = X * Y
353
+ dX = dZ * YT
354
+ dY = XT * dZ
355
+
356
+ dZ: [sum_B(N_i), D]
357
+ YT: [D, sum_B(N_i)] call Y.T
358
+ XT: transposed
359
+ Z: [sum_B(N_i), D]
360
+ """
361
+
362
+ (
363
+ x,
364
+ y,
365
+ x_lengths,
366
+ y_lengths,
367
+ z_lengths,
368
+ x_offsets,
369
+ y_offsets,
370
+ z_offsets,
371
+ ) = ctx.saved_tensors
372
+
373
+ grad_x = triton_jagged_jagged_bmm_jagged_out(
374
+ grad_output,
375
+ y.T,
376
+ ctx.max_seq_len,
377
+ z_lengths,
378
+ y_lengths,
379
+ x_lengths,
380
+ z_offsets,
381
+ y_offsets,
382
+ x_offsets,
383
+ ctx.allow_tf32,
384
+ )
385
+
386
+ grad_y = triton_array_jagged_bmm_jagged_out(
387
+ x,
388
+ grad_output,
389
+ x_lengths,
390
+ y_lengths,
391
+ z_lengths,
392
+ x_offsets,
393
+ y_offsets,
394
+ z_offsets,
395
+ ctx.max_seq_len,
396
+ ctx.allow_tf32,
397
+ 1,
398
+ )
399
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
400
+
401
+
402
+ class JaggedJaggedBmmNoPadding(torch.autograd.Function):
403
+ """
404
+ Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
405
+ z = x x y^T
406
+ x: [sum_B(M_i), D]
407
+ y: [sum_B(N_i), D]
408
+ z: [sum_B(M_i * N_i)], assuming M_i = N_i
409
+ """
410
+
411
+ @staticmethod
412
+ # pyre-fixme
413
+ def forward(
414
+ ctx,
415
+ x: torch.Tensor,
416
+ y: torch.Tensor,
417
+ x_lengths: torch.Tensor,
418
+ x_offsets: torch.Tensor,
419
+ y_lengths: torch.Tensor,
420
+ y_offsets: torch.Tensor,
421
+ z_lengths: torch.Tensor,
422
+ z_offsets: torch.Tensor,
423
+ max_seq_len: int,
424
+ allow_tf32,
425
+ ):
426
+ ctx.allow_tf32 = allow_tf32
427
+ ctx.max_seq_len = max_seq_len
428
+
429
+ ctx.save_for_backward(
430
+ x,
431
+ y,
432
+ x_lengths,
433
+ y_lengths,
434
+ z_lengths,
435
+ x_offsets,
436
+ y_offsets,
437
+ z_offsets,
438
+ )
439
+
440
+ return triton_jagged_jagged_bmm_jagged_out(
441
+ x,
442
+ y.T,
443
+ max_seq_len,
444
+ x_lengths,
445
+ y_lengths,
446
+ z_lengths,
447
+ x_offsets,
448
+ y_offsets,
449
+ z_offsets,
450
+ allow_tf32,
451
+ )
452
+
453
+ @staticmethod
454
+ # pyre-fixme
455
+ def backward(ctx, grad_output: torch.Tensor):
456
+ """
457
+ z = x x y^T
458
+ x: [sum_B(M_i), D]
459
+ y: [sum_B(N_i), D]
460
+ z: [sum_B(M_i * N_i)], assuming M_i = N_i
461
+ dx = dz x (y^T)^T = > dx = dz x y
462
+ d(y^T) = x^T x dz => dy = dz^T x x
463
+ """
464
+ (
465
+ x,
466
+ y,
467
+ x_lengths,
468
+ y_lengths,
469
+ z_lengths,
470
+ x_offsets,
471
+ y_offsets,
472
+ z_offsets,
473
+ ) = ctx.saved_tensors
474
+
475
+ grad_x = triton_array_jagged_bmm_jagged_out(
476
+ grad_output,
477
+ y,
478
+ z_lengths,
479
+ y_lengths,
480
+ x_lengths,
481
+ z_offsets,
482
+ y_offsets,
483
+ x_offsets,
484
+ ctx.max_seq_len,
485
+ ctx.allow_tf32,
486
+ transpose=0,
487
+ )
488
+ grad_y = triton_array_jagged_bmm_jagged_out(
489
+ grad_output,
490
+ x,
491
+ z_lengths,
492
+ x_lengths,
493
+ y_lengths,
494
+ z_offsets,
495
+ x_offsets,
496
+ y_offsets,
497
+ ctx.max_seq_len,
498
+ ctx.allow_tf32,
499
+ transpose=1,
500
+ )
501
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
502
+
503
+
504
+ def array_jagged_bmm_jagged_out(
505
+ x: torch.Tensor,
506
+ y: torch.Tensor,
507
+ x_lengths: torch.Tensor,
508
+ x_offsets: torch.Tensor,
509
+ y_lengths: torch.Tensor,
510
+ y_offsets: torch.Tensor,
511
+ z_lengths: torch.Tensor,
512
+ z_offsets: torch.Tensor,
513
+ max_seq_len: int,
514
+ allow_tf32: bool = True,
515
+ ):
516
+ return ArrayJaggedBmmNopadding.apply(
517
+ x,
518
+ y,
519
+ x_lengths,
520
+ x_offsets,
521
+ y_lengths,
522
+ y_offsets,
523
+ z_lengths,
524
+ z_offsets,
525
+ max_seq_len,
526
+ allow_tf32,
527
+ )
528
+
529
+
530
+ def jagged_jagged_bmm_jagged_out(
531
+ x: torch.Tensor,
532
+ y: torch.Tensor,
533
+ x_lengths: torch.Tensor,
534
+ x_offsets: torch.Tensor,
535
+ y_lengths: torch.Tensor,
536
+ y_offsets: torch.Tensor,
537
+ z_lengths: torch.Tensor,
538
+ z_offsets: torch.Tensor,
539
+ max_seq_len: int,
540
+ allow_tf32: bool = True,
541
+ ):
542
+ return JaggedJaggedBmmNoPadding.apply(
543
+ x,
544
+ y,
545
+ x_lengths,
546
+ x_offsets,
547
+ y_lengths,
548
+ y_offsets,
549
+ z_lengths,
550
+ z_offsets,
551
+ max_seq_len,
552
+ allow_tf32,
553
+ )
@@ -0,0 +1,52 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+
11
+ from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
12
+ dense_to_jagged,
13
+ jagged_to_dense,
14
+ )
15
+
16
+
17
+ class JaggedDenseAdd(torch.autograd.Function):
18
+ @staticmethod
19
+ # pyre-fixme
20
+ def forward(
21
+ ctx, x: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, max_seq_len: int
22
+ ):
23
+ ctx.save_for_backward(x_offsets)
24
+ ctx.max_seq_len = max_seq_len
25
+ # TODO: what should be the correct behavior when jagged values has length > max seq len?
26
+ # current behavior is to not truncate jagged values
27
+ # similar for backward grad_output
28
+ return dense_to_jagged(
29
+ y, [x_offsets], operation_function="add", operation_jagged_values=x
30
+ )[0]
31
+
32
+ @staticmethod
33
+ # pyre-fixme
34
+ def backward(ctx, grad_output: torch.Tensor):
35
+ (offsets,) = ctx.saved_tensors
36
+ grad_dense = jagged_to_dense(grad_output, [offsets], [ctx.max_seq_len])
37
+ return grad_output, None, grad_dense, None
38
+
39
+
40
+ def jagged_dense_elementwise_add(
41
+ x: torch.Tensor,
42
+ x_offsets: torch.Tensor,
43
+ y: torch.Tensor,
44
+ max_seq_len: int,
45
+ use_fbgemm_kernel: bool = True,
46
+ ):
47
+ if use_fbgemm_kernel:
48
+ return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
49
+ x, [x_offsets], y
50
+ )[0]
51
+ else:
52
+ return JaggedDenseAdd.apply(x, x_offsets, y, max_seq_len)