fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,1192 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import functools
|
|
10
|
+
import inspect
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
import triton
|
|
18
|
+
import triton.language as tl
|
|
19
|
+
|
|
20
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm import utils
|
|
21
|
+
from triton.runtime import driver # @manual
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_NV_CONFIGS = [
|
|
25
|
+
triton.Config(
|
|
26
|
+
{
|
|
27
|
+
"BLOCK_SIZE_M": block_size_m,
|
|
28
|
+
"BLOCK_SIZE_N": block_size_n,
|
|
29
|
+
"BLOCK_SIZE_K": block_size_k,
|
|
30
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
31
|
+
},
|
|
32
|
+
num_stages=num_stages,
|
|
33
|
+
num_warps=num_warps,
|
|
34
|
+
num_ctas=num_ctas,
|
|
35
|
+
)
|
|
36
|
+
for block_size_m in [64, 128]
|
|
37
|
+
for block_size_n in [64, 128, 256]
|
|
38
|
+
for block_size_k in [64, 128, 256]
|
|
39
|
+
for num_stages in [3, 4]
|
|
40
|
+
for num_warps in [4, 8]
|
|
41
|
+
for num_ctas in [1]
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
_HAS_WS_SUPPORT = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _check_ws_support():
|
|
48
|
+
if not hasattr(tl, "async_task"):
|
|
49
|
+
return False
|
|
50
|
+
config_signature = inspect.signature(triton.Config).parameters
|
|
51
|
+
if (
|
|
52
|
+
"num_consumer_groups" not in config_signature
|
|
53
|
+
or "num_buffers_warp_spec" not in config_signature
|
|
54
|
+
):
|
|
55
|
+
return False
|
|
56
|
+
if not utils.HAS_TMA_DESC:
|
|
57
|
+
return False
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _set_ws_support():
|
|
62
|
+
global _HAS_WS_SUPPORT
|
|
63
|
+
if _HAS_WS_SUPPORT is None:
|
|
64
|
+
_HAS_WS_SUPPORT = _check_ws_support()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
_set_ws_support()
|
|
68
|
+
|
|
69
|
+
if _HAS_WS_SUPPORT:
|
|
70
|
+
_NV_WS_CONFIGS = [
|
|
71
|
+
triton.Config(
|
|
72
|
+
{
|
|
73
|
+
"BLOCK_SIZE_M": block_size_m,
|
|
74
|
+
"BLOCK_SIZE_N": block_size_n,
|
|
75
|
+
"BLOCK_SIZE_K": block_size_k,
|
|
76
|
+
"NUM_CONSUMER_GROUPS": max(1, num_consumer_groups),
|
|
77
|
+
"USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales,
|
|
78
|
+
"USE_TMA_STORE": use_tma_store,
|
|
79
|
+
},
|
|
80
|
+
num_stages=num_stages,
|
|
81
|
+
num_warps=num_warps,
|
|
82
|
+
num_ctas=num_ctas,
|
|
83
|
+
num_consumer_groups=num_consumer_groups,
|
|
84
|
+
num_buffers_warp_spec=num_stages,
|
|
85
|
+
)
|
|
86
|
+
for block_size_m in [64, 128, 256]
|
|
87
|
+
for block_size_n in [64, 128, 256]
|
|
88
|
+
for block_size_k in [64, 128, 256]
|
|
89
|
+
for num_stages in [2, 3, 4]
|
|
90
|
+
for num_warps in [4, 8, 16]
|
|
91
|
+
# TODO(shikaili): Resolve LLVM error.
|
|
92
|
+
for num_ctas in [1]
|
|
93
|
+
for num_consumer_groups in [0, 2]
|
|
94
|
+
for use_tma_load_on_scales in [True, False]
|
|
95
|
+
# TODO(shikaili): Resolve compatibility with ws.
|
|
96
|
+
for use_tma_store in [False]
|
|
97
|
+
]
|
|
98
|
+
else:
|
|
99
|
+
_NV_WS_CONFIGS = _NV_CONFIGS
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
_AMD_CONFIGS = [
|
|
103
|
+
triton.Config(
|
|
104
|
+
{
|
|
105
|
+
"BLOCK_SIZE_M": block_size_m,
|
|
106
|
+
"BLOCK_SIZE_N": block_size_n,
|
|
107
|
+
"BLOCK_SIZE_K": block_size_k,
|
|
108
|
+
"waves_per_eu": waves_per_cu,
|
|
109
|
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
110
|
+
"NUM_CONSUMER_GROUPS": 1,
|
|
111
|
+
},
|
|
112
|
+
num_stages=num_stages,
|
|
113
|
+
num_warps=num_warps,
|
|
114
|
+
)
|
|
115
|
+
for block_size_m in [32, 64, 128]
|
|
116
|
+
for block_size_n in [32, 64, 128, 256]
|
|
117
|
+
for block_size_k in [128, 256]
|
|
118
|
+
for num_stages in [1, 2]
|
|
119
|
+
for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
|
|
120
|
+
for matrix_instr_nonkdim in [16]
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
|
125
|
+
device = torch.cuda.current_device()
|
|
126
|
+
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
|
127
|
+
if dtsize is None:
|
|
128
|
+
dtsize = named_args["c_ptr"].element_size()
|
|
129
|
+
if dtype is None:
|
|
130
|
+
dtype = named_args["c_ptr"].dtype
|
|
131
|
+
|
|
132
|
+
pruned_configs = []
|
|
133
|
+
for config in configs:
|
|
134
|
+
kw = config.kwargs
|
|
135
|
+
(
|
|
136
|
+
BLOCK_M,
|
|
137
|
+
BLOCK_N,
|
|
138
|
+
BLOCK_K,
|
|
139
|
+
num_stages,
|
|
140
|
+
use_tma_load_on_scales,
|
|
141
|
+
) = (
|
|
142
|
+
kw["BLOCK_SIZE_M"],
|
|
143
|
+
kw["BLOCK_SIZE_N"],
|
|
144
|
+
kw["BLOCK_SIZE_K"],
|
|
145
|
+
config.num_stages,
|
|
146
|
+
kw.get("USE_TMA_LOAD_ON_SCALES", False),
|
|
147
|
+
)
|
|
148
|
+
G, M, N = (
|
|
149
|
+
named_args["G"],
|
|
150
|
+
named_args["M_BUCKET"],
|
|
151
|
+
named_args["N"],
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# 1. make sure we have enough smem
|
|
155
|
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
|
156
|
+
"max_shared_mem"
|
|
157
|
+
]
|
|
158
|
+
if torch.version.hip:
|
|
159
|
+
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
|
|
160
|
+
else:
|
|
161
|
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
|
162
|
+
if required_shared_memory > max_shared_memory:
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
M_PER_GROUP = M // G
|
|
166
|
+
MIN_M_TILES = 32 if torch.version.hip else 64
|
|
167
|
+
# 2. make sure we don't load M tiles that are too big
|
|
168
|
+
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
|
169
|
+
continue
|
|
170
|
+
# 3. make sure we don't load N tiles that are too small
|
|
171
|
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
num_sm = driver.active.utils.get_device_properties(device)[
|
|
175
|
+
"multiprocessor_count"
|
|
176
|
+
]
|
|
177
|
+
N_TILES = (N + BLOCK_N - 1) // BLOCK_N
|
|
178
|
+
MIN_N_TILES = 32 if torch.version.hip else 64
|
|
179
|
+
# 4. make sure we don't load N tiles that are too big
|
|
180
|
+
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
|
181
|
+
continue
|
|
182
|
+
# 5. make sure we don't load N tiles that are too small
|
|
183
|
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
|
184
|
+
continue
|
|
185
|
+
if dtsize >= 2:
|
|
186
|
+
if use_tma_load_on_scales:
|
|
187
|
+
continue
|
|
188
|
+
pruned_configs.append(config)
|
|
189
|
+
|
|
190
|
+
return pruned_configs
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def early_config_prune_ws(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
|
194
|
+
device = torch.cuda.current_device()
|
|
195
|
+
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
|
196
|
+
if dtsize is None:
|
|
197
|
+
dtsize = named_args["c_ptr"].element_size()
|
|
198
|
+
if dtype is None:
|
|
199
|
+
dtype = named_args["c_ptr"].dtype
|
|
200
|
+
|
|
201
|
+
pruned_configs = []
|
|
202
|
+
for config in configs:
|
|
203
|
+
kw = config.kwargs
|
|
204
|
+
(
|
|
205
|
+
BLOCK_M,
|
|
206
|
+
BLOCK_N,
|
|
207
|
+
BLOCK_K,
|
|
208
|
+
num_stages,
|
|
209
|
+
num_warps,
|
|
210
|
+
num_consumer_groups,
|
|
211
|
+
use_tma_load_on_scales,
|
|
212
|
+
) = (
|
|
213
|
+
kw["BLOCK_SIZE_M"],
|
|
214
|
+
kw["BLOCK_SIZE_N"],
|
|
215
|
+
kw["BLOCK_SIZE_K"],
|
|
216
|
+
config.num_stages,
|
|
217
|
+
config.num_warps,
|
|
218
|
+
config.num_consumer_groups,
|
|
219
|
+
kw.get("USE_TMA_LOAD_ON_SCALES", False),
|
|
220
|
+
)
|
|
221
|
+
G, M, N = (
|
|
222
|
+
named_args["G"],
|
|
223
|
+
named_args["M_BUCKET"],
|
|
224
|
+
named_args["N"],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# 1. make sure we have enough smem
|
|
228
|
+
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
|
229
|
+
"max_shared_mem"
|
|
230
|
+
]
|
|
231
|
+
if torch.version.hip:
|
|
232
|
+
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
|
|
233
|
+
else:
|
|
234
|
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
|
235
|
+
if required_shared_memory > max_shared_memory:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
use_warp_specialization = num_consumer_groups >= 1
|
|
239
|
+
|
|
240
|
+
M_PER_GROUP = M // G
|
|
241
|
+
MIN_M_TILES = 32 if torch.version.hip else 64
|
|
242
|
+
# 2. make sure we don't load M tiles that are too big
|
|
243
|
+
if (
|
|
244
|
+
not use_warp_specialization
|
|
245
|
+
and BLOCK_M > MIN_M_TILES
|
|
246
|
+
and BLOCK_M > (M_PER_GROUP * 2)
|
|
247
|
+
):
|
|
248
|
+
continue
|
|
249
|
+
# 3. make sure we don't load N tiles that are too small
|
|
250
|
+
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
|
251
|
+
continue
|
|
252
|
+
|
|
253
|
+
num_sm = driver.active.utils.get_device_properties(device)[
|
|
254
|
+
"multiprocessor_count"
|
|
255
|
+
]
|
|
256
|
+
N_TILES = (N + BLOCK_N - 1) // BLOCK_N
|
|
257
|
+
MIN_N_TILES = 32 if torch.version.hip else 64
|
|
258
|
+
# 4. make sure we don't load N tiles that are too big
|
|
259
|
+
if (
|
|
260
|
+
not use_warp_specialization
|
|
261
|
+
and BLOCK_N > MIN_N_TILES
|
|
262
|
+
and M * N_TILES < num_sm
|
|
263
|
+
):
|
|
264
|
+
continue
|
|
265
|
+
# 5. make sure we don't load N tiles that are too small
|
|
266
|
+
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
|
267
|
+
continue
|
|
268
|
+
|
|
269
|
+
# 6. make sure we can partition for ws
|
|
270
|
+
if use_warp_specialization:
|
|
271
|
+
if num_warps != 4:
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
# "tritongpu-warp-spec-data-partition"
|
|
275
|
+
m_slice = BLOCK_M // num_consumer_groups
|
|
276
|
+
n_slice = BLOCK_N // num_consumer_groups
|
|
277
|
+
if m_slice < 64 and n_slice < 256:
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
if dtsize >= 2:
|
|
281
|
+
if use_tma_load_on_scales:
|
|
282
|
+
continue
|
|
283
|
+
pruned_configs.append(config)
|
|
284
|
+
|
|
285
|
+
return pruned_configs
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@triton.autotune(
|
|
289
|
+
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
290
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
291
|
+
prune_configs_by={"early_config_prune": early_config_prune},
|
|
292
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
293
|
+
)
|
|
294
|
+
@triton.jit
|
|
295
|
+
def _fbgemm_grouped_gemm(
|
|
296
|
+
a_desc_ptr,
|
|
297
|
+
b_desc_ptr,
|
|
298
|
+
c_ptr,
|
|
299
|
+
scatter_add_indices,
|
|
300
|
+
m_sizes,
|
|
301
|
+
# problem sizes
|
|
302
|
+
G: tl.constexpr,
|
|
303
|
+
M_BUCKET,
|
|
304
|
+
N: tl.constexpr,
|
|
305
|
+
K: tl.constexpr,
|
|
306
|
+
NUM_SMS: tl.constexpr,
|
|
307
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
308
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
309
|
+
USE_TMA_STORE: tl.constexpr,
|
|
310
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
311
|
+
# tile sizes
|
|
312
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
313
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
314
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
315
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
316
|
+
) -> None:
|
|
317
|
+
tl.static_assert(
|
|
318
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
319
|
+
"Cannot fuse scatter add with TMA store!",
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
tidx = tl.program_id(0)
|
|
323
|
+
|
|
324
|
+
dtype: tl.dtype = c_ptr.dtype.element_ty
|
|
325
|
+
|
|
326
|
+
M_end_offset = 0
|
|
327
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
328
|
+
iterated_tiles = 0
|
|
329
|
+
for g in tl.range(G):
|
|
330
|
+
# Move across groups
|
|
331
|
+
m_size = tl.load(m_sizes + g)
|
|
332
|
+
|
|
333
|
+
if m_size > 0:
|
|
334
|
+
M_start_offset = M_end_offset
|
|
335
|
+
M_end_offset = M_start_offset + m_size
|
|
336
|
+
N_start_offset = g.to(tl.int64) * N
|
|
337
|
+
n_size = N
|
|
338
|
+
|
|
339
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
340
|
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
|
341
|
+
num_tiles = num_m_tiles * num_n_tiles
|
|
342
|
+
|
|
343
|
+
if USE_TMA_STORE:
|
|
344
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
345
|
+
c_ptr + M_start_offset * N,
|
|
346
|
+
shape=[m_size, n_size],
|
|
347
|
+
# pyre-ignore
|
|
348
|
+
strides=[n_size, 1],
|
|
349
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Move across tiles
|
|
353
|
+
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
|
|
354
|
+
gidx = tidx - iterated_tiles
|
|
355
|
+
# Split M first and N second.
|
|
356
|
+
tile_m_idx = gidx % num_m_tiles
|
|
357
|
+
tile_n_idx = gidx // num_m_tiles
|
|
358
|
+
|
|
359
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
360
|
+
|
|
361
|
+
if USE_TMA_LOAD:
|
|
362
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
363
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
364
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
365
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
366
|
+
a = tl._experimental_descriptor_load(
|
|
367
|
+
a_desc_ptr,
|
|
368
|
+
[m_offset, k_offset],
|
|
369
|
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
370
|
+
dtype,
|
|
371
|
+
)
|
|
372
|
+
b = tl._experimental_descriptor_load(
|
|
373
|
+
b_desc_ptr,
|
|
374
|
+
[n_offset, k_offset],
|
|
375
|
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
376
|
+
dtype,
|
|
377
|
+
)
|
|
378
|
+
if USE_FAST_ACCUM:
|
|
379
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
380
|
+
else:
|
|
381
|
+
accumulator += tl.dot(a, b.T)
|
|
382
|
+
else:
|
|
383
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
384
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
385
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
386
|
+
a_ptrs = (
|
|
387
|
+
a_desc_ptr
|
|
388
|
+
+ (M_start_offset + offs_am[:, None]) * K
|
|
389
|
+
+ offs_k[None, :]
|
|
390
|
+
)
|
|
391
|
+
b_ptrs = (
|
|
392
|
+
b_desc_ptr
|
|
393
|
+
+ (N_start_offset + offs_bn[:, None]) * K
|
|
394
|
+
+ offs_k[None, :]
|
|
395
|
+
)
|
|
396
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
397
|
+
updated_k_offset = k_offset + offs_k
|
|
398
|
+
updated_k_offset_mask = updated_k_offset[None, :] < K # type: ignore[16]
|
|
399
|
+
a = tl.load(
|
|
400
|
+
a_ptrs,
|
|
401
|
+
mask=((offs_am[:, None] < m_size) & updated_k_offset_mask),
|
|
402
|
+
other=0.0,
|
|
403
|
+
)
|
|
404
|
+
b = tl.load(
|
|
405
|
+
b_ptrs,
|
|
406
|
+
mask=((offs_bn[:, None] < n_size) & updated_k_offset_mask),
|
|
407
|
+
other=0.0,
|
|
408
|
+
)
|
|
409
|
+
accumulator += tl.dot(a, b.T)
|
|
410
|
+
a_ptrs += BLOCK_SIZE_K
|
|
411
|
+
b_ptrs += BLOCK_SIZE_K
|
|
412
|
+
|
|
413
|
+
if USE_TMA_STORE:
|
|
414
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
415
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
416
|
+
# pyre-ignore
|
|
417
|
+
c_desc_ptr.store(
|
|
418
|
+
[m_offset, n_offset], accumulator.to(c_ptr.dtype.element_ty)
|
|
419
|
+
)
|
|
420
|
+
elif FUSE_SCATTER_ADD:
|
|
421
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
422
|
+
mask = offs_am < m_size
|
|
423
|
+
m_offsets = tl.load(
|
|
424
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
425
|
+
mask=mask,
|
|
426
|
+
cache_modifier=".ca",
|
|
427
|
+
)
|
|
428
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
429
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
430
|
+
tl.atomic_add(
|
|
431
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
432
|
+
c,
|
|
433
|
+
mask=mask[:, None] and offs_bn[None, :] < n_size,
|
|
434
|
+
sem="relaxed",
|
|
435
|
+
)
|
|
436
|
+
else:
|
|
437
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
438
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
439
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
440
|
+
tl.store(
|
|
441
|
+
c_ptr
|
|
442
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
443
|
+
+ offs_bn[None, :],
|
|
444
|
+
c,
|
|
445
|
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
|
446
|
+
)
|
|
447
|
+
tidx += NUM_SMS
|
|
448
|
+
|
|
449
|
+
iterated_tiles += num_tiles
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# TODO(shikaili): Too much code duplication. Need to refactor.
|
|
453
|
+
@triton.autotune(
|
|
454
|
+
configs=_NV_WS_CONFIGS,
|
|
455
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
456
|
+
prune_configs_by={"early_config_prune": early_config_prune_ws},
|
|
457
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
458
|
+
)
|
|
459
|
+
@triton.jit
|
|
460
|
+
def _fbgemm_grouped_gemm_ws(
|
|
461
|
+
a_desc_ptr,
|
|
462
|
+
b_desc_ptr,
|
|
463
|
+
c_ptr,
|
|
464
|
+
scatter_add_indices,
|
|
465
|
+
m_sizes,
|
|
466
|
+
# problem sizes
|
|
467
|
+
G: tl.constexpr,
|
|
468
|
+
M_BUCKET: tl.constexpr,
|
|
469
|
+
N: tl.constexpr,
|
|
470
|
+
K: tl.constexpr,
|
|
471
|
+
NUM_SMS: tl.constexpr,
|
|
472
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
473
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
474
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
475
|
+
# tile sizes
|
|
476
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
477
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
478
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
479
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
480
|
+
USE_TMA_LOAD_ON_SCALES: tl.constexpr,
|
|
481
|
+
USE_TMA_STORE: tl.constexpr,
|
|
482
|
+
) -> None:
|
|
483
|
+
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
|
|
484
|
+
tl.static_assert(not USE_TMA_LOAD_ON_SCALES, "Not supported!")
|
|
485
|
+
tl.static_assert(
|
|
486
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
487
|
+
"Cannot fuse scatter add with TMA store!",
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
tidx = tl.program_id(0)
|
|
491
|
+
|
|
492
|
+
dtype: tl.dtype = c_ptr.dtype.element_ty
|
|
493
|
+
|
|
494
|
+
M_end_offset = 0
|
|
495
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
496
|
+
iterated_tiles = 0
|
|
497
|
+
for g in tl.range(G):
|
|
498
|
+
# Move across groups
|
|
499
|
+
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
|
|
500
|
+
|
|
501
|
+
if m_size > 0:
|
|
502
|
+
M_start_offset = M_end_offset
|
|
503
|
+
M_end_offset = M_start_offset + m_size
|
|
504
|
+
N_start_offset = g.to(tl.int64) * N
|
|
505
|
+
|
|
506
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
507
|
+
tl.static_assert(N % BLOCK_SIZE_N == 0, f"{N=} {BLOCK_SIZE_N=}")
|
|
508
|
+
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
|
|
509
|
+
num_tiles = num_m_tiles * NUM_N_TILES
|
|
510
|
+
|
|
511
|
+
if USE_TMA_STORE:
|
|
512
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
513
|
+
c_ptr + M_start_offset * N,
|
|
514
|
+
shape=[m_size, N],
|
|
515
|
+
# pyre-ignore
|
|
516
|
+
strides=[N, 1],
|
|
517
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# Move across tiles
|
|
521
|
+
next_iterated_tiles = iterated_tiles + num_tiles
|
|
522
|
+
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
|
|
523
|
+
for i in range(tidx, next_iterated_tiles, NUM_SMS):
|
|
524
|
+
gidx = i - iterated_tiles
|
|
525
|
+
# Split M first and N second.
|
|
526
|
+
tile_m_idx = gidx % num_m_tiles
|
|
527
|
+
tile_n_idx = gidx // num_m_tiles
|
|
528
|
+
|
|
529
|
+
accumulator = tl.zeros(
|
|
530
|
+
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
531
|
+
)
|
|
532
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
533
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
534
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
535
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
536
|
+
a = tl._experimental_descriptor_load(
|
|
537
|
+
a_desc_ptr,
|
|
538
|
+
[m_offset, k_offset],
|
|
539
|
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
540
|
+
dtype,
|
|
541
|
+
)
|
|
542
|
+
b = tl._experimental_descriptor_load(
|
|
543
|
+
b_desc_ptr,
|
|
544
|
+
[n_offset, k_offset],
|
|
545
|
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
546
|
+
dtype,
|
|
547
|
+
)
|
|
548
|
+
if USE_FAST_ACCUM:
|
|
549
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
550
|
+
else:
|
|
551
|
+
accumulator += tl.dot(a, b.T)
|
|
552
|
+
|
|
553
|
+
if USE_TMA_STORE:
|
|
554
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
555
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
556
|
+
# pyre-ignore
|
|
557
|
+
c_desc_ptr.store(
|
|
558
|
+
[m_offset, n_offset],
|
|
559
|
+
accumulator.to(c_ptr.dtype.element_ty),
|
|
560
|
+
)
|
|
561
|
+
elif FUSE_SCATTER_ADD:
|
|
562
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
563
|
+
mask = offs_am < m_size
|
|
564
|
+
m_offsets = tl.load(
|
|
565
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
566
|
+
mask=mask,
|
|
567
|
+
cache_modifier=".ca",
|
|
568
|
+
)
|
|
569
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
570
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
571
|
+
tl.atomic_add(
|
|
572
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
573
|
+
c,
|
|
574
|
+
mask=mask[:, None],
|
|
575
|
+
sem="relaxed",
|
|
576
|
+
)
|
|
577
|
+
else:
|
|
578
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
579
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
580
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
581
|
+
tl.store(
|
|
582
|
+
c_ptr
|
|
583
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
584
|
+
+ offs_bn[None, :],
|
|
585
|
+
c,
|
|
586
|
+
mask=offs_am[:, None] < m_size,
|
|
587
|
+
cache_modifier=".cs",
|
|
588
|
+
)
|
|
589
|
+
tidx += NUM_SMS
|
|
590
|
+
|
|
591
|
+
iterated_tiles += num_tiles
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
# TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument.
|
|
598
|
+
@triton.autotune(
|
|
599
|
+
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
600
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
601
|
+
prune_configs_by={
|
|
602
|
+
"early_config_prune": functools.partial(
|
|
603
|
+
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
|
|
604
|
+
)
|
|
605
|
+
},
|
|
606
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
607
|
+
)
|
|
608
|
+
@triton.jit
|
|
609
|
+
def _fbgemm_grouped_gemm_fp8_rowwise(
|
|
610
|
+
a_desc_ptr,
|
|
611
|
+
a_scale_ptr,
|
|
612
|
+
b_desc_ptr,
|
|
613
|
+
b_scale_ptr,
|
|
614
|
+
b_scale_desc_ptr,
|
|
615
|
+
c_ptr,
|
|
616
|
+
scatter_add_indices,
|
|
617
|
+
m_sizes,
|
|
618
|
+
# problem sizes
|
|
619
|
+
G: tl.constexpr,
|
|
620
|
+
M_BUCKET,
|
|
621
|
+
N: tl.constexpr,
|
|
622
|
+
K: tl.constexpr,
|
|
623
|
+
NUM_SMS: tl.constexpr,
|
|
624
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
625
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
626
|
+
USE_TMA_STORE: tl.constexpr,
|
|
627
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
628
|
+
# tile sizes
|
|
629
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
630
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
631
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
632
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
633
|
+
) -> None:
|
|
634
|
+
tl.static_assert(
|
|
635
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
636
|
+
"Cannot fuse scatter add with TMA store!",
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
tidx = tl.program_id(0)
|
|
640
|
+
|
|
641
|
+
dtype = TT_FP8_DTYPE
|
|
642
|
+
|
|
643
|
+
M_end_offset = 0
|
|
644
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
645
|
+
iterated_tiles = 0
|
|
646
|
+
for g in tl.range(G):
|
|
647
|
+
# Move across groups
|
|
648
|
+
m_size = tl.load(m_sizes + g)
|
|
649
|
+
|
|
650
|
+
if m_size > 0:
|
|
651
|
+
M_start_offset = M_end_offset
|
|
652
|
+
M_end_offset = M_start_offset + m_size
|
|
653
|
+
N_start_offset = g.to(tl.int64) * N
|
|
654
|
+
n_size = N
|
|
655
|
+
|
|
656
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
657
|
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
|
658
|
+
num_tiles = num_m_tiles * num_n_tiles
|
|
659
|
+
|
|
660
|
+
if USE_TMA_STORE:
|
|
661
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
662
|
+
c_ptr + M_start_offset * N,
|
|
663
|
+
shape=[m_size, n_size],
|
|
664
|
+
# pyre-ignore
|
|
665
|
+
strides=[n_size, 1],
|
|
666
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
# Move across tiles
|
|
670
|
+
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
|
|
671
|
+
gidx = tidx - iterated_tiles
|
|
672
|
+
# Split M first and N second.
|
|
673
|
+
tile_m_idx = gidx % num_m_tiles
|
|
674
|
+
tile_n_idx = gidx // num_m_tiles
|
|
675
|
+
|
|
676
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
677
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
678
|
+
if USE_TMA_LOAD:
|
|
679
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
680
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
681
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
682
|
+
a = tl._experimental_descriptor_load(
|
|
683
|
+
a_desc_ptr,
|
|
684
|
+
[m_offset, k_offset],
|
|
685
|
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
686
|
+
dtype,
|
|
687
|
+
)
|
|
688
|
+
b = tl._experimental_descriptor_load(
|
|
689
|
+
b_desc_ptr,
|
|
690
|
+
[n_offset, k_offset],
|
|
691
|
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
692
|
+
dtype,
|
|
693
|
+
)
|
|
694
|
+
if USE_FAST_ACCUM:
|
|
695
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
696
|
+
else:
|
|
697
|
+
accumulator += tl.dot(a, b.T)
|
|
698
|
+
else:
|
|
699
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
700
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
701
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
702
|
+
a_ptrs = (
|
|
703
|
+
a_desc_ptr
|
|
704
|
+
+ (M_start_offset + offs_am[:, None]) * K
|
|
705
|
+
+ offs_k[None, :]
|
|
706
|
+
)
|
|
707
|
+
b_ptrs = (
|
|
708
|
+
b_desc_ptr
|
|
709
|
+
+ (N_start_offset + offs_bn[:, None]) * K
|
|
710
|
+
+ offs_k[None, :]
|
|
711
|
+
)
|
|
712
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
713
|
+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
|
714
|
+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
|
715
|
+
accumulator += tl.dot(a, b.T)
|
|
716
|
+
a_ptrs += BLOCK_SIZE_K
|
|
717
|
+
b_ptrs += BLOCK_SIZE_K
|
|
718
|
+
|
|
719
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
720
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
721
|
+
a_scale = tl.load(
|
|
722
|
+
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
723
|
+
mask=offs_am[:, None] < m_size,
|
|
724
|
+
)
|
|
725
|
+
b_scale = tl.load(
|
|
726
|
+
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
727
|
+
mask=offs_bn[None, :] < n_size,
|
|
728
|
+
)
|
|
729
|
+
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
730
|
+
|
|
731
|
+
if USE_TMA_STORE:
|
|
732
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
733
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
734
|
+
# pyre-ignore
|
|
735
|
+
c_desc_ptr.store([m_offset, n_offset], c.to(c_ptr.dtype.element_ty))
|
|
736
|
+
elif FUSE_SCATTER_ADD:
|
|
737
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
738
|
+
mask = offs_am < m_size
|
|
739
|
+
m_offsets = tl.load(
|
|
740
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
741
|
+
mask=mask,
|
|
742
|
+
cache_modifier=".ca",
|
|
743
|
+
)
|
|
744
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
745
|
+
tl.atomic_add(
|
|
746
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
747
|
+
c.to(c_ptr.dtype.element_ty),
|
|
748
|
+
mask=mask[:, None] and offs_bn[None, :] < n_size,
|
|
749
|
+
sem="relaxed",
|
|
750
|
+
)
|
|
751
|
+
else:
|
|
752
|
+
tl.store(
|
|
753
|
+
c_ptr
|
|
754
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
755
|
+
+ offs_bn[None, :],
|
|
756
|
+
c,
|
|
757
|
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
|
758
|
+
)
|
|
759
|
+
tidx += NUM_SMS
|
|
760
|
+
|
|
761
|
+
iterated_tiles += num_tiles
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
# TODO(shikaili): Too much code duplication. Need to refactor.
|
|
765
|
+
@triton.autotune(
|
|
766
|
+
configs=_NV_WS_CONFIGS,
|
|
767
|
+
key=["G", "M_BUCKET", "N", "K"],
|
|
768
|
+
prune_configs_by={
|
|
769
|
+
"early_config_prune": functools.partial(
|
|
770
|
+
early_config_prune_ws, dtype=TT_FP8_DTYPE, dtsize=1
|
|
771
|
+
)
|
|
772
|
+
},
|
|
773
|
+
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
774
|
+
)
|
|
775
|
+
@triton.jit
|
|
776
|
+
def _fbgemm_grouped_gemm_fp8_rowwise_ws(
|
|
777
|
+
a_desc_ptr,
|
|
778
|
+
a_scale_ptr,
|
|
779
|
+
b_desc_ptr,
|
|
780
|
+
b_scale_ptr,
|
|
781
|
+
b_scale_desc_ptr,
|
|
782
|
+
c_ptr,
|
|
783
|
+
scatter_add_indices,
|
|
784
|
+
m_sizes,
|
|
785
|
+
# problem sizes
|
|
786
|
+
G: tl.constexpr,
|
|
787
|
+
M_BUCKET: tl.constexpr,
|
|
788
|
+
N: tl.constexpr,
|
|
789
|
+
K: tl.constexpr,
|
|
790
|
+
NUM_SMS: tl.constexpr,
|
|
791
|
+
FUSE_SCATTER_ADD: tl.constexpr,
|
|
792
|
+
USE_TMA_LOAD: tl.constexpr,
|
|
793
|
+
USE_FAST_ACCUM: tl.constexpr,
|
|
794
|
+
# tile sizes
|
|
795
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
796
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
797
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
798
|
+
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
799
|
+
USE_TMA_LOAD_ON_SCALES: tl.constexpr,
|
|
800
|
+
USE_TMA_STORE: tl.constexpr,
|
|
801
|
+
) -> None:
|
|
802
|
+
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
|
|
803
|
+
tl.static_assert(
|
|
804
|
+
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
805
|
+
"Cannot fuse scatter add with TMA store!",
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
tidx = tl.program_id(0)
|
|
809
|
+
|
|
810
|
+
dtype = TT_FP8_DTYPE
|
|
811
|
+
|
|
812
|
+
M_end_offset = 0
|
|
813
|
+
M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
|
|
814
|
+
iterated_tiles = 0
|
|
815
|
+
for g in tl.range(G):
|
|
816
|
+
# Move across groups
|
|
817
|
+
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
|
|
818
|
+
|
|
819
|
+
if m_size > 0:
|
|
820
|
+
M_start_offset = M_end_offset
|
|
821
|
+
M_end_offset = M_start_offset + m_size
|
|
822
|
+
N_start_offset = g.to(tl.int64) * N
|
|
823
|
+
|
|
824
|
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
825
|
+
tl.static_assert(N % BLOCK_SIZE_N == 0)
|
|
826
|
+
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
|
|
827
|
+
num_tiles = num_m_tiles * NUM_N_TILES
|
|
828
|
+
|
|
829
|
+
if USE_TMA_STORE:
|
|
830
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
831
|
+
c_ptr + M_start_offset * N,
|
|
832
|
+
shape=[m_size, N],
|
|
833
|
+
# pyre-ignore
|
|
834
|
+
strides=[N, 1],
|
|
835
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Move across tiles
|
|
839
|
+
next_iterated_tiles = iterated_tiles + num_tiles
|
|
840
|
+
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
|
|
841
|
+
for i in range(tidx, next_iterated_tiles, NUM_SMS):
|
|
842
|
+
gidx = i - iterated_tiles
|
|
843
|
+
# Split M first and N second.
|
|
844
|
+
tile_m_idx = gidx % num_m_tiles
|
|
845
|
+
tile_n_idx = gidx // num_m_tiles
|
|
846
|
+
|
|
847
|
+
accumulator = tl.zeros(
|
|
848
|
+
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
849
|
+
)
|
|
850
|
+
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
851
|
+
|
|
852
|
+
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
853
|
+
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
854
|
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
855
|
+
a = tl._experimental_descriptor_load(
|
|
856
|
+
a_desc_ptr,
|
|
857
|
+
[m_offset, k_offset],
|
|
858
|
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
859
|
+
dtype,
|
|
860
|
+
)
|
|
861
|
+
b = tl._experimental_descriptor_load(
|
|
862
|
+
b_desc_ptr,
|
|
863
|
+
[n_offset, k_offset],
|
|
864
|
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
865
|
+
dtype,
|
|
866
|
+
)
|
|
867
|
+
if USE_FAST_ACCUM:
|
|
868
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
869
|
+
else:
|
|
870
|
+
accumulator += tl.dot(a, b.T)
|
|
871
|
+
|
|
872
|
+
if USE_TMA_LOAD_ON_SCALES:
|
|
873
|
+
b_scale = tl._experimental_descriptor_load(
|
|
874
|
+
b_scale_desc_ptr,
|
|
875
|
+
[n_offset],
|
|
876
|
+
[BLOCK_SIZE_N],
|
|
877
|
+
tl.float32,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
881
|
+
a_scale = tl.load(
|
|
882
|
+
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
883
|
+
mask=offs_am[:, None] < m_size,
|
|
884
|
+
cache_modifier=".ca",
|
|
885
|
+
)
|
|
886
|
+
c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
|
|
887
|
+
else:
|
|
888
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
889
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
890
|
+
a_scale = tl.load(
|
|
891
|
+
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
892
|
+
mask=offs_am[:, None] < m_size,
|
|
893
|
+
cache_modifier=".ca",
|
|
894
|
+
)
|
|
895
|
+
b_scale = tl.load(
|
|
896
|
+
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
897
|
+
cache_modifier=".ca",
|
|
898
|
+
)
|
|
899
|
+
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
900
|
+
|
|
901
|
+
if USE_TMA_STORE:
|
|
902
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
903
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
904
|
+
# pyre-ignore
|
|
905
|
+
c_desc_ptr.store(
|
|
906
|
+
[m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
|
|
907
|
+
)
|
|
908
|
+
elif FUSE_SCATTER_ADD:
|
|
909
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
910
|
+
mask = offs_am < m_size
|
|
911
|
+
m_offsets = tl.load(
|
|
912
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
913
|
+
mask=mask,
|
|
914
|
+
cache_modifier=".ca",
|
|
915
|
+
)
|
|
916
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
917
|
+
tl.atomic_add(
|
|
918
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
919
|
+
c,
|
|
920
|
+
mask=mask[:, None],
|
|
921
|
+
sem="relaxed",
|
|
922
|
+
)
|
|
923
|
+
else:
|
|
924
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
925
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
926
|
+
tl.store(
|
|
927
|
+
c_ptr
|
|
928
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
929
|
+
+ offs_bn[None, :],
|
|
930
|
+
c,
|
|
931
|
+
mask=offs_am[:, None] < m_size,
|
|
932
|
+
cache_modifier=".cs",
|
|
933
|
+
)
|
|
934
|
+
tidx += NUM_SMS
|
|
935
|
+
|
|
936
|
+
iterated_tiles += num_tiles
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
warnings.simplefilter("once")
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
def _grouped_gemm(
|
|
943
|
+
*,
|
|
944
|
+
x: torch.Tensor,
|
|
945
|
+
w: torch.Tensor,
|
|
946
|
+
m_sizes: torch.Tensor,
|
|
947
|
+
x_scale: Optional[torch.Tensor],
|
|
948
|
+
w_scale: Optional[torch.Tensor],
|
|
949
|
+
use_fast_accum: bool,
|
|
950
|
+
use_warp_specialization: bool,
|
|
951
|
+
output_tensor: Optional[torch.Tensor],
|
|
952
|
+
scatter_add_indices: Optional[torch.Tensor],
|
|
953
|
+
) -> torch.Tensor:
|
|
954
|
+
|
|
955
|
+
USE_TMA_LOAD = not torch.version.hip
|
|
956
|
+
USE_TMA_STORE = False
|
|
957
|
+
|
|
958
|
+
if USE_TMA_LOAD and not utils.HAS_TMA_DESC:
|
|
959
|
+
USE_TMA_LOAD = False
|
|
960
|
+
warnings.warn(
|
|
961
|
+
"TMA load is disabled as there is no TMA descriptor support!", stacklevel=2
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
if USE_TMA_STORE and not utils.HAS_TMA_DESC:
|
|
965
|
+
USE_TMA_STORE = False
|
|
966
|
+
warnings.warn(
|
|
967
|
+
"TMA store is disabled as there is no TMA descriptor support!", stacklevel=2
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
# TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
|
|
971
|
+
if use_warp_specialization and torch.version.hip:
|
|
972
|
+
warnings.warn(
|
|
973
|
+
"Warp specialization is disabled as it is not supported on ROCm.",
|
|
974
|
+
stacklevel=2,
|
|
975
|
+
)
|
|
976
|
+
use_warp_specialization = False
|
|
977
|
+
|
|
978
|
+
if use_warp_specialization and not _HAS_WS_SUPPORT:
|
|
979
|
+
warnings.warn(
|
|
980
|
+
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.",
|
|
981
|
+
stacklevel=2,
|
|
982
|
+
)
|
|
983
|
+
use_warp_specialization = False
|
|
984
|
+
|
|
985
|
+
if use_warp_specialization:
|
|
986
|
+
assert utils.HAS_TMA_DESC
|
|
987
|
+
USE_TMA_STORE = True # Tuning decision
|
|
988
|
+
|
|
989
|
+
G = m_sizes.shape[0]
|
|
990
|
+
|
|
991
|
+
assert x.is_contiguous()
|
|
992
|
+
assert w.is_contiguous()
|
|
993
|
+
assert m_sizes.is_contiguous()
|
|
994
|
+
|
|
995
|
+
M, K = x.shape
|
|
996
|
+
N = w.shape[0] // G
|
|
997
|
+
assert K == w.shape[1]
|
|
998
|
+
|
|
999
|
+
if K % 8 != 0 or N % 8 != 0:
|
|
1000
|
+
use_warp_specialization = False
|
|
1001
|
+
USE_TMA_LOAD = False
|
|
1002
|
+
USE_TMA_STORE = False
|
|
1003
|
+
warnings.warn(
|
|
1004
|
+
f"TMA load and warp specialization are disabled since K or N is not a multiple of 8: {K=}, {N=}.",
|
|
1005
|
+
stacklevel=2,
|
|
1006
|
+
)
|
|
1007
|
+
assert (
|
|
1008
|
+
x_scale is None
|
|
1009
|
+
), f"Quantisation is not supported yet when K or N is not a multiple of 8: {K=}, {N=}."
|
|
1010
|
+
|
|
1011
|
+
assert (
|
|
1012
|
+
output_tensor is None
|
|
1013
|
+
), f"Fused scatter add has large rounding error when K or N is not a multiple of 8: {K=}, {N=}."
|
|
1014
|
+
|
|
1015
|
+
if output_tensor is None:
|
|
1016
|
+
FUSE_SCATTER_ADD = False
|
|
1017
|
+
assert scatter_add_indices is None
|
|
1018
|
+
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
|
|
1019
|
+
else:
|
|
1020
|
+
FUSE_SCATTER_ADD = True
|
|
1021
|
+
assert scatter_add_indices is not None
|
|
1022
|
+
assert scatter_add_indices.is_contiguous()
|
|
1023
|
+
assert scatter_add_indices.shape == (M,)
|
|
1024
|
+
y = output_tensor
|
|
1025
|
+
if M == 0 or N == 0:
|
|
1026
|
+
return y
|
|
1027
|
+
|
|
1028
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
1029
|
+
|
|
1030
|
+
desc_helper = None
|
|
1031
|
+
desc_x = x
|
|
1032
|
+
desc_w = w
|
|
1033
|
+
desc_ws = w_scale
|
|
1034
|
+
|
|
1035
|
+
if USE_TMA_LOAD:
|
|
1036
|
+
desc_helper = utils.TmaAutoTuneHelper()
|
|
1037
|
+
desc_helper.init_tma_descriptor("x")
|
|
1038
|
+
desc_helper.init_tma_descriptor("w")
|
|
1039
|
+
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
|
|
1040
|
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
|
1041
|
+
if use_warp_specialization and w_scale is not None:
|
|
1042
|
+
desc_helper.init_tma_descriptor("ws")
|
|
1043
|
+
desc_ws = desc_helper.get_tma_descriptor_kernel_param("ws")
|
|
1044
|
+
|
|
1045
|
+
if USE_TMA_STORE:
|
|
1046
|
+
|
|
1047
|
+
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
|
|
1048
|
+
return torch.empty(size, device="cuda", dtype=torch.int8)
|
|
1049
|
+
|
|
1050
|
+
triton.set_allocator(alloc_fn)
|
|
1051
|
+
|
|
1052
|
+
def grid(META):
|
|
1053
|
+
if USE_TMA_LOAD:
|
|
1054
|
+
nonlocal desc_helper # noqa: F824
|
|
1055
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1056
|
+
"x",
|
|
1057
|
+
x.data_ptr(),
|
|
1058
|
+
M,
|
|
1059
|
+
K,
|
|
1060
|
+
META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
1061
|
+
META["BLOCK_SIZE_K"],
|
|
1062
|
+
x.element_size(),
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
desc_helper.fill_2d_tma_descriptor(
|
|
1066
|
+
"w",
|
|
1067
|
+
w.data_ptr(),
|
|
1068
|
+
N * G,
|
|
1069
|
+
K,
|
|
1070
|
+
META["BLOCK_SIZE_N"],
|
|
1071
|
+
META["BLOCK_SIZE_K"],
|
|
1072
|
+
w.element_size(),
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
if META.get("USE_TMA_LOAD_ON_SCALES", False):
|
|
1076
|
+
desc_helper.fill_1d_tma_descriptor(
|
|
1077
|
+
"ws",
|
|
1078
|
+
w_scale.data_ptr(),
|
|
1079
|
+
N * G,
|
|
1080
|
+
META["BLOCK_SIZE_N"],
|
|
1081
|
+
w_scale.element_size(),
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
return (NUM_SMS,)
|
|
1085
|
+
|
|
1086
|
+
M_BUCKET_CAP = 16384
|
|
1087
|
+
M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
|
|
1088
|
+
if x_scale is not None and w_scale is not None:
|
|
1089
|
+
assert x_scale.is_contiguous()
|
|
1090
|
+
assert w_scale.is_contiguous()
|
|
1091
|
+
fn = (
|
|
1092
|
+
_fbgemm_grouped_gemm_fp8_rowwise_ws
|
|
1093
|
+
if use_warp_specialization
|
|
1094
|
+
else _fbgemm_grouped_gemm_fp8_rowwise
|
|
1095
|
+
)
|
|
1096
|
+
args = (
|
|
1097
|
+
desc_x,
|
|
1098
|
+
x_scale,
|
|
1099
|
+
desc_w,
|
|
1100
|
+
w_scale,
|
|
1101
|
+
desc_ws,
|
|
1102
|
+
y,
|
|
1103
|
+
scatter_add_indices,
|
|
1104
|
+
m_sizes,
|
|
1105
|
+
G,
|
|
1106
|
+
M_BUCKET,
|
|
1107
|
+
N,
|
|
1108
|
+
K,
|
|
1109
|
+
NUM_SMS,
|
|
1110
|
+
FUSE_SCATTER_ADD,
|
|
1111
|
+
USE_TMA_LOAD,
|
|
1112
|
+
)
|
|
1113
|
+
if use_warp_specialization:
|
|
1114
|
+
args += (use_fast_accum,)
|
|
1115
|
+
else:
|
|
1116
|
+
args += (USE_TMA_STORE, use_fast_accum)
|
|
1117
|
+
fn[grid](*args)
|
|
1118
|
+
else:
|
|
1119
|
+
assert x_scale is None
|
|
1120
|
+
assert w_scale is None
|
|
1121
|
+
fn = (
|
|
1122
|
+
_fbgemm_grouped_gemm_ws if use_warp_specialization else _fbgemm_grouped_gemm
|
|
1123
|
+
)
|
|
1124
|
+
args = (
|
|
1125
|
+
desc_x,
|
|
1126
|
+
desc_w,
|
|
1127
|
+
y,
|
|
1128
|
+
scatter_add_indices,
|
|
1129
|
+
m_sizes,
|
|
1130
|
+
G,
|
|
1131
|
+
M_BUCKET,
|
|
1132
|
+
N,
|
|
1133
|
+
K,
|
|
1134
|
+
NUM_SMS,
|
|
1135
|
+
FUSE_SCATTER_ADD,
|
|
1136
|
+
USE_TMA_LOAD,
|
|
1137
|
+
)
|
|
1138
|
+
if use_warp_specialization:
|
|
1139
|
+
args += (use_fast_accum,)
|
|
1140
|
+
else:
|
|
1141
|
+
args += (USE_TMA_STORE, use_fast_accum)
|
|
1142
|
+
fn[grid](*args)
|
|
1143
|
+
|
|
1144
|
+
return y
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
def grouped_gemm(
|
|
1148
|
+
x: torch.Tensor,
|
|
1149
|
+
w: torch.Tensor,
|
|
1150
|
+
m_sizes: torch.Tensor,
|
|
1151
|
+
use_fast_accum: bool = True,
|
|
1152
|
+
*,
|
|
1153
|
+
_use_warp_specialization: bool = True,
|
|
1154
|
+
_output_tensor: Optional[torch.Tensor] = None,
|
|
1155
|
+
_scatter_add_indices: Optional[torch.Tensor] = None,
|
|
1156
|
+
) -> torch.Tensor:
|
|
1157
|
+
return _grouped_gemm(
|
|
1158
|
+
x=x,
|
|
1159
|
+
w=w,
|
|
1160
|
+
m_sizes=m_sizes,
|
|
1161
|
+
x_scale=None,
|
|
1162
|
+
w_scale=None,
|
|
1163
|
+
use_fast_accum=use_fast_accum,
|
|
1164
|
+
use_warp_specialization=_use_warp_specialization,
|
|
1165
|
+
output_tensor=_output_tensor,
|
|
1166
|
+
scatter_add_indices=_scatter_add_indices,
|
|
1167
|
+
)
|
|
1168
|
+
|
|
1169
|
+
|
|
1170
|
+
def grouped_gemm_fp8_rowwise(
|
|
1171
|
+
x: torch.Tensor,
|
|
1172
|
+
w: torch.Tensor,
|
|
1173
|
+
m_sizes: torch.Tensor,
|
|
1174
|
+
x_scale: torch.Tensor,
|
|
1175
|
+
w_scale: torch.Tensor,
|
|
1176
|
+
use_fast_accum: bool = True,
|
|
1177
|
+
*,
|
|
1178
|
+
_use_warp_specialization: bool = True,
|
|
1179
|
+
_output_tensor: Optional[torch.Tensor] = None,
|
|
1180
|
+
_scatter_add_indices: Optional[torch.Tensor] = None,
|
|
1181
|
+
) -> torch.Tensor:
|
|
1182
|
+
return _grouped_gemm(
|
|
1183
|
+
x=x,
|
|
1184
|
+
w=w,
|
|
1185
|
+
m_sizes=m_sizes,
|
|
1186
|
+
x_scale=x_scale,
|
|
1187
|
+
w_scale=w_scale,
|
|
1188
|
+
use_fast_accum=use_fast_accum,
|
|
1189
|
+
use_warp_specialization=_use_warp_specialization,
|
|
1190
|
+
output_tensor=_output_tensor,
|
|
1191
|
+
scatter_add_indices=_scatter_add_indices,
|
|
1192
|
+
)
|