fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,35 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ from fbgemm_gpu.sll.meta.meta_sll import ( # noqa F401
11
+ meta_array_jagged_bmm_jagged_out,
12
+ meta_jagged2_softmax,
13
+ meta_jagged_dense_elementwise_mul_jagged_out,
14
+ meta_jagged_jagged_bmm_jagged_out,
15
+ meta_jagged_self_substraction_jagged_out,
16
+ )
17
+
18
+ # pyre-ignore[5]
19
+ op_registrations = {
20
+ "sll_jagged_self_substraction_jagged_out": {
21
+ "Meta": meta_jagged_self_substraction_jagged_out,
22
+ },
23
+ "sll_jagged_dense_elementwise_mul_jagged_out": {
24
+ "Meta": meta_jagged_dense_elementwise_mul_jagged_out,
25
+ },
26
+ "sll_jagged2_softmax": {
27
+ "AutogradMeta": meta_jagged2_softmax,
28
+ },
29
+ "sll_array_jagged_bmm_jagged_out": {
30
+ "AutogradMeta": meta_array_jagged_bmm_jagged_out,
31
+ },
32
+ "sll_jagged_jagged_bmm_jagged_out": {
33
+ "AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
34
+ },
35
+ }
@@ -0,0 +1,337 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import torch
9
+
10
+
11
+ def meta_jagged_self_substraction_jagged_out(
12
+ jagged_A: torch.Tensor,
13
+ offsets_a: torch.Tensor,
14
+ offsets_b: torch.Tensor,
15
+ max_seq_len: int,
16
+ ) -> torch.Tensor:
17
+ return torch.empty(
18
+ [torch.library.get_ctx().new_dynamic_size()],
19
+ dtype=jagged_A.dtype,
20
+ device=jagged_A.device,
21
+ )
22
+
23
+
24
+ class MetaJaggedDenseElementwiseMul(torch.autograd.Function):
25
+ @staticmethod
26
+ # pyre-fixme
27
+ def forward(
28
+ ctx, # pyre-ignore [2]
29
+ x: torch.Tensor,
30
+ y: torch.Tensor,
31
+ x_seq_lengths: torch.Tensor,
32
+ x_offsets: torch.Tensor,
33
+ max_seq_len: int,
34
+ ) -> torch.Tensor:
35
+ ctx.max_seq_len = max_seq_len
36
+
37
+ ctx.save_for_backward(
38
+ x,
39
+ y,
40
+ x_seq_lengths,
41
+ x_offsets,
42
+ )
43
+
44
+ total_L = x.size(0)
45
+ jagged_C = torch.zeros((total_L), device=x.device, dtype=x.dtype)
46
+
47
+ return jagged_C
48
+
49
+ @staticmethod
50
+ # pyre-fixme
51
+ def backward(ctx, grad_output: torch.Tensor):
52
+ (
53
+ x,
54
+ y,
55
+ x_seq_lengths,
56
+ x_offsets,
57
+ ) = ctx.saved_tensors
58
+
59
+ total_L = grad_output.size(0)
60
+ jagged_C = torch.zeros(
61
+ (total_L), device=grad_output.device, dtype=grad_output.dtype
62
+ )
63
+
64
+ return jagged_C, None, None, None, None
65
+
66
+
67
+ def meta_jagged_dense_elementwise_mul_jagged_out(
68
+ x: torch.Tensor,
69
+ y: torch.Tensor,
70
+ x_seq_lengths: torch.Tensor,
71
+ x_offsets: torch.Tensor,
72
+ max_seq_len: int,
73
+ ) -> torch.Tensor:
74
+ return MetaJaggedDenseElementwiseMul.apply(
75
+ x,
76
+ y,
77
+ x_seq_lengths,
78
+ x_offsets,
79
+ max_seq_len,
80
+ )
81
+
82
+
83
+ class Jagged2SoftmaxMeta(torch.autograd.Function):
84
+ @staticmethod
85
+ # pyre-fixme
86
+ def forward(
87
+ # pyre-fixme[2]: Parameter must be annotated.
88
+ ctx,
89
+ x: torch.Tensor,
90
+ x_offsets: torch.Tensor,
91
+ row_offsets: torch.Tensor,
92
+ head_offsets: torch.Tensor,
93
+ max_seq_len_row: int,
94
+ max_seq_len_head: int,
95
+ transpose: bool = True,
96
+ ) -> torch.Tensor:
97
+ y = torch.rand(x.size(0), device=x.device, dtype=x.dtype)
98
+
99
+ ctx.save_for_backward(y, x_offsets, row_offsets, head_offsets)
100
+ ctx.max_seq_len_row = max_seq_len_row
101
+ ctx.max_seq_len_head = max_seq_len_head
102
+
103
+ return y
104
+
105
+ @staticmethod
106
+ # pyre-fixme
107
+ def backward(ctx, grad_output: torch.Tensor):
108
+ y, x_offsets, row_offsets, head_offsets = ctx.saved_tensors
109
+ grad = torch.rand(y.size(), device=y.device, dtype=y.dtype)
110
+
111
+ return grad, None, None, None, None, None, None
112
+
113
+
114
+ def meta_jagged2_softmax(
115
+ x: torch.Tensor,
116
+ offsets: torch.Tensor,
117
+ offsets_total: torch.Tensor,
118
+ max_seq_len: int,
119
+ transpose: bool,
120
+ ) -> torch.Tensor:
121
+ """
122
+ Meta version of jagged2 softmax: [sum(softmax([B_i, B_i]))]
123
+ """
124
+ return Jagged2SoftmaxMeta.apply(
125
+ x,
126
+ offsets_total,
127
+ offsets,
128
+ offsets,
129
+ max_seq_len,
130
+ max_seq_len,
131
+ transpose,
132
+ )
133
+
134
+
135
+ class ArrayJaggedBmmNopadding(torch.autograd.Function):
136
+ """
137
+ Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
138
+ z = X * Y
139
+ x: [Sum_B(N_i, N_i)]
140
+ y: [sum_B(N_i), D]
141
+ z: [sum_B(N_i), D]
142
+ """
143
+
144
+ @staticmethod
145
+ # pyre-fixme
146
+ def forward(
147
+ # pyre-fixme[2]: Parameter must be annotated.
148
+ ctx,
149
+ x: torch.Tensor,
150
+ y: torch.Tensor,
151
+ x_lengths: torch.Tensor,
152
+ x_offsets: torch.Tensor,
153
+ y_lengths: torch.Tensor,
154
+ y_offsets: torch.Tensor,
155
+ z_lengths: torch.Tensor,
156
+ z_offsets: torch.Tensor,
157
+ max_seq_len: int,
158
+ # pyre-fixme[2]: Parameter must be annotated.
159
+ allow_tf32,
160
+ ) -> torch.Tensor:
161
+ ctx.allow_tf32 = allow_tf32
162
+ ctx.max_seq_len = max_seq_len
163
+
164
+ ctx.save_for_backward(
165
+ x,
166
+ y,
167
+ x_lengths,
168
+ y_lengths,
169
+ z_lengths,
170
+ x_offsets,
171
+ y_offsets,
172
+ z_offsets,
173
+ )
174
+
175
+ D = y.size(1)
176
+ L = y.size(0)
177
+ # gradients of the emb vectors beyond max_seq_len is set to zeros
178
+ jagged_C = torch.zeros((L, D), device=y.device, dtype=y.dtype)
179
+ return jagged_C
180
+
181
+ @staticmethod
182
+ # pyre-fixme
183
+ def backward(ctx, grad_output: torch.Tensor):
184
+ """
185
+ z = X * Y
186
+ dX = dZ * YT
187
+ dY = XT * dZ
188
+
189
+ dZ: [sum_B(N_i), D]
190
+ YT: [D, sum_B(N_i)] call Y.T
191
+ XT: transposed
192
+ Z: [sum_B(N_i), D]
193
+ """
194
+
195
+ (
196
+ x,
197
+ y,
198
+ x_lengths,
199
+ y_lengths,
200
+ z_lengths,
201
+ x_offsets,
202
+ y_offsets,
203
+ z_offsets,
204
+ ) = ctx.saved_tensors
205
+
206
+ grad_x = torch.zeros(
207
+ (x.size()), device=grad_output.device, dtype=grad_output.dtype
208
+ )
209
+
210
+ # gradients of the emb vectors beyond max_seq_len is set to zeros
211
+ grad_y = torch.zeros(
212
+ y.size(), device=grad_output.device, dtype=grad_output.dtype
213
+ )
214
+ return (
215
+ grad_x,
216
+ grad_y,
217
+ None,
218
+ None,
219
+ None,
220
+ None,
221
+ None,
222
+ None,
223
+ None,
224
+ None,
225
+ )
226
+
227
+
228
+ # pyre-fixme[3]: Return type must be annotated.
229
+ def meta_array_jagged_bmm_jagged_out(
230
+ x: torch.Tensor,
231
+ y: torch.Tensor,
232
+ x_lengths: torch.Tensor,
233
+ x_offsets: torch.Tensor,
234
+ y_lengths: torch.Tensor,
235
+ y_offsets: torch.Tensor,
236
+ z_lengths: torch.Tensor,
237
+ z_offsets: torch.Tensor,
238
+ max_seq_len: int,
239
+ allow_tf32: bool = True,
240
+ ):
241
+ return ArrayJaggedBmmNopadding.apply(
242
+ x,
243
+ y,
244
+ x_lengths,
245
+ x_offsets,
246
+ y_lengths,
247
+ y_offsets,
248
+ z_lengths,
249
+ z_offsets,
250
+ max_seq_len,
251
+ allow_tf32,
252
+ )
253
+
254
+
255
+ class JaggedJaggedBmmNoPaddingMeta(torch.autograd.Function):
256
+ @staticmethod
257
+ # pyre-fixme
258
+ def forward(
259
+ # pyre-fixme[2]: Parameter must be annotated.
260
+ ctx,
261
+ x: torch.Tensor,
262
+ y: torch.Tensor,
263
+ x_lengths: torch.Tensor,
264
+ x_offsets: torch.Tensor,
265
+ y_lengths: torch.Tensor,
266
+ y_offsets: torch.Tensor,
267
+ z_lengths: torch.Tensor,
268
+ z_offsets: torch.Tensor,
269
+ max_seq_len: int,
270
+ # pyre-fixme[2]: Parameter must be annotated.
271
+ allow_tf32,
272
+ ):
273
+ assert x.size(1) == y.size(0), "incompatible dimensions"
274
+
275
+ ctx.allow_tf32 = allow_tf32
276
+ ctx.max_seq_len = max_seq_len
277
+
278
+ ctx.save_for_backward(
279
+ x,
280
+ y,
281
+ x_lengths,
282
+ y_lengths,
283
+ z_lengths,
284
+ x_offsets,
285
+ y_offsets,
286
+ z_offsets,
287
+ )
288
+
289
+ # pyre-fixme[6]: For 1st argument expected `Sequence[Union[int, SymInt]]`
290
+ # but got `Tensor`.
291
+ c = torch.rand((z_lengths.sum()), device=x.device, dtype=x.dtype)
292
+ return c
293
+
294
+ @staticmethod
295
+ # pyre-fixme
296
+ def backward(ctx, grad_output: torch.Tensor):
297
+ (
298
+ x,
299
+ y,
300
+ x_lengths,
301
+ y_lengths,
302
+ z_lengths,
303
+ x_offsets,
304
+ y_offsets,
305
+ z_offsets,
306
+ ) = ctx.saved_tensors
307
+
308
+ grad_x = torch.rand(x.size(), device=x.device, dtype=x.dtype)
309
+ grad_y = torch.rand(y.size(), device=y.device, dtype=y.dtype)
310
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
311
+
312
+
313
+ # pyre-fixme[3]: Return type must be annotated.
314
+ def meta_jagged_jagged_bmm_jagged_out(
315
+ x: torch.Tensor,
316
+ y: torch.Tensor,
317
+ x_lengths: torch.Tensor,
318
+ x_offsets: torch.Tensor,
319
+ y_lengths: torch.Tensor,
320
+ y_offsets: torch.Tensor,
321
+ z_lengths: torch.Tensor,
322
+ z_offsets: torch.Tensor,
323
+ max_seq_len: int,
324
+ allow_tf32: bool = True,
325
+ ):
326
+ return JaggedJaggedBmmNoPaddingMeta.apply(
327
+ x,
328
+ y,
329
+ x_lengths,
330
+ x_offsets,
331
+ y_lengths,
332
+ y_offsets,
333
+ z_lengths,
334
+ z_offsets,
335
+ max_seq_len,
336
+ allow_tf32,
337
+ )
@@ -0,0 +1,127 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
11
+ dense_jagged_cat_jagged_out,
12
+ )
13
+
14
+ from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401
15
+ jagged2_to_padded_dense,
16
+ Jagged2ToPaddedDense, # noqa F401
17
+ )
18
+
19
+ from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
20
+ jagged_dense_bmm,
21
+ jagged_jagged_bmm,
22
+ JaggedDenseBmm, # noqa F401
23
+ JaggedJaggedBmm, # noqa F401
24
+ )
25
+
26
+ from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
27
+ array_jagged_bmm_jagged_out,
28
+ ArrayJaggedBmmNopadding, # noqa F401
29
+ jagged_jagged_bmm_jagged_out,
30
+ JaggedJaggedBmmNoPadding, # noqa F401
31
+ triton_array_jagged_bmm_jagged_out, # noqa F401
32
+ triton_jagged_jagged_bmm_jagged_out, # noqa F401
33
+ )
34
+
35
+ from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
36
+ jagged_dense_elementwise_add,
37
+ JaggedDenseAdd, # noqa F401
38
+ )
39
+
40
+ from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
41
+ jagged_dense_elementwise_mul_jagged_out,
42
+ JaggedDenseElementwiseMul, # noqa F401
43
+ )
44
+
45
+ from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
46
+ jagged_dense_flash_attention,
47
+ JaggedDenseFlashAttention, # noqa F401
48
+ )
49
+
50
+ from fbgemm_gpu.sll.triton.triton_jagged_flash_attention_basic import ( # noqa F401
51
+ jagged_flash_attention_basic,
52
+ JaggedFlashAttentionBasic, # noqa F401
53
+ )
54
+
55
+ from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
56
+ triton_jagged_self_substraction_jagged_out,
57
+ )
58
+
59
+ from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
60
+ jagged2_softmax,
61
+ Jagged2Softmax, # noqa F401
62
+ jagged_softmax,
63
+ JaggedSoftmax, # noqa F401
64
+ )
65
+
66
+ from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
67
+ multi_head_jagged_flash_attention,
68
+ MultiHeadJaggedFlashAttention, # noqa F401
69
+ )
70
+
71
+ # pyre-ignore[5]
72
+ op_registrations = {
73
+ "sll_dense_jagged_cat_jagged_out": {
74
+ "CUDA": dense_jagged_cat_jagged_out,
75
+ },
76
+ "sll_jagged_dense_bmm": {
77
+ "CUDA": jagged_dense_bmm,
78
+ "AutogradCUDA": jagged_dense_bmm,
79
+ },
80
+ "sll_jagged_jagged_bmm": {
81
+ "CUDA": jagged_jagged_bmm,
82
+ "AutogradCUDA": jagged_jagged_bmm,
83
+ },
84
+ "sll_jagged2_to_padded_dense": {
85
+ "CUDA": jagged2_to_padded_dense,
86
+ "AutogradCUDA": jagged2_to_padded_dense,
87
+ },
88
+ "sll_array_jagged_bmm_jagged_out": {
89
+ "CUDA": array_jagged_bmm_jagged_out,
90
+ "AutogradCUDA": array_jagged_bmm_jagged_out,
91
+ },
92
+ "sll_jagged_jagged_bmm_jagged_out": {
93
+ "CUDA": jagged_jagged_bmm_jagged_out,
94
+ "AutogradCUDA": jagged_jagged_bmm_jagged_out,
95
+ },
96
+ "sll_jagged_softmax": {
97
+ "CUDA": jagged_softmax,
98
+ "AutogradCUDA": jagged_softmax,
99
+ },
100
+ "sll_jagged2_softmax": {
101
+ "CUDA": jagged2_softmax,
102
+ "AutogradCUDA": jagged2_softmax,
103
+ },
104
+ "sll_jagged_dense_elementwise_add": {
105
+ "CUDA": jagged_dense_elementwise_add,
106
+ "AutogradCUDA": jagged_dense_elementwise_add,
107
+ },
108
+ "sll_jagged_dense_flash_attention": {
109
+ "CUDA": jagged_dense_flash_attention,
110
+ "AutogradCUDA": jagged_dense_flash_attention,
111
+ },
112
+ "sll_jagged_flash_attention_basic": {
113
+ "CUDA": jagged_flash_attention_basic,
114
+ "AutogradCUDA": jagged_flash_attention_basic,
115
+ },
116
+ "sll_multi_head_jagged_flash_attention": {
117
+ "CUDA": multi_head_jagged_flash_attention,
118
+ "AutogradCUDA": multi_head_jagged_flash_attention,
119
+ },
120
+ "sll_jagged_self_substraction_jagged_out": {
121
+ "CUDA": triton_jagged_self_substraction_jagged_out,
122
+ },
123
+ "sll_jagged_dense_elementwise_mul_jagged_out": {
124
+ "CUDA": jagged_dense_elementwise_mul_jagged_out,
125
+ "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
126
+ },
127
+ }
@@ -0,0 +1,38 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+
11
+
12
+ def next_power_of_two(N: int) -> int:
13
+ if N > 4096:
14
+ raise Exception(f"{N} is too large that is not supported yet")
15
+
16
+ if N > 2048:
17
+ return 4096
18
+ elif N > 1024:
19
+ return 2048
20
+ elif N > 512:
21
+ return 1024
22
+ elif N > 256:
23
+ return 512
24
+ elif N > 128:
25
+ return 256
26
+ elif N > 64:
27
+ return 128
28
+ elif N > 32:
29
+ return 64
30
+ else:
31
+ return 32
32
+
33
+
34
+ def expect_contiguous(x: torch.Tensor) -> torch.Tensor:
35
+ if not x.is_contiguous():
36
+ return x.contiguous()
37
+ else:
38
+ return x
@@ -0,0 +1,72 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+
14
+ @triton.jit
15
+ def dense_jagged_cat_jagged_out_kernel(
16
+ a_ptr, # dense
17
+ b_ptr, # jagged
18
+ c_ptr, # jagged
19
+ b_offsets_ptr,
20
+ c_offsets_ptr,
21
+ max_seq_len,
22
+ BLOCK_SIZE: tl.constexpr,
23
+ ):
24
+ pid_batch = tl.program_id(0)
25
+ b_start = tl.load(b_offsets_ptr + pid_batch)
26
+ b_end = tl.load(b_offsets_ptr + pid_batch + 1)
27
+ c_start = b_start + pid_batch
28
+ N = b_end - b_start
29
+ N = tl.minimum(N, max_seq_len)
30
+
31
+ a = tl.load(a_ptr + pid_batch)
32
+ tl.store(c_ptr + c_start, a)
33
+
34
+ offs_k = tl.arange(0, BLOCK_SIZE)
35
+ for k in range(0, N, BLOCK_SIZE):
36
+ b_offset = k + offs_k
37
+ b_ptrs = b_ptr + b_start + b_offset
38
+ b = tl.load(b_ptrs, mask=b_offset < N, other=0.0)
39
+ tl.store(c_ptr + c_start + 1 + b_offset, b, mask=b_offset < N)
40
+ tl.store(c_offsets_ptr + pid_batch, b_start + pid_batch)
41
+
42
+
43
+ def dense_jagged_cat_jagged_out(
44
+ a: torch.Tensor,
45
+ b: torch.Tensor,
46
+ b_offsets: torch.Tensor,
47
+ max_seq_len: int,
48
+ ):
49
+ assert a.is_contiguous()
50
+ assert b.is_contiguous()
51
+ assert b_offsets.is_contiguous()
52
+ B = a.size(0)
53
+ BLOCK_SIZE = 128
54
+ c = torch.zeros(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
55
+ c_offsets = torch.empty(
56
+ b_offsets.size(0), dtype=b_offsets.dtype, device=b_offsets.device
57
+ ) # B + 1
58
+
59
+ dense_jagged_cat_jagged_out_kernel[(B,)](
60
+ a,
61
+ b,
62
+ c,
63
+ b_offsets,
64
+ c_offsets,
65
+ max_seq_len,
66
+ # pyre-fixme[6]: For 7th argument expected `constexpr` but got `int`.
67
+ BLOCK_SIZE,
68
+ )
69
+
70
+ c_offsets[-1] = b_offsets[-1] + B
71
+
72
+ return c, c_offsets