fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,1192 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import functools
10
+ import inspect
11
+ import warnings
12
+
13
+ from typing import Optional
14
+
15
+ import torch
16
+
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from fbgemm_gpu.experimental.gemm.triton_gemm import utils
21
+ from triton.runtime import driver # @manual
22
+
23
+
24
+ _NV_CONFIGS = [
25
+ triton.Config(
26
+ {
27
+ "BLOCK_SIZE_M": block_size_m,
28
+ "BLOCK_SIZE_N": block_size_n,
29
+ "BLOCK_SIZE_K": block_size_k,
30
+ "NUM_CONSUMER_GROUPS": 1,
31
+ },
32
+ num_stages=num_stages,
33
+ num_warps=num_warps,
34
+ num_ctas=num_ctas,
35
+ )
36
+ for block_size_m in [64, 128]
37
+ for block_size_n in [64, 128, 256]
38
+ for block_size_k in [64, 128, 256]
39
+ for num_stages in [3, 4]
40
+ for num_warps in [4, 8]
41
+ for num_ctas in [1]
42
+ ]
43
+
44
+ _HAS_WS_SUPPORT = None
45
+
46
+
47
+ def _check_ws_support():
48
+ if not hasattr(tl, "async_task"):
49
+ return False
50
+ config_signature = inspect.signature(triton.Config).parameters
51
+ if (
52
+ "num_consumer_groups" not in config_signature
53
+ or "num_buffers_warp_spec" not in config_signature
54
+ ):
55
+ return False
56
+ if not utils.HAS_TMA_DESC:
57
+ return False
58
+ return True
59
+
60
+
61
+ def _set_ws_support():
62
+ global _HAS_WS_SUPPORT
63
+ if _HAS_WS_SUPPORT is None:
64
+ _HAS_WS_SUPPORT = _check_ws_support()
65
+
66
+
67
+ _set_ws_support()
68
+
69
+ if _HAS_WS_SUPPORT:
70
+ _NV_WS_CONFIGS = [
71
+ triton.Config(
72
+ {
73
+ "BLOCK_SIZE_M": block_size_m,
74
+ "BLOCK_SIZE_N": block_size_n,
75
+ "BLOCK_SIZE_K": block_size_k,
76
+ "NUM_CONSUMER_GROUPS": max(1, num_consumer_groups),
77
+ "USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales,
78
+ "USE_TMA_STORE": use_tma_store,
79
+ },
80
+ num_stages=num_stages,
81
+ num_warps=num_warps,
82
+ num_ctas=num_ctas,
83
+ num_consumer_groups=num_consumer_groups,
84
+ num_buffers_warp_spec=num_stages,
85
+ )
86
+ for block_size_m in [64, 128, 256]
87
+ for block_size_n in [64, 128, 256]
88
+ for block_size_k in [64, 128, 256]
89
+ for num_stages in [2, 3, 4]
90
+ for num_warps in [4, 8, 16]
91
+ # TODO(shikaili): Resolve LLVM error.
92
+ for num_ctas in [1]
93
+ for num_consumer_groups in [0, 2]
94
+ for use_tma_load_on_scales in [True, False]
95
+ # TODO(shikaili): Resolve compatibility with ws.
96
+ for use_tma_store in [False]
97
+ ]
98
+ else:
99
+ _NV_WS_CONFIGS = _NV_CONFIGS
100
+
101
+
102
+ _AMD_CONFIGS = [
103
+ triton.Config(
104
+ {
105
+ "BLOCK_SIZE_M": block_size_m,
106
+ "BLOCK_SIZE_N": block_size_n,
107
+ "BLOCK_SIZE_K": block_size_k,
108
+ "waves_per_eu": waves_per_cu,
109
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
110
+ "NUM_CONSUMER_GROUPS": 1,
111
+ },
112
+ num_stages=num_stages,
113
+ num_warps=num_warps,
114
+ )
115
+ for block_size_m in [32, 64, 128]
116
+ for block_size_n in [32, 64, 128, 256]
117
+ for block_size_k in [128, 256]
118
+ for num_stages in [1, 2]
119
+ for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
120
+ for matrix_instr_nonkdim in [16]
121
+ ]
122
+
123
+
124
+ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
125
+ device = torch.cuda.current_device()
126
+ # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
127
+ if dtsize is None:
128
+ dtsize = named_args["c_ptr"].element_size()
129
+ if dtype is None:
130
+ dtype = named_args["c_ptr"].dtype
131
+
132
+ pruned_configs = []
133
+ for config in configs:
134
+ kw = config.kwargs
135
+ (
136
+ BLOCK_M,
137
+ BLOCK_N,
138
+ BLOCK_K,
139
+ num_stages,
140
+ use_tma_load_on_scales,
141
+ ) = (
142
+ kw["BLOCK_SIZE_M"],
143
+ kw["BLOCK_SIZE_N"],
144
+ kw["BLOCK_SIZE_K"],
145
+ config.num_stages,
146
+ kw.get("USE_TMA_LOAD_ON_SCALES", False),
147
+ )
148
+ G, M, N = (
149
+ named_args["G"],
150
+ named_args["M_BUCKET"],
151
+ named_args["N"],
152
+ )
153
+
154
+ # 1. make sure we have enough smem
155
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
156
+ "max_shared_mem"
157
+ ]
158
+ if torch.version.hip:
159
+ required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
160
+ else:
161
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
162
+ if required_shared_memory > max_shared_memory:
163
+ continue
164
+
165
+ M_PER_GROUP = M // G
166
+ MIN_M_TILES = 32 if torch.version.hip else 64
167
+ # 2. make sure we don't load M tiles that are too big
168
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
169
+ continue
170
+ # 3. make sure we don't load N tiles that are too small
171
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
172
+ continue
173
+
174
+ num_sm = driver.active.utils.get_device_properties(device)[
175
+ "multiprocessor_count"
176
+ ]
177
+ N_TILES = (N + BLOCK_N - 1) // BLOCK_N
178
+ MIN_N_TILES = 32 if torch.version.hip else 64
179
+ # 4. make sure we don't load N tiles that are too big
180
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
181
+ continue
182
+ # 5. make sure we don't load N tiles that are too small
183
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
184
+ continue
185
+ if dtsize >= 2:
186
+ if use_tma_load_on_scales:
187
+ continue
188
+ pruned_configs.append(config)
189
+
190
+ return pruned_configs
191
+
192
+
193
+ def early_config_prune_ws(configs, named_args, dtsize=None, dtype=None, **kwargs):
194
+ device = torch.cuda.current_device()
195
+ # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
196
+ if dtsize is None:
197
+ dtsize = named_args["c_ptr"].element_size()
198
+ if dtype is None:
199
+ dtype = named_args["c_ptr"].dtype
200
+
201
+ pruned_configs = []
202
+ for config in configs:
203
+ kw = config.kwargs
204
+ (
205
+ BLOCK_M,
206
+ BLOCK_N,
207
+ BLOCK_K,
208
+ num_stages,
209
+ num_warps,
210
+ num_consumer_groups,
211
+ use_tma_load_on_scales,
212
+ ) = (
213
+ kw["BLOCK_SIZE_M"],
214
+ kw["BLOCK_SIZE_N"],
215
+ kw["BLOCK_SIZE_K"],
216
+ config.num_stages,
217
+ config.num_warps,
218
+ config.num_consumer_groups,
219
+ kw.get("USE_TMA_LOAD_ON_SCALES", False),
220
+ )
221
+ G, M, N = (
222
+ named_args["G"],
223
+ named_args["M_BUCKET"],
224
+ named_args["N"],
225
+ )
226
+
227
+ # 1. make sure we have enough smem
228
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
229
+ "max_shared_mem"
230
+ ]
231
+ if torch.version.hip:
232
+ required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
233
+ else:
234
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
235
+ if required_shared_memory > max_shared_memory:
236
+ continue
237
+
238
+ use_warp_specialization = num_consumer_groups >= 1
239
+
240
+ M_PER_GROUP = M // G
241
+ MIN_M_TILES = 32 if torch.version.hip else 64
242
+ # 2. make sure we don't load M tiles that are too big
243
+ if (
244
+ not use_warp_specialization
245
+ and BLOCK_M > MIN_M_TILES
246
+ and BLOCK_M > (M_PER_GROUP * 2)
247
+ ):
248
+ continue
249
+ # 3. make sure we don't load N tiles that are too small
250
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
251
+ continue
252
+
253
+ num_sm = driver.active.utils.get_device_properties(device)[
254
+ "multiprocessor_count"
255
+ ]
256
+ N_TILES = (N + BLOCK_N - 1) // BLOCK_N
257
+ MIN_N_TILES = 32 if torch.version.hip else 64
258
+ # 4. make sure we don't load N tiles that are too big
259
+ if (
260
+ not use_warp_specialization
261
+ and BLOCK_N > MIN_N_TILES
262
+ and M * N_TILES < num_sm
263
+ ):
264
+ continue
265
+ # 5. make sure we don't load N tiles that are too small
266
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
267
+ continue
268
+
269
+ # 6. make sure we can partition for ws
270
+ if use_warp_specialization:
271
+ if num_warps != 4:
272
+ continue
273
+
274
+ # "tritongpu-warp-spec-data-partition"
275
+ m_slice = BLOCK_M // num_consumer_groups
276
+ n_slice = BLOCK_N // num_consumer_groups
277
+ if m_slice < 64 and n_slice < 256:
278
+ continue
279
+
280
+ if dtsize >= 2:
281
+ if use_tma_load_on_scales:
282
+ continue
283
+ pruned_configs.append(config)
284
+
285
+ return pruned_configs
286
+
287
+
288
+ @triton.autotune(
289
+ configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
290
+ key=["G", "M_BUCKET", "N", "K"],
291
+ prune_configs_by={"early_config_prune": early_config_prune},
292
+ restore_value=["c_ptr"], # restore for scatter_add fusion
293
+ )
294
+ @triton.jit
295
+ def _fbgemm_grouped_gemm(
296
+ a_desc_ptr,
297
+ b_desc_ptr,
298
+ c_ptr,
299
+ scatter_add_indices,
300
+ m_sizes,
301
+ # problem sizes
302
+ G: tl.constexpr,
303
+ M_BUCKET,
304
+ N: tl.constexpr,
305
+ K: tl.constexpr,
306
+ NUM_SMS: tl.constexpr,
307
+ FUSE_SCATTER_ADD: tl.constexpr,
308
+ USE_TMA_LOAD: tl.constexpr,
309
+ USE_TMA_STORE: tl.constexpr,
310
+ USE_FAST_ACCUM: tl.constexpr,
311
+ # tile sizes
312
+ BLOCK_SIZE_M: tl.constexpr,
313
+ BLOCK_SIZE_N: tl.constexpr,
314
+ BLOCK_SIZE_K: tl.constexpr,
315
+ NUM_CONSUMER_GROUPS: tl.constexpr,
316
+ ) -> None:
317
+ tl.static_assert(
318
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
319
+ "Cannot fuse scatter add with TMA store!",
320
+ )
321
+
322
+ tidx = tl.program_id(0)
323
+
324
+ dtype: tl.dtype = c_ptr.dtype.element_ty
325
+
326
+ M_end_offset = 0
327
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
328
+ iterated_tiles = 0
329
+ for g in tl.range(G):
330
+ # Move across groups
331
+ m_size = tl.load(m_sizes + g)
332
+
333
+ if m_size > 0:
334
+ M_start_offset = M_end_offset
335
+ M_end_offset = M_start_offset + m_size
336
+ N_start_offset = g.to(tl.int64) * N
337
+ n_size = N
338
+
339
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
340
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
341
+ num_tiles = num_m_tiles * num_n_tiles
342
+
343
+ if USE_TMA_STORE:
344
+ c_desc_ptr = tl.make_tensor_descriptor(
345
+ c_ptr + M_start_offset * N,
346
+ shape=[m_size, n_size],
347
+ # pyre-ignore
348
+ strides=[n_size, 1],
349
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
350
+ )
351
+
352
+ # Move across tiles
353
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
354
+ gidx = tidx - iterated_tiles
355
+ # Split M first and N second.
356
+ tile_m_idx = gidx % num_m_tiles
357
+ tile_n_idx = gidx // num_m_tiles
358
+
359
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
360
+
361
+ if USE_TMA_LOAD:
362
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
363
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
364
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
365
+ for k_offset in range(0, K, BLOCK_SIZE_K):
366
+ a = tl._experimental_descriptor_load(
367
+ a_desc_ptr,
368
+ [m_offset, k_offset],
369
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
370
+ dtype,
371
+ )
372
+ b = tl._experimental_descriptor_load(
373
+ b_desc_ptr,
374
+ [n_offset, k_offset],
375
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
376
+ dtype,
377
+ )
378
+ if USE_FAST_ACCUM:
379
+ accumulator = tl.dot(a, b.T, accumulator)
380
+ else:
381
+ accumulator += tl.dot(a, b.T)
382
+ else:
383
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
384
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
385
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
386
+ a_ptrs = (
387
+ a_desc_ptr
388
+ + (M_start_offset + offs_am[:, None]) * K
389
+ + offs_k[None, :]
390
+ )
391
+ b_ptrs = (
392
+ b_desc_ptr
393
+ + (N_start_offset + offs_bn[:, None]) * K
394
+ + offs_k[None, :]
395
+ )
396
+ for k_offset in range(0, K, BLOCK_SIZE_K):
397
+ updated_k_offset = k_offset + offs_k
398
+ updated_k_offset_mask = updated_k_offset[None, :] < K # type: ignore[16]
399
+ a = tl.load(
400
+ a_ptrs,
401
+ mask=((offs_am[:, None] < m_size) & updated_k_offset_mask),
402
+ other=0.0,
403
+ )
404
+ b = tl.load(
405
+ b_ptrs,
406
+ mask=((offs_bn[:, None] < n_size) & updated_k_offset_mask),
407
+ other=0.0,
408
+ )
409
+ accumulator += tl.dot(a, b.T)
410
+ a_ptrs += BLOCK_SIZE_K
411
+ b_ptrs += BLOCK_SIZE_K
412
+
413
+ if USE_TMA_STORE:
414
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
415
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
416
+ # pyre-ignore
417
+ c_desc_ptr.store(
418
+ [m_offset, n_offset], accumulator.to(c_ptr.dtype.element_ty)
419
+ )
420
+ elif FUSE_SCATTER_ADD:
421
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
422
+ mask = offs_am < m_size
423
+ m_offsets = tl.load(
424
+ scatter_add_indices + M_start_offset + offs_am,
425
+ mask=mask,
426
+ cache_modifier=".ca",
427
+ )
428
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
429
+ c = accumulator.to(c_ptr.dtype.element_ty)
430
+ tl.atomic_add(
431
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
432
+ c,
433
+ mask=mask[:, None] and offs_bn[None, :] < n_size,
434
+ sem="relaxed",
435
+ )
436
+ else:
437
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
438
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
439
+ c = accumulator.to(c_ptr.dtype.element_ty)
440
+ tl.store(
441
+ c_ptr
442
+ + (M_start_offset + offs_am[:, None]) * N
443
+ + offs_bn[None, :],
444
+ c,
445
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
446
+ )
447
+ tidx += NUM_SMS
448
+
449
+ iterated_tiles += num_tiles
450
+
451
+
452
+ # TODO(shikaili): Too much code duplication. Need to refactor.
453
+ @triton.autotune(
454
+ configs=_NV_WS_CONFIGS,
455
+ key=["G", "M_BUCKET", "N", "K"],
456
+ prune_configs_by={"early_config_prune": early_config_prune_ws},
457
+ restore_value=["c_ptr"], # restore for scatter_add fusion
458
+ )
459
+ @triton.jit
460
+ def _fbgemm_grouped_gemm_ws(
461
+ a_desc_ptr,
462
+ b_desc_ptr,
463
+ c_ptr,
464
+ scatter_add_indices,
465
+ m_sizes,
466
+ # problem sizes
467
+ G: tl.constexpr,
468
+ M_BUCKET: tl.constexpr,
469
+ N: tl.constexpr,
470
+ K: tl.constexpr,
471
+ NUM_SMS: tl.constexpr,
472
+ FUSE_SCATTER_ADD: tl.constexpr,
473
+ USE_TMA_LOAD: tl.constexpr,
474
+ USE_FAST_ACCUM: tl.constexpr,
475
+ # tile sizes
476
+ BLOCK_SIZE_M: tl.constexpr,
477
+ BLOCK_SIZE_N: tl.constexpr,
478
+ BLOCK_SIZE_K: tl.constexpr,
479
+ NUM_CONSUMER_GROUPS: tl.constexpr,
480
+ USE_TMA_LOAD_ON_SCALES: tl.constexpr,
481
+ USE_TMA_STORE: tl.constexpr,
482
+ ) -> None:
483
+ tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
484
+ tl.static_assert(not USE_TMA_LOAD_ON_SCALES, "Not supported!")
485
+ tl.static_assert(
486
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
487
+ "Cannot fuse scatter add with TMA store!",
488
+ )
489
+
490
+ tidx = tl.program_id(0)
491
+
492
+ dtype: tl.dtype = c_ptr.dtype.element_ty
493
+
494
+ M_end_offset = 0
495
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
496
+ iterated_tiles = 0
497
+ for g in tl.range(G):
498
+ # Move across groups
499
+ m_size = tl.load(m_sizes + g, cache_modifier=".ca")
500
+
501
+ if m_size > 0:
502
+ M_start_offset = M_end_offset
503
+ M_end_offset = M_start_offset + m_size
504
+ N_start_offset = g.to(tl.int64) * N
505
+
506
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
507
+ tl.static_assert(N % BLOCK_SIZE_N == 0, f"{N=} {BLOCK_SIZE_N=}")
508
+ NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
509
+ num_tiles = num_m_tiles * NUM_N_TILES
510
+
511
+ if USE_TMA_STORE:
512
+ c_desc_ptr = tl.make_tensor_descriptor(
513
+ c_ptr + M_start_offset * N,
514
+ shape=[m_size, N],
515
+ # pyre-ignore
516
+ strides=[N, 1],
517
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
518
+ )
519
+
520
+ # Move across tiles
521
+ next_iterated_tiles = iterated_tiles + num_tiles
522
+ if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
523
+ for i in range(tidx, next_iterated_tiles, NUM_SMS):
524
+ gidx = i - iterated_tiles
525
+ # Split M first and N second.
526
+ tile_m_idx = gidx % num_m_tiles
527
+ tile_n_idx = gidx // num_m_tiles
528
+
529
+ accumulator = tl.zeros(
530
+ (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
531
+ )
532
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
533
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
534
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
535
+ for k_offset in range(0, K, BLOCK_SIZE_K):
536
+ a = tl._experimental_descriptor_load(
537
+ a_desc_ptr,
538
+ [m_offset, k_offset],
539
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
540
+ dtype,
541
+ )
542
+ b = tl._experimental_descriptor_load(
543
+ b_desc_ptr,
544
+ [n_offset, k_offset],
545
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
546
+ dtype,
547
+ )
548
+ if USE_FAST_ACCUM:
549
+ accumulator = tl.dot(a, b.T, accumulator)
550
+ else:
551
+ accumulator += tl.dot(a, b.T)
552
+
553
+ if USE_TMA_STORE:
554
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
555
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
556
+ # pyre-ignore
557
+ c_desc_ptr.store(
558
+ [m_offset, n_offset],
559
+ accumulator.to(c_ptr.dtype.element_ty),
560
+ )
561
+ elif FUSE_SCATTER_ADD:
562
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
563
+ mask = offs_am < m_size
564
+ m_offsets = tl.load(
565
+ scatter_add_indices + M_start_offset + offs_am,
566
+ mask=mask,
567
+ cache_modifier=".ca",
568
+ )
569
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
570
+ c = accumulator.to(c_ptr.dtype.element_ty)
571
+ tl.atomic_add(
572
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
573
+ c,
574
+ mask=mask[:, None],
575
+ sem="relaxed",
576
+ )
577
+ else:
578
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
579
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
580
+ c = accumulator.to(c_ptr.dtype.element_ty)
581
+ tl.store(
582
+ c_ptr
583
+ + (M_start_offset + offs_am[:, None]) * N
584
+ + offs_bn[None, :],
585
+ c,
586
+ mask=offs_am[:, None] < m_size,
587
+ cache_modifier=".cs",
588
+ )
589
+ tidx += NUM_SMS
590
+
591
+ iterated_tiles += num_tiles
592
+
593
+
594
+ TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
595
+
596
+
597
+ # TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument.
598
+ @triton.autotune(
599
+ configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
600
+ key=["G", "M_BUCKET", "N", "K"],
601
+ prune_configs_by={
602
+ "early_config_prune": functools.partial(
603
+ early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
604
+ )
605
+ },
606
+ restore_value=["c_ptr"], # restore for scatter_add fusion
607
+ )
608
+ @triton.jit
609
+ def _fbgemm_grouped_gemm_fp8_rowwise(
610
+ a_desc_ptr,
611
+ a_scale_ptr,
612
+ b_desc_ptr,
613
+ b_scale_ptr,
614
+ b_scale_desc_ptr,
615
+ c_ptr,
616
+ scatter_add_indices,
617
+ m_sizes,
618
+ # problem sizes
619
+ G: tl.constexpr,
620
+ M_BUCKET,
621
+ N: tl.constexpr,
622
+ K: tl.constexpr,
623
+ NUM_SMS: tl.constexpr,
624
+ FUSE_SCATTER_ADD: tl.constexpr,
625
+ USE_TMA_LOAD: tl.constexpr,
626
+ USE_TMA_STORE: tl.constexpr,
627
+ USE_FAST_ACCUM: tl.constexpr,
628
+ # tile sizes
629
+ BLOCK_SIZE_M: tl.constexpr,
630
+ BLOCK_SIZE_N: tl.constexpr,
631
+ BLOCK_SIZE_K: tl.constexpr,
632
+ NUM_CONSUMER_GROUPS: tl.constexpr,
633
+ ) -> None:
634
+ tl.static_assert(
635
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
636
+ "Cannot fuse scatter add with TMA store!",
637
+ )
638
+
639
+ tidx = tl.program_id(0)
640
+
641
+ dtype = TT_FP8_DTYPE
642
+
643
+ M_end_offset = 0
644
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
645
+ iterated_tiles = 0
646
+ for g in tl.range(G):
647
+ # Move across groups
648
+ m_size = tl.load(m_sizes + g)
649
+
650
+ if m_size > 0:
651
+ M_start_offset = M_end_offset
652
+ M_end_offset = M_start_offset + m_size
653
+ N_start_offset = g.to(tl.int64) * N
654
+ n_size = N
655
+
656
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
657
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
658
+ num_tiles = num_m_tiles * num_n_tiles
659
+
660
+ if USE_TMA_STORE:
661
+ c_desc_ptr = tl.make_tensor_descriptor(
662
+ c_ptr + M_start_offset * N,
663
+ shape=[m_size, n_size],
664
+ # pyre-ignore
665
+ strides=[n_size, 1],
666
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
667
+ )
668
+
669
+ # Move across tiles
670
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
671
+ gidx = tidx - iterated_tiles
672
+ # Split M first and N second.
673
+ tile_m_idx = gidx % num_m_tiles
674
+ tile_n_idx = gidx // num_m_tiles
675
+
676
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
677
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
678
+ if USE_TMA_LOAD:
679
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
680
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
681
+ for k_offset in range(0, K, BLOCK_SIZE_K):
682
+ a = tl._experimental_descriptor_load(
683
+ a_desc_ptr,
684
+ [m_offset, k_offset],
685
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
686
+ dtype,
687
+ )
688
+ b = tl._experimental_descriptor_load(
689
+ b_desc_ptr,
690
+ [n_offset, k_offset],
691
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
692
+ dtype,
693
+ )
694
+ if USE_FAST_ACCUM:
695
+ accumulator = tl.dot(a, b.T, accumulator)
696
+ else:
697
+ accumulator += tl.dot(a, b.T)
698
+ else:
699
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
700
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
701
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
702
+ a_ptrs = (
703
+ a_desc_ptr
704
+ + (M_start_offset + offs_am[:, None]) * K
705
+ + offs_k[None, :]
706
+ )
707
+ b_ptrs = (
708
+ b_desc_ptr
709
+ + (N_start_offset + offs_bn[:, None]) * K
710
+ + offs_k[None, :]
711
+ )
712
+ for k_offset in range(0, K, BLOCK_SIZE_K):
713
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
714
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
715
+ accumulator += tl.dot(a, b.T)
716
+ a_ptrs += BLOCK_SIZE_K
717
+ b_ptrs += BLOCK_SIZE_K
718
+
719
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
720
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
721
+ a_scale = tl.load(
722
+ a_scale_ptr + M_start_offset + offs_am[:, None],
723
+ mask=offs_am[:, None] < m_size,
724
+ )
725
+ b_scale = tl.load(
726
+ b_scale_ptr + N_start_offset + offs_bn[None, :],
727
+ mask=offs_bn[None, :] < n_size,
728
+ )
729
+ c = accumulator.to(tl.float32) * a_scale * b_scale
730
+
731
+ if USE_TMA_STORE:
732
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
733
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
734
+ # pyre-ignore
735
+ c_desc_ptr.store([m_offset, n_offset], c.to(c_ptr.dtype.element_ty))
736
+ elif FUSE_SCATTER_ADD:
737
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
738
+ mask = offs_am < m_size
739
+ m_offsets = tl.load(
740
+ scatter_add_indices + M_start_offset + offs_am,
741
+ mask=mask,
742
+ cache_modifier=".ca",
743
+ )
744
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
745
+ tl.atomic_add(
746
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
747
+ c.to(c_ptr.dtype.element_ty),
748
+ mask=mask[:, None] and offs_bn[None, :] < n_size,
749
+ sem="relaxed",
750
+ )
751
+ else:
752
+ tl.store(
753
+ c_ptr
754
+ + (M_start_offset + offs_am[:, None]) * N
755
+ + offs_bn[None, :],
756
+ c,
757
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
758
+ )
759
+ tidx += NUM_SMS
760
+
761
+ iterated_tiles += num_tiles
762
+
763
+
764
+ # TODO(shikaili): Too much code duplication. Need to refactor.
765
+ @triton.autotune(
766
+ configs=_NV_WS_CONFIGS,
767
+ key=["G", "M_BUCKET", "N", "K"],
768
+ prune_configs_by={
769
+ "early_config_prune": functools.partial(
770
+ early_config_prune_ws, dtype=TT_FP8_DTYPE, dtsize=1
771
+ )
772
+ },
773
+ restore_value=["c_ptr"], # restore for scatter_add fusion
774
+ )
775
+ @triton.jit
776
+ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
777
+ a_desc_ptr,
778
+ a_scale_ptr,
779
+ b_desc_ptr,
780
+ b_scale_ptr,
781
+ b_scale_desc_ptr,
782
+ c_ptr,
783
+ scatter_add_indices,
784
+ m_sizes,
785
+ # problem sizes
786
+ G: tl.constexpr,
787
+ M_BUCKET: tl.constexpr,
788
+ N: tl.constexpr,
789
+ K: tl.constexpr,
790
+ NUM_SMS: tl.constexpr,
791
+ FUSE_SCATTER_ADD: tl.constexpr,
792
+ USE_TMA_LOAD: tl.constexpr,
793
+ USE_FAST_ACCUM: tl.constexpr,
794
+ # tile sizes
795
+ BLOCK_SIZE_M: tl.constexpr,
796
+ BLOCK_SIZE_N: tl.constexpr,
797
+ BLOCK_SIZE_K: tl.constexpr,
798
+ NUM_CONSUMER_GROUPS: tl.constexpr,
799
+ USE_TMA_LOAD_ON_SCALES: tl.constexpr,
800
+ USE_TMA_STORE: tl.constexpr,
801
+ ) -> None:
802
+ tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
803
+ tl.static_assert(
804
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
805
+ "Cannot fuse scatter add with TMA store!",
806
+ )
807
+
808
+ tidx = tl.program_id(0)
809
+
810
+ dtype = TT_FP8_DTYPE
811
+
812
+ M_end_offset = 0
813
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
814
+ iterated_tiles = 0
815
+ for g in tl.range(G):
816
+ # Move across groups
817
+ m_size = tl.load(m_sizes + g, cache_modifier=".ca")
818
+
819
+ if m_size > 0:
820
+ M_start_offset = M_end_offset
821
+ M_end_offset = M_start_offset + m_size
822
+ N_start_offset = g.to(tl.int64) * N
823
+
824
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
825
+ tl.static_assert(N % BLOCK_SIZE_N == 0)
826
+ NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
827
+ num_tiles = num_m_tiles * NUM_N_TILES
828
+
829
+ if USE_TMA_STORE:
830
+ c_desc_ptr = tl.make_tensor_descriptor(
831
+ c_ptr + M_start_offset * N,
832
+ shape=[m_size, N],
833
+ # pyre-ignore
834
+ strides=[N, 1],
835
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
836
+ )
837
+
838
+ # Move across tiles
839
+ next_iterated_tiles = iterated_tiles + num_tiles
840
+ if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
841
+ for i in range(tidx, next_iterated_tiles, NUM_SMS):
842
+ gidx = i - iterated_tiles
843
+ # Split M first and N second.
844
+ tile_m_idx = gidx % num_m_tiles
845
+ tile_n_idx = gidx // num_m_tiles
846
+
847
+ accumulator = tl.zeros(
848
+ (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
849
+ )
850
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
851
+
852
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
853
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
854
+ for k_offset in range(0, K, BLOCK_SIZE_K):
855
+ a = tl._experimental_descriptor_load(
856
+ a_desc_ptr,
857
+ [m_offset, k_offset],
858
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
859
+ dtype,
860
+ )
861
+ b = tl._experimental_descriptor_load(
862
+ b_desc_ptr,
863
+ [n_offset, k_offset],
864
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
865
+ dtype,
866
+ )
867
+ if USE_FAST_ACCUM:
868
+ accumulator = tl.dot(a, b.T, accumulator)
869
+ else:
870
+ accumulator += tl.dot(a, b.T)
871
+
872
+ if USE_TMA_LOAD_ON_SCALES:
873
+ b_scale = tl._experimental_descriptor_load(
874
+ b_scale_desc_ptr,
875
+ [n_offset],
876
+ [BLOCK_SIZE_N],
877
+ tl.float32,
878
+ )
879
+
880
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
881
+ a_scale = tl.load(
882
+ a_scale_ptr + M_start_offset + offs_am[:, None],
883
+ mask=offs_am[:, None] < m_size,
884
+ cache_modifier=".ca",
885
+ )
886
+ c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
887
+ else:
888
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
889
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
890
+ a_scale = tl.load(
891
+ a_scale_ptr + M_start_offset + offs_am[:, None],
892
+ mask=offs_am[:, None] < m_size,
893
+ cache_modifier=".ca",
894
+ )
895
+ b_scale = tl.load(
896
+ b_scale_ptr + N_start_offset + offs_bn[None, :],
897
+ cache_modifier=".ca",
898
+ )
899
+ c = accumulator.to(tl.float32) * a_scale * b_scale
900
+
901
+ if USE_TMA_STORE:
902
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
903
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
904
+ # pyre-ignore
905
+ c_desc_ptr.store(
906
+ [m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
907
+ )
908
+ elif FUSE_SCATTER_ADD:
909
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
910
+ mask = offs_am < m_size
911
+ m_offsets = tl.load(
912
+ scatter_add_indices + M_start_offset + offs_am,
913
+ mask=mask,
914
+ cache_modifier=".ca",
915
+ )
916
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
917
+ tl.atomic_add(
918
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
919
+ c,
920
+ mask=mask[:, None],
921
+ sem="relaxed",
922
+ )
923
+ else:
924
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
925
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
926
+ tl.store(
927
+ c_ptr
928
+ + (M_start_offset + offs_am[:, None]) * N
929
+ + offs_bn[None, :],
930
+ c,
931
+ mask=offs_am[:, None] < m_size,
932
+ cache_modifier=".cs",
933
+ )
934
+ tidx += NUM_SMS
935
+
936
+ iterated_tiles += num_tiles
937
+
938
+
939
+ warnings.simplefilter("once")
940
+
941
+
942
+ def _grouped_gemm(
943
+ *,
944
+ x: torch.Tensor,
945
+ w: torch.Tensor,
946
+ m_sizes: torch.Tensor,
947
+ x_scale: Optional[torch.Tensor],
948
+ w_scale: Optional[torch.Tensor],
949
+ use_fast_accum: bool,
950
+ use_warp_specialization: bool,
951
+ output_tensor: Optional[torch.Tensor],
952
+ scatter_add_indices: Optional[torch.Tensor],
953
+ ) -> torch.Tensor:
954
+
955
+ USE_TMA_LOAD = not torch.version.hip
956
+ USE_TMA_STORE = False
957
+
958
+ if USE_TMA_LOAD and not utils.HAS_TMA_DESC:
959
+ USE_TMA_LOAD = False
960
+ warnings.warn(
961
+ "TMA load is disabled as there is no TMA descriptor support!", stacklevel=2
962
+ )
963
+
964
+ if USE_TMA_STORE and not utils.HAS_TMA_DESC:
965
+ USE_TMA_STORE = False
966
+ warnings.warn(
967
+ "TMA store is disabled as there is no TMA descriptor support!", stacklevel=2
968
+ )
969
+
970
+ # TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
971
+ if use_warp_specialization and torch.version.hip:
972
+ warnings.warn(
973
+ "Warp specialization is disabled as it is not supported on ROCm.",
974
+ stacklevel=2,
975
+ )
976
+ use_warp_specialization = False
977
+
978
+ if use_warp_specialization and not _HAS_WS_SUPPORT:
979
+ warnings.warn(
980
+ "Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.",
981
+ stacklevel=2,
982
+ )
983
+ use_warp_specialization = False
984
+
985
+ if use_warp_specialization:
986
+ assert utils.HAS_TMA_DESC
987
+ USE_TMA_STORE = True # Tuning decision
988
+
989
+ G = m_sizes.shape[0]
990
+
991
+ assert x.is_contiguous()
992
+ assert w.is_contiguous()
993
+ assert m_sizes.is_contiguous()
994
+
995
+ M, K = x.shape
996
+ N = w.shape[0] // G
997
+ assert K == w.shape[1]
998
+
999
+ if K % 8 != 0 or N % 8 != 0:
1000
+ use_warp_specialization = False
1001
+ USE_TMA_LOAD = False
1002
+ USE_TMA_STORE = False
1003
+ warnings.warn(
1004
+ f"TMA load and warp specialization are disabled since K or N is not a multiple of 8: {K=}, {N=}.",
1005
+ stacklevel=2,
1006
+ )
1007
+ assert (
1008
+ x_scale is None
1009
+ ), f"Quantisation is not supported yet when K or N is not a multiple of 8: {K=}, {N=}."
1010
+
1011
+ assert (
1012
+ output_tensor is None
1013
+ ), f"Fused scatter add has large rounding error when K or N is not a multiple of 8: {K=}, {N=}."
1014
+
1015
+ if output_tensor is None:
1016
+ FUSE_SCATTER_ADD = False
1017
+ assert scatter_add_indices is None
1018
+ y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
1019
+ else:
1020
+ FUSE_SCATTER_ADD = True
1021
+ assert scatter_add_indices is not None
1022
+ assert scatter_add_indices.is_contiguous()
1023
+ assert scatter_add_indices.shape == (M,)
1024
+ y = output_tensor
1025
+ if M == 0 or N == 0:
1026
+ return y
1027
+
1028
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
1029
+
1030
+ desc_helper = None
1031
+ desc_x = x
1032
+ desc_w = w
1033
+ desc_ws = w_scale
1034
+
1035
+ if USE_TMA_LOAD:
1036
+ desc_helper = utils.TmaAutoTuneHelper()
1037
+ desc_helper.init_tma_descriptor("x")
1038
+ desc_helper.init_tma_descriptor("w")
1039
+ desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
1040
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
1041
+ if use_warp_specialization and w_scale is not None:
1042
+ desc_helper.init_tma_descriptor("ws")
1043
+ desc_ws = desc_helper.get_tma_descriptor_kernel_param("ws")
1044
+
1045
+ if USE_TMA_STORE:
1046
+
1047
+ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
1048
+ return torch.empty(size, device="cuda", dtype=torch.int8)
1049
+
1050
+ triton.set_allocator(alloc_fn)
1051
+
1052
+ def grid(META):
1053
+ if USE_TMA_LOAD:
1054
+ nonlocal desc_helper # noqa: F824
1055
+ desc_helper.fill_2d_tma_descriptor(
1056
+ "x",
1057
+ x.data_ptr(),
1058
+ M,
1059
+ K,
1060
+ META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"],
1061
+ META["BLOCK_SIZE_K"],
1062
+ x.element_size(),
1063
+ )
1064
+
1065
+ desc_helper.fill_2d_tma_descriptor(
1066
+ "w",
1067
+ w.data_ptr(),
1068
+ N * G,
1069
+ K,
1070
+ META["BLOCK_SIZE_N"],
1071
+ META["BLOCK_SIZE_K"],
1072
+ w.element_size(),
1073
+ )
1074
+
1075
+ if META.get("USE_TMA_LOAD_ON_SCALES", False):
1076
+ desc_helper.fill_1d_tma_descriptor(
1077
+ "ws",
1078
+ w_scale.data_ptr(),
1079
+ N * G,
1080
+ META["BLOCK_SIZE_N"],
1081
+ w_scale.element_size(),
1082
+ )
1083
+
1084
+ return (NUM_SMS,)
1085
+
1086
+ M_BUCKET_CAP = 16384
1087
+ M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
1088
+ if x_scale is not None and w_scale is not None:
1089
+ assert x_scale.is_contiguous()
1090
+ assert w_scale.is_contiguous()
1091
+ fn = (
1092
+ _fbgemm_grouped_gemm_fp8_rowwise_ws
1093
+ if use_warp_specialization
1094
+ else _fbgemm_grouped_gemm_fp8_rowwise
1095
+ )
1096
+ args = (
1097
+ desc_x,
1098
+ x_scale,
1099
+ desc_w,
1100
+ w_scale,
1101
+ desc_ws,
1102
+ y,
1103
+ scatter_add_indices,
1104
+ m_sizes,
1105
+ G,
1106
+ M_BUCKET,
1107
+ N,
1108
+ K,
1109
+ NUM_SMS,
1110
+ FUSE_SCATTER_ADD,
1111
+ USE_TMA_LOAD,
1112
+ )
1113
+ if use_warp_specialization:
1114
+ args += (use_fast_accum,)
1115
+ else:
1116
+ args += (USE_TMA_STORE, use_fast_accum)
1117
+ fn[grid](*args)
1118
+ else:
1119
+ assert x_scale is None
1120
+ assert w_scale is None
1121
+ fn = (
1122
+ _fbgemm_grouped_gemm_ws if use_warp_specialization else _fbgemm_grouped_gemm
1123
+ )
1124
+ args = (
1125
+ desc_x,
1126
+ desc_w,
1127
+ y,
1128
+ scatter_add_indices,
1129
+ m_sizes,
1130
+ G,
1131
+ M_BUCKET,
1132
+ N,
1133
+ K,
1134
+ NUM_SMS,
1135
+ FUSE_SCATTER_ADD,
1136
+ USE_TMA_LOAD,
1137
+ )
1138
+ if use_warp_specialization:
1139
+ args += (use_fast_accum,)
1140
+ else:
1141
+ args += (USE_TMA_STORE, use_fast_accum)
1142
+ fn[grid](*args)
1143
+
1144
+ return y
1145
+
1146
+
1147
+ def grouped_gemm(
1148
+ x: torch.Tensor,
1149
+ w: torch.Tensor,
1150
+ m_sizes: torch.Tensor,
1151
+ use_fast_accum: bool = True,
1152
+ *,
1153
+ _use_warp_specialization: bool = True,
1154
+ _output_tensor: Optional[torch.Tensor] = None,
1155
+ _scatter_add_indices: Optional[torch.Tensor] = None,
1156
+ ) -> torch.Tensor:
1157
+ return _grouped_gemm(
1158
+ x=x,
1159
+ w=w,
1160
+ m_sizes=m_sizes,
1161
+ x_scale=None,
1162
+ w_scale=None,
1163
+ use_fast_accum=use_fast_accum,
1164
+ use_warp_specialization=_use_warp_specialization,
1165
+ output_tensor=_output_tensor,
1166
+ scatter_add_indices=_scatter_add_indices,
1167
+ )
1168
+
1169
+
1170
+ def grouped_gemm_fp8_rowwise(
1171
+ x: torch.Tensor,
1172
+ w: torch.Tensor,
1173
+ m_sizes: torch.Tensor,
1174
+ x_scale: torch.Tensor,
1175
+ w_scale: torch.Tensor,
1176
+ use_fast_accum: bool = True,
1177
+ *,
1178
+ _use_warp_specialization: bool = True,
1179
+ _output_tensor: Optional[torch.Tensor] = None,
1180
+ _scatter_add_indices: Optional[torch.Tensor] = None,
1181
+ ) -> torch.Tensor:
1182
+ return _grouped_gemm(
1183
+ x=x,
1184
+ w=w,
1185
+ m_sizes=m_sizes,
1186
+ x_scale=x_scale,
1187
+ w_scale=w_scale,
1188
+ use_fast_accum=use_fast_accum,
1189
+ use_warp_specialization=_use_warp_specialization,
1190
+ output_tensor=_output_tensor,
1191
+ scatter_add_indices=_scatter_add_indices,
1192
+ )