fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,740 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import get_fp8_constants
16
+
17
+
18
+ # Function APIs
19
+ def gather_scale_dense_tokens(
20
+ x: torch.Tensor,
21
+ token_indices: torch.Tensor,
22
+ expert_indices: torch.Tensor,
23
+ scores: torch.Tensor,
24
+ valid_token_count: Optional[torch.Tensor] = None,
25
+ ) -> torch.Tensor:
26
+ """
27
+ Gather and scale dense tokens along 1D indices.
28
+
29
+ For each input token, token_indices[i] is the index of the token in the input sequence.
30
+ expert_indices[i] is the index of the expert that the token is assigned to.
31
+ scores[i] is the score of the token.
32
+
33
+ For each expert, the tokens assigned to this expert are gathered from the input sequence,
34
+ and then their scores are multiplied element-wise.
35
+
36
+ valid_token_count is an optional tensor that can be used to filter out some tokens.
37
+ If it is provided, the function will only consider the first valid_token_count tokens in the input sequence.
38
+
39
+ The function returns a tensor of shape (a, D), where a is the number of tokens and D is the input dimension.
40
+
41
+ Args:
42
+ x (torch.Tensor): input tensor of shape (T, D)
43
+ token_indices (torch.Tensor): token indices of shape (a,)
44
+ expert_indices (torch.Tensor): expert indices of shape (a,)
45
+ scores (torch.Tensor): scores of shape (T, E)
46
+ valid_token_count (torch.Tensor, optional): valid token count of shape (,)
47
+
48
+ Returns:
49
+ torch.Tensor: output tensor of shape (a, D)
50
+ """
51
+ T, D = x.shape
52
+ E = scores.shape[1]
53
+ # a = K * T
54
+ a = token_indices.shape[0]
55
+
56
+ out = torch.empty((a, D), device=x.device, dtype=x.dtype)
57
+ if a == 0 or D == 0:
58
+ return out
59
+
60
+ assert x.is_contiguous()
61
+ assert token_indices.is_contiguous()
62
+ assert expert_indices.is_contiguous()
63
+
64
+ assert tuple(token_indices.shape) == (a,)
65
+ assert tuple(expert_indices.shape) == (a,)
66
+ assert tuple(scores.shape) == (T, E)
67
+
68
+ stride_t = scores.stride(0)
69
+ stride_e = scores.stride(1)
70
+
71
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
72
+ if a >= NUM_SMS:
73
+ BLOCK_D_OUTER = D
74
+ BLOCK_D_INNER = 1024
75
+ assert D % BLOCK_D_INNER == 0
76
+ else:
77
+ BLOCK_D_OUTER = 512
78
+ BLOCK_D_INNER = 256
79
+ assert D % BLOCK_D_OUTER == 0
80
+ grid = (a, D // BLOCK_D_OUTER)
81
+ _fbgemm_gather_scale_dense_tokens[grid](
82
+ out,
83
+ x,
84
+ token_indices,
85
+ expert_indices,
86
+ scores,
87
+ stride_t,
88
+ stride_e,
89
+ valid_token_count,
90
+ D, # pyre-ignore
91
+ BLOCK_D_OUTER, # pyre-ignore
92
+ BLOCK_D_INNER, # pyre-ignore
93
+ )
94
+ return out
95
+
96
+
97
+ def gather_scale_quant_dense_tokens(
98
+ x: torch.Tensor,
99
+ token_indices: torch.Tensor,
100
+ expert_indices: torch.Tensor,
101
+ scores: torch.Tensor,
102
+ scale_ub: Optional[torch.Tensor] = None,
103
+ valid_token_count: Optional[torch.Tensor] = None,
104
+ ) -> tuple[torch.Tensor, torch.Tensor]:
105
+ """
106
+ Gather, scale, and quantize dense tokens along 1D indices.
107
+
108
+ For each input token, token_indices[i] is the index of the token in the input sequence.
109
+ expert_indices[i] is the index of the expert that the token is assigned to.
110
+ scores[i] is the score of the token.
111
+
112
+ For each expert, the tokens assigned to this expert are gathered from the input sequence,
113
+ and then their scores are multiplied element-wise, and then quantized to FP8.
114
+
115
+ valid_token_count is an optional tensor that can be used to filter out some tokens.
116
+ If it is provided, the function will only consider the first valid_token_count tokens in the input sequence.
117
+
118
+ The function returns a tensor of shape (a, D), where a is the number of tokens and D is the input dimension.
119
+
120
+ Args:
121
+ x (torch.Tensor): input tensor of shape (T, D)
122
+ token_indices (torch.Tensor): token indices of shape (a,)
123
+ expert_indices (torch.Tensor): expert indices of shape (a,)
124
+ scores (torch.Tensor): scores of shape (T, E)
125
+ scale_ub (torch.Tensor, optional): scale upper bound of shape (1,)
126
+ valid_token_count (torch.Tensor, optional): valid token count of shape (1,)
127
+
128
+ Returns:
129
+ torch.Tensor: output tensor of shape (a, D)
130
+ """
131
+ T, D = x.shape
132
+ E = scores.shape[1]
133
+ # a = K * T
134
+ a = token_indices.shape[0]
135
+
136
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
137
+
138
+ assert x.is_contiguous()
139
+ assert token_indices.is_contiguous()
140
+ assert expert_indices.is_contiguous()
141
+
142
+ assert tuple(token_indices.shape) == (a,)
143
+ assert tuple(expert_indices.shape) == (a,)
144
+ assert tuple(scores.shape) == (T, E)
145
+
146
+ stride_t = scores.stride(0)
147
+ stride_e = scores.stride(1)
148
+
149
+ out = torch.empty((a, D), device="cuda", dtype=pt_dtype)
150
+ out_scale = torch.empty((a,), device="cuda", dtype=torch.float32)
151
+
152
+ grid = (a,)
153
+ _fbgemm_gather_scale_fp8_rowwise_quant_dense_tokens[grid](
154
+ out,
155
+ out_scale,
156
+ x,
157
+ token_indices,
158
+ expert_indices,
159
+ scores,
160
+ scale_ub,
161
+ stride_t,
162
+ stride_e,
163
+ valid_token_count,
164
+ D,
165
+ TL_FP8_DTYPE=tl_dtype,
166
+ MAX_FP8=max_fp8,
167
+ EPS=eps,
168
+ CLAMP_MAX=scale_ub is not None,
169
+ )
170
+ return out, out_scale
171
+
172
+
173
+ def scatter_add_dense_tokens(
174
+ out_tokens: torch.Tensor, # [T, D]
175
+ in_tokens: torch.Tensor, # [a, D]
176
+ token_indices: torch.Tensor, # [a]
177
+ valid_token_count: Optional[torch.Tensor] = None,
178
+ ) -> None:
179
+ """
180
+ Scatter add dense tokens along 1D indices.
181
+
182
+ Args:
183
+ out_tokens (torch.Tensor): output tensor of shape (T, D)
184
+ in_tokens (torch.Tensor): input tensor of shape (a, D)
185
+ token_indices (torch.Tensor): token indices of shape (a,)
186
+ valid_token_count (torch.Tensor, optional): valid token count of shape (1,)
187
+
188
+ Returns:
189
+ None
190
+ """
191
+
192
+ assert torch.version.hip is not None or (
193
+ torch.version.cuda is not None and torch.version.cuda >= "12.4"
194
+ ), "Requires CUDA version 12.4 or later on Nvidia GPUs!"
195
+
196
+ assert in_tokens.is_contiguous()
197
+ assert token_indices.is_contiguous()
198
+ assert out_tokens.is_contiguous()
199
+
200
+ a, D = in_tokens.shape
201
+ if a == 0:
202
+ return
203
+ assert token_indices.shape == (a,)
204
+ assert out_tokens.ndim == 2 and out_tokens.shape[1] == D
205
+
206
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
207
+ if a >= NUM_SMS:
208
+ BLOCK_D_OUTER = D
209
+ BLOCK_D_INNER = 1024
210
+ else:
211
+ BLOCK_D_OUTER = 512
212
+ BLOCK_D_INNER = 256
213
+ while D % BLOCK_D_OUTER != 0:
214
+ BLOCK_D_OUTER //= 2
215
+ while D % BLOCK_D_INNER != 0:
216
+ BLOCK_D_INNER //= 2
217
+
218
+ grid = (a, D // BLOCK_D_OUTER)
219
+ _fbgemm_scatter_add_dense_tokens[grid](
220
+ out_tokens,
221
+ in_tokens,
222
+ token_indices,
223
+ valid_token_count,
224
+ D, # pyre-ignore
225
+ BLOCK_D_OUTER, # pyre-ignore
226
+ BLOCK_D_INNER, # pyre-ignore
227
+ )
228
+
229
+
230
+ def scatter_add_padded_tokens(
231
+ in_tokens: torch.Tensor, # [EP, T_K, D]
232
+ token_counts: torch.Tensor, # [E]
233
+ token_indices: torch.Tensor, # [T_K]
234
+ out_tokens: torch.Tensor, # [T, D]
235
+ ) -> None:
236
+ """
237
+ Scatter add valid tokens based on token counts metadata.
238
+
239
+ Args:
240
+ in_tokens (torch.Tensor): input tensor of shape (EP, T_K, D)
241
+ token_counts (torch.Tensor): token counts of shape (E,)
242
+ token_indices (torch.Tensor): token indices of shape (T_K,)
243
+ out_tokens (torch.Tensor): output tensor of shape (T, D)
244
+
245
+ Returns:
246
+ None
247
+ """
248
+ assert torch.version.hip is not None or (
249
+ torch.version.cuda is not None and torch.version.cuda >= "12.4"
250
+ ), "Requires CUDA version 12.4 or later on Nvidia GPUs!"
251
+
252
+ assert in_tokens.is_contiguous()
253
+ assert token_counts.is_contiguous()
254
+ assert token_indices.is_contiguous()
255
+ assert out_tokens.is_contiguous()
256
+
257
+ EP, T_K, D = in_tokens.shape
258
+ E = token_counts.shape[0]
259
+ assert tuple(token_indices.shape) == (T_K,)
260
+ assert T_K % out_tokens.shape[0] == 0 and out_tokens.shape[1] == D
261
+
262
+ def grid(META):
263
+ return (
264
+ E,
265
+ META["SPLIT_T"],
266
+ )
267
+
268
+ T_BUCKET_CAP = 16384
269
+ T_BUCKET = min(triton.next_power_of_2(T_K), T_BUCKET_CAP)
270
+ BLOCK_E = max(triton.next_power_of_2(E), 8)
271
+ _fbgemm_scatter_add_padded_tokens[grid](
272
+ in_tokens,
273
+ token_counts,
274
+ token_indices,
275
+ out_tokens,
276
+ EP,
277
+ E,
278
+ T_BUCKET,
279
+ T_K,
280
+ D,
281
+ BLOCK_E,
282
+ )
283
+
284
+
285
+ # Torch Custom Op Registrations
286
+ _GATHER_SCALE_DENSE_TOKENS_OP_NAME = "fbgemm::gather_scale_dense_tokens"
287
+
288
+ torch.library.define(
289
+ "fbgemm::gather_scale_dense_tokens",
290
+ "(Tensor x, Tensor token_indices, Tensor expert_indices, Tensor scores, Tensor? valid_token_count=None) -> Tensor",
291
+ )
292
+
293
+
294
+ @torch.library.impl(_GATHER_SCALE_DENSE_TOKENS_OP_NAME, "Meta")
295
+ def gather_scale_dense_tokens_meta(
296
+ x,
297
+ token_indices,
298
+ expert_indices,
299
+ scores,
300
+ valid_token_count=None,
301
+ ):
302
+ D = x.shape[1]
303
+ a = token_indices.shape[0]
304
+ return x.new_empty((a, D))
305
+
306
+
307
+ @torch.library.impl(_GATHER_SCALE_DENSE_TOKENS_OP_NAME, "CUDA")
308
+ def gather_scale_dense_tokens_cuda(
309
+ x,
310
+ token_indices,
311
+ expert_indices,
312
+ scores,
313
+ valid_token_count=None,
314
+ ):
315
+ return gather_scale_dense_tokens(
316
+ x,
317
+ token_indices,
318
+ expert_indices,
319
+ scores,
320
+ valid_token_count,
321
+ )
322
+
323
+
324
+ _GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME = "fbgemm::gather_scale_quant_dense_tokens"
325
+
326
+ torch.library.define(
327
+ "fbgemm::gather_scale_quant_dense_tokens",
328
+ "(Tensor x, Tensor token_indices, Tensor expert_indices, Tensor scores, Tensor? scale_ub=None, Tensor? valid_token_count=None) -> Tensor",
329
+ )
330
+
331
+
332
+ @torch.library.impl(_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME, "Meta")
333
+ def gather_scale_quant_dense_tokens_meta(
334
+ x,
335
+ token_indices,
336
+ expert_indices,
337
+ scores,
338
+ scale_ub=None,
339
+ valid_token_count=None,
340
+ ):
341
+ D = x.shape[1]
342
+ a = token_indices.shape[0]
343
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
344
+ return torch.empty((a, D), device=x.device, dtype=pt_dtype), torch.empty(
345
+ (a,), device=x.device, dtype=torch.float32
346
+ )
347
+
348
+
349
+ @torch.library.impl(_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME, "CUDA")
350
+ def gather_scale_quant_dense_tokens_cuda(
351
+ x,
352
+ token_indices,
353
+ expert_indices,
354
+ scores,
355
+ scale_ub=None,
356
+ valid_token_count=None,
357
+ ):
358
+ return gather_scale_quant_dense_tokens(
359
+ x,
360
+ token_indices,
361
+ expert_indices,
362
+ scores,
363
+ scale_ub,
364
+ valid_token_count,
365
+ )
366
+
367
+
368
+ _SCATTER_ADD_DENSE_TOKENS_OP_NAME = "fbgemm::scatter_add_dense_tokens"
369
+
370
+ torch.library.define(
371
+ "fbgemm::scatter_add_dense_tokens",
372
+ "(Tensor out_tokens, Tensor in_tokens, Tensor token_indices, Tensor? valid_token_count=None) -> None",
373
+ )
374
+
375
+
376
+ @torch.library.impl(_SCATTER_ADD_DENSE_TOKENS_OP_NAME, "Meta")
377
+ def scatter_add_dense_tokens_meta(
378
+ out_tokens,
379
+ in_tokens,
380
+ token_indices,
381
+ valid_token_count=None,
382
+ ):
383
+ return None
384
+
385
+
386
+ @torch.library.impl(_SCATTER_ADD_DENSE_TOKENS_OP_NAME, "CUDA")
387
+ def scatter_add_dense_tokens_cuda(
388
+ out_tokens,
389
+ in_tokens,
390
+ token_indices,
391
+ valid_token_count=None,
392
+ ):
393
+ return scatter_add_dense_tokens(
394
+ out_tokens, in_tokens, token_indices, valid_token_count
395
+ )
396
+
397
+
398
+ _SCATTER_ADD_PADDED_TOKENS_OP_NAME = "fbgemm::scatter_add_padded_tokens"
399
+
400
+ torch.library.define(
401
+ "fbgemm::scatter_add_padded_tokens",
402
+ "(Tensor in_tokens, Tensor token_counts, Tensor token_indices, Tensor out_tokens) -> None",
403
+ )
404
+
405
+
406
+ @torch.library.impl(_SCATTER_ADD_PADDED_TOKENS_OP_NAME, "Meta")
407
+ def scatter_add_padded_tokens_meta(
408
+ in_tokens,
409
+ token_counts,
410
+ token_indices,
411
+ out_tokens,
412
+ ):
413
+ return None
414
+
415
+
416
+ @torch.library.impl(_SCATTER_ADD_PADDED_TOKENS_OP_NAME, "CUDA")
417
+ def scatter_add_padded_tokens_cuda(
418
+ in_tokens,
419
+ token_counts,
420
+ token_indices,
421
+ out_tokens,
422
+ ):
423
+ return scatter_add_padded_tokens(
424
+ in_tokens,
425
+ token_counts,
426
+ token_indices,
427
+ out_tokens,
428
+ )
429
+
430
+
431
+ # Kernel Implementations
432
+ @triton.jit
433
+ def _fbgemm_gather_scale_dense_tokens(
434
+ out,
435
+ x,
436
+ token_indices,
437
+ expert_indices,
438
+ scores,
439
+ stride_t,
440
+ stride_e,
441
+ valid_token_count,
442
+ D: tl.constexpr,
443
+ BLOCK_D_OUTER: tl.constexpr,
444
+ BLOCK_D_INNER: tl.constexpr,
445
+ ):
446
+ output_token_index = tl.program_id(0)
447
+ feature_offset = tl.program_id(1) * BLOCK_D_OUTER
448
+
449
+ if valid_token_count is not None:
450
+ valid_token_count = tl.load(
451
+ valid_token_count, None, eviction_policy="evict_last"
452
+ )
453
+ if output_token_index >= valid_token_count:
454
+ return
455
+
456
+ input_token_index = tl.load(
457
+ token_indices + output_token_index, None, eviction_policy="evict_last"
458
+ )
459
+ input_expert_index = tl.load(
460
+ expert_indices + output_token_index, None, eviction_policy="evict_last"
461
+ )
462
+
463
+ input_score = tl.load(
464
+ scores + input_token_index * stride_t + input_expert_index * stride_e,
465
+ None,
466
+ eviction_policy="evict_last",
467
+ ).to(tl.float32)
468
+
469
+ for _ in range(0, BLOCK_D_OUTER // BLOCK_D_INNER):
470
+ input_token_value = tl.load(
471
+ x
472
+ + input_token_index.to(tl.int64) * D
473
+ + feature_offset
474
+ + tl.arange(0, BLOCK_D_INNER)[:],
475
+ None,
476
+ ).to(tl.float32)
477
+ output_token_value = input_token_value * input_score
478
+
479
+ tl.store(
480
+ out
481
+ + output_token_index.to(tl.int64) * D
482
+ + feature_offset
483
+ + tl.arange(0, BLOCK_D_INNER)[:],
484
+ output_token_value,
485
+ None,
486
+ )
487
+ feature_offset += BLOCK_D_INNER
488
+
489
+
490
+ @triton.jit
491
+ def _fbgemm_scatter_add_dense_tokens(
492
+ out_tokens,
493
+ in_tokens,
494
+ token_indices,
495
+ valid_token_count,
496
+ D: tl.constexpr,
497
+ BLOCK_D_OUTER: tl.constexpr,
498
+ BLOCK_D_INNER: tl.constexpr,
499
+ ):
500
+ input_token_index = tl.program_id(0).to(tl.int64)
501
+ feature_offset = tl.program_id(1) * BLOCK_D_OUTER + tl.arange(0, BLOCK_D_INNER)[:]
502
+
503
+ if valid_token_count is not None:
504
+ valid_token_count = tl.load(
505
+ valid_token_count, None, eviction_policy="evict_last"
506
+ )
507
+ if input_token_index >= valid_token_count:
508
+ return
509
+
510
+ output_token_index = tl.load(
511
+ token_indices + input_token_index, None, eviction_policy="evict_last"
512
+ ).to(tl.int64)
513
+
514
+ for _ in range(0, BLOCK_D_OUTER // BLOCK_D_INNER):
515
+ input_token_value = tl.load(
516
+ in_tokens + input_token_index * D + feature_offset,
517
+ None,
518
+ eviction_policy="evict_first",
519
+ )
520
+
521
+ tl.atomic_add(
522
+ out_tokens + output_token_index * D + feature_offset,
523
+ input_token_value,
524
+ None,
525
+ sem="relaxed",
526
+ )
527
+ feature_offset += BLOCK_D_INNER
528
+
529
+
530
+ @triton.autotune(
531
+ configs=[
532
+ triton.Config({"BLOCK_D": 256}),
533
+ triton.Config({"BLOCK_D": 512}),
534
+ triton.Config({"BLOCK_D": 1024}),
535
+ ],
536
+ key=["D"],
537
+ )
538
+ @triton.jit
539
+ def _fbgemm_gather_scale_fp8_rowwise_quant_dense_tokens(
540
+ output_ptr,
541
+ output_scale_ptr,
542
+ input_ptr,
543
+ token_indices_ptr,
544
+ expert_indices_ptr,
545
+ scores_ptr,
546
+ scale_ub_ptr,
547
+ stride_t,
548
+ stride_e,
549
+ valid_token_count,
550
+ D: tl.constexpr,
551
+ TL_FP8_DTYPE: tl.constexpr,
552
+ MAX_FP8: tl.constexpr,
553
+ EPS: tl.constexpr,
554
+ CLAMP_MAX: tl.constexpr,
555
+ BLOCK_D: tl.constexpr,
556
+ ):
557
+ tl.static_assert(D % BLOCK_D == 0, "D must be a multiple of BLOCK_D")
558
+
559
+ output_token_index = tl.program_id(0)
560
+
561
+ if valid_token_count is not None:
562
+ valid_token_count = tl.load(
563
+ valid_token_count, None, eviction_policy="evict_last"
564
+ )
565
+ if output_token_index >= valid_token_count:
566
+ return
567
+
568
+ input_token_index = tl.load(
569
+ token_indices_ptr + output_token_index, None, eviction_policy="evict_first"
570
+ )
571
+ input_expert_index = tl.load(
572
+ expert_indices_ptr + output_token_index, None, eviction_policy="evict_first"
573
+ )
574
+ input_score = tl.load(
575
+ scores_ptr + input_token_index * stride_t + input_expert_index * stride_e,
576
+ None,
577
+ eviction_policy="evict_first",
578
+ ).to(tl.float32)
579
+
580
+ row_max = 0.0
581
+ in_2d_ptr = (
582
+ input_ptr + input_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
583
+ )
584
+ for _ in range(0, D, BLOCK_D):
585
+ input_token_value = tl.load(
586
+ in_2d_ptr,
587
+ None,
588
+ eviction_policy="evict_last",
589
+ ).to(tl.float32)
590
+ output_token_value = input_token_value * input_score
591
+
592
+ tile_max = tl.max(tl.abs(output_token_value))
593
+ row_max = tl.maximum(tile_max, row_max)
594
+ in_2d_ptr += BLOCK_D
595
+
596
+ # Clamp max value appropriately.
597
+ if CLAMP_MAX:
598
+ ub = tl.load(scale_ub_ptr, eviction_policy="evict_last")
599
+ row_max = tl.clamp(row_max, EPS, ub)
600
+ else:
601
+ row_max = tl.maximum(row_max, EPS)
602
+
603
+ # Scale and quantize.
604
+ output_scale = MAX_FP8 / row_max
605
+ tl.store(output_scale_ptr + output_token_index, 1.0 / output_scale)
606
+
607
+ in_2d_ptr = (
608
+ input_ptr + input_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
609
+ )
610
+ out_2d_ptr = (
611
+ output_ptr + output_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
612
+ )
613
+ for _ in range(0, D, BLOCK_D):
614
+ # Load from L2
615
+ input_token_value = tl.load(
616
+ in_2d_ptr,
617
+ None,
618
+ eviction_policy="evict_first",
619
+ ).to(tl.float32)
620
+ # Rematerilize
621
+ output_token_value_fp8 = (input_token_value * input_score) * output_scale
622
+
623
+ # Clamp A to fp8 range to make sure there's no overflow.
624
+ # This is required for AMD. Nvidia's default saturation
625
+ # handles it, but it's nice to have anyway.
626
+ output_token_value_fp8 = tl.clamp(output_token_value_fp8, -MAX_FP8, MAX_FP8).to(
627
+ TL_FP8_DTYPE
628
+ )
629
+ tl.store(
630
+ out_2d_ptr,
631
+ output_token_value_fp8,
632
+ None,
633
+ cache_modifier=".cg",
634
+ )
635
+ in_2d_ptr += BLOCK_D
636
+ out_2d_ptr += BLOCK_D
637
+
638
+
639
+ _NV_CONFIGS = [
640
+ triton.Config(
641
+ {
642
+ "SPLIT_T": split_t,
643
+ "BLOCK_D": block_d,
644
+ },
645
+ num_stages=num_stages,
646
+ num_warps=num_warps,
647
+ num_ctas=num_ctas,
648
+ )
649
+ for split_t in [1, 4, 8, 16]
650
+ for block_d in [512, 1024]
651
+ for num_stages in [1, 3]
652
+ for num_warps in [8, 16]
653
+ for num_ctas in [1]
654
+ ]
655
+
656
+ _AMD_CONFIGS = [
657
+ triton.Config(
658
+ {
659
+ "SPLIT_T": split_t,
660
+ "BLOCK_D": block_d,
661
+ "waves_per_eu": waves_per_eu,
662
+ },
663
+ num_stages=num_stages,
664
+ num_warps=num_warps,
665
+ )
666
+ for split_t in [2, 8, 16, 32]
667
+ for block_d in [512, 1024]
668
+ for num_stages in [1, 3]
669
+ for num_warps, waves_per_eu in [(8, 2), (16, 4)]
670
+ ]
671
+
672
+
673
+ @triton.autotune(
674
+ configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
675
+ restore_value=("out_tokens_ptr",),
676
+ key=["EP", "E", "T_BUCKET", "D"],
677
+ )
678
+ @triton.jit
679
+ def _fbgemm_scatter_add_padded_tokens(
680
+ in_tokens_ptr,
681
+ token_counts_ptr,
682
+ token_indices_ptr,
683
+ out_tokens_ptr,
684
+ EP: tl.constexpr,
685
+ E: tl.constexpr,
686
+ T_BUCKET,
687
+ T_K,
688
+ D: tl.constexpr,
689
+ BLOCK_E: tl.constexpr,
690
+ SPLIT_T: tl.constexpr,
691
+ BLOCK_D: tl.constexpr,
692
+ ):
693
+ """
694
+ in_tokens: [EP, T_K, D]
695
+ token_counts: [E]
696
+ out_tokens: [T, D]
697
+ """
698
+ expert = tl.program_id(0)
699
+ t_tile = tl.program_id(1)
700
+
701
+ tl.static_assert(D % BLOCK_D == 0)
702
+ NUM_D_BLOCKS: tl.constexpr = D // BLOCK_D
703
+
704
+ num_tokens = tl.load(token_counts_ptr + expert)
705
+ if num_tokens == 0:
706
+ return
707
+
708
+ num_tokens_per_cta = tl.cdiv(num_tokens, SPLIT_T)
709
+ start_token = t_tile * num_tokens_per_cta
710
+ end_token = min(start_token + num_tokens_per_cta, num_tokens)
711
+
712
+ tl.static_assert(E % EP == 0)
713
+ EXPERT_PER_RANK: tl.constexpr = E // EP
714
+ rank = expert // EXPERT_PER_RANK
715
+
716
+ offs_e = tl.arange(0, BLOCK_E)
717
+ token_counts = tl.load(token_counts_ptr + offs_e, mask=(offs_e < E), other=0)
718
+ input_local_offset = (
719
+ tl.sum(tl.where(offs_e < expert, token_counts, 0)) + start_token
720
+ ).to(tl.int64)
721
+
722
+ for _t in range(start_token, end_token):
723
+ output_local_offset = tl.load(token_indices_ptr + input_local_offset).to(
724
+ tl.int64
725
+ )
726
+ output_global_offset = output_local_offset * D
727
+
728
+ d_ptr = tl.arange(0, BLOCK_D)
729
+ input_global_ptr = (
730
+ in_tokens_ptr + rank * T_K * D + input_local_offset * D + d_ptr
731
+ )
732
+ output_global_ptr = out_tokens_ptr + output_global_offset + d_ptr
733
+
734
+ for _d in range(NUM_D_BLOCKS):
735
+ vec = tl.load(input_global_ptr)
736
+ tl.atomic_add(output_global_ptr, vec, sem="relaxed")
737
+ input_global_ptr += BLOCK_D
738
+ output_global_ptr += BLOCK_D
739
+
740
+ input_local_offset += 1