fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
Binary file
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# FBGEMM GenAI MoE Support
|
|
2
|
+
|
|
3
|
+
MetaShuffling MoE kernel support in FBGEMM GenAI kernel library.
|
|
4
|
+
|
|
5
|
+
# **Overview**
|
|
6
|
+
|
|
7
|
+
Mixture-of-Experts (MoE) is a popular model architecture for large language models (LLMs). Although it reduces computation in training and inference by activating less parameters per token, it imposes additional challenges in achieving optimal computation efficiency with high memory and communication pressure, as well as the complexity to handle the dynamism and sparsity nature of the model. Here we introduce a new MoE inference solution, MetaShuffling, which enables us to efficiently deploy Llama 4 models for real scenario inference.
|
|
8
|
+
|
|
9
|
+
[Technical design blog](https://pytorch.org/blog/metashuffling-accelerating-llama-4-moe-inference/).
|
|
10
|
+
|
|
11
|
+
# **Updates**
|
|
12
|
+
|
|
13
|
+
- 2025-05-01: Initial release of MetaShuffling MoE PyTorch examples.
|
|
14
|
+
|
|
15
|
+
- 2025-04-17: Initial release of MetaShuffling MoE GPU kernels.
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
# pyre-ignore[21]
|
|
15
|
+
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
|
|
16
|
+
from fbgemm_gpu import open_source
|
|
17
|
+
except Exception:
|
|
18
|
+
open_source: bool = False
|
|
19
|
+
|
|
20
|
+
# pyre-ignore[16]
|
|
21
|
+
if open_source:
|
|
22
|
+
torch.ops.load_library(
|
|
23
|
+
os.path.join(
|
|
24
|
+
os.path.dirname(os.path.dirname(__file__)),
|
|
25
|
+
"fbgemm_gpu_experimental_gen_ai.so",
|
|
26
|
+
)
|
|
27
|
+
)
|
|
28
|
+
torch.classes.load_library(
|
|
29
|
+
os.path.join(
|
|
30
|
+
os.path.dirname(os.path.dirname(__file__)),
|
|
31
|
+
"fbgemm_gpu_experimental_gen_ai.so",
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
torch.ops.load_library(
|
|
36
|
+
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:index_shuffling_ops"
|
|
37
|
+
)
|
|
38
|
+
torch.ops.load_library(
|
|
39
|
+
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:gather_scatter_ops"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
index_shuffling = None
|
|
43
|
+
gather_along_first_dim = None
|
|
44
|
+
scatter_add_along_first_dim = None
|
|
45
|
+
|
|
46
|
+
if torch.cuda.is_available():
|
|
47
|
+
index_shuffling = torch.ops.fbgemm.index_shuffling # noqa F401
|
|
48
|
+
if not torch.version.hip:
|
|
49
|
+
# SM90 support
|
|
50
|
+
gather_along_first_dim = torch.ops.fbgemm.gather_along_first_dim # noqa F401
|
|
51
|
+
scatter_add_along_first_dim = torch.ops.fbgemm.scatter_add_along_first_dim # noqa F401
|
|
52
|
+
|
|
53
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( # noqa F401
|
|
54
|
+
grouped_gemm,
|
|
55
|
+
grouped_gemm_fp8_rowwise,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
from .activation import silu_mul, silu_mul_quant # noqa F401
|
|
59
|
+
|
|
60
|
+
from .gather_scatter import ( # noqa F401
|
|
61
|
+
gather_scale_dense_tokens,
|
|
62
|
+
gather_scale_quant_dense_tokens,
|
|
63
|
+
scatter_add_dense_tokens,
|
|
64
|
+
scatter_add_padded_tokens,
|
|
65
|
+
)
|
|
66
|
+
from .shuffling import combine_shuffling, split_shuffling # noqa F401
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import triton
|
|
13
|
+
import triton.language as tl
|
|
14
|
+
|
|
15
|
+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import get_fp8_constants
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Function APIs
|
|
19
|
+
def silu_mul(
|
|
20
|
+
x0: torch.Tensor,
|
|
21
|
+
x1: torch.Tensor,
|
|
22
|
+
valid_token_count: Optional[torch.Tensor] = None,
|
|
23
|
+
) -> torch.Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Fused silu and mul operations.
|
|
26
|
+
|
|
27
|
+
y = x0 * sigmoid(x0) * x1
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
x0: input tensor of shape (T, D)
|
|
31
|
+
x1: input tensor of shape (T, D)
|
|
32
|
+
valid_token_count: tensor of shape (1,) to indicate the number of valid tokens.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
output tensor of shape (T, D)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
assert x0.ndim == 2 and x0.stride(1) == 1
|
|
39
|
+
assert x1.ndim == 2 and x1.stride(1) == 1
|
|
40
|
+
assert x0.shape == x1.shape
|
|
41
|
+
assert x0.dtype == x1.dtype
|
|
42
|
+
|
|
43
|
+
T, D = x0.shape
|
|
44
|
+
stride_0 = x0.stride(0)
|
|
45
|
+
stride_1 = x1.stride(0)
|
|
46
|
+
|
|
47
|
+
out = torch.empty((T, D), device="cuda", dtype=x0.dtype)
|
|
48
|
+
|
|
49
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
50
|
+
if T >= NUM_SMS:
|
|
51
|
+
BLOCK_D_OUTER = D
|
|
52
|
+
BLOCK_D_INNER = 1024
|
|
53
|
+
assert D % BLOCK_D_INNER == 0
|
|
54
|
+
else:
|
|
55
|
+
BLOCK_D_OUTER = 512
|
|
56
|
+
BLOCK_D_INNER = 256
|
|
57
|
+
assert D % BLOCK_D_OUTER == 0
|
|
58
|
+
grid = (T, D // BLOCK_D_OUTER)
|
|
59
|
+
_fbgemm_silu_mul[grid](
|
|
60
|
+
out,
|
|
61
|
+
x0,
|
|
62
|
+
x1,
|
|
63
|
+
stride_0,
|
|
64
|
+
stride_1,
|
|
65
|
+
valid_token_count,
|
|
66
|
+
D, # pyre-ignore
|
|
67
|
+
BLOCK_D_OUTER, # pyre-ignore
|
|
68
|
+
BLOCK_D_INNER, # pyre-ignore
|
|
69
|
+
)
|
|
70
|
+
return out
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def silu_mul_quant(
|
|
74
|
+
x0: torch.Tensor,
|
|
75
|
+
x1: torch.Tensor,
|
|
76
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
77
|
+
valid_token_count: Optional[torch.Tensor] = None,
|
|
78
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
79
|
+
"""
|
|
80
|
+
Fused silu, mul, and FP8 rowwise quantization operations.
|
|
81
|
+
|
|
82
|
+
y, y_scale = quantize(x0 * sigmoid(x0) * x1)
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
x0: input tensor of shape (T, D)
|
|
86
|
+
x1: input tensor of shape (T, D)
|
|
87
|
+
scale_ub: tensor of shape (1,) to indicate the upper bound of the scale.
|
|
88
|
+
valid_token_count: tensor of shape (1,) to indicate the number of valid tokens.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
output quantized tensor of shape (T, D) and its inverse scale of shape (T,)
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
assert x0.ndim == 2 and x0.stride(1) == 1
|
|
95
|
+
assert x1.ndim == 2 and x1.stride(1) == 1
|
|
96
|
+
assert x0.shape == x1.shape
|
|
97
|
+
assert x0.dtype == x1.dtype
|
|
98
|
+
|
|
99
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
100
|
+
|
|
101
|
+
T, D = x0.shape
|
|
102
|
+
stride_0 = x0.stride(0)
|
|
103
|
+
stride_1 = x1.stride(0)
|
|
104
|
+
|
|
105
|
+
out = torch.empty((T, D), device="cuda", dtype=pt_dtype)
|
|
106
|
+
out_inv_scale = torch.empty((T,), device="cuda", dtype=torch.float32)
|
|
107
|
+
if T == 0:
|
|
108
|
+
return out, out_inv_scale
|
|
109
|
+
|
|
110
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
111
|
+
BLOCK_T = triton.cdiv(T, NUM_SMS)
|
|
112
|
+
|
|
113
|
+
NUM_CTAS = triton.cdiv(T, BLOCK_T)
|
|
114
|
+
|
|
115
|
+
grid = (NUM_CTAS,)
|
|
116
|
+
_fbgemm_silu_mul_quant[grid](
|
|
117
|
+
out,
|
|
118
|
+
out_inv_scale,
|
|
119
|
+
x0,
|
|
120
|
+
x1,
|
|
121
|
+
scale_ub,
|
|
122
|
+
stride_0,
|
|
123
|
+
stride_1,
|
|
124
|
+
valid_token_count,
|
|
125
|
+
T,
|
|
126
|
+
D, # pyre-ignore
|
|
127
|
+
BLOCK_T,
|
|
128
|
+
TL_FP8_DTYPE=tl_dtype, # pyre-ignore
|
|
129
|
+
MAX_FP8=max_fp8, # pyre-ignore
|
|
130
|
+
EPS=eps, # pyre-ignore
|
|
131
|
+
CLAMP_MAX=scale_ub is not None, # pyre-ignore
|
|
132
|
+
)
|
|
133
|
+
return out, out_inv_scale
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Torch Custom Op Registrations
|
|
137
|
+
_SILU_MUL_OP_NAME = "fbgemm::silu_mul"
|
|
138
|
+
|
|
139
|
+
torch.library.define(
|
|
140
|
+
"fbgemm::silu_mul",
|
|
141
|
+
"(Tensor x0, Tensor x1, Tensor? valid_token_count=None) -> Tensor",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@torch.library.impl(_SILU_MUL_OP_NAME, "Meta")
|
|
146
|
+
def silu_mul_meta(x0, x1, valid_token_count):
|
|
147
|
+
return x0.new_empty(x0.shape)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@torch.library.impl(_SILU_MUL_OP_NAME, "CUDA")
|
|
151
|
+
def silu_mul_cuda(x0, x1, valid_token_count):
|
|
152
|
+
return silu_mul(x0, x1, valid_token_count)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
_SILU_MUL_OP_QUANT_NAME = "fbgemm::silu_mul_quant"
|
|
156
|
+
|
|
157
|
+
torch.library.define(
|
|
158
|
+
"fbgemm::silu_mul_quant",
|
|
159
|
+
"(Tensor x0, Tensor x1, Tensor? scale_ub=None, Tensor? valid_token_count=None) -> (Tensor, Tensor)",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@torch.library.impl(_SILU_MUL_OP_QUANT_NAME, "Meta")
|
|
164
|
+
def silu_mul_quant_meta(x0, x1, scale_ub, valid_token_count):
|
|
165
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
166
|
+
return torch.empty(x0.shape, device=x0.device, dtype=pt_dtype)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@torch.library.impl(_SILU_MUL_OP_QUANT_NAME, "CUDA")
|
|
170
|
+
def silu_mul_quant_cuda(x0, x1, scale_ub=None, valid_token_count=None):
|
|
171
|
+
return silu_mul_quant(x0, x1, scale_ub, valid_token_count)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# Kernel Implementations
|
|
175
|
+
@triton.jit
|
|
176
|
+
def _fbgemm_silu_mul(
|
|
177
|
+
y_ptr,
|
|
178
|
+
x0_ptr,
|
|
179
|
+
x1_ptr,
|
|
180
|
+
stride_0,
|
|
181
|
+
stride_1,
|
|
182
|
+
valid_token_count,
|
|
183
|
+
D: tl.constexpr,
|
|
184
|
+
BLOCK_D_OUTER: tl.constexpr,
|
|
185
|
+
BLOCK_D_INNER: tl.constexpr,
|
|
186
|
+
) -> None:
|
|
187
|
+
token_index = tl.program_id(0)
|
|
188
|
+
feature_offset = tl.program_id(1) * BLOCK_D_OUTER + tl.arange(0, BLOCK_D_INNER)[:]
|
|
189
|
+
|
|
190
|
+
if valid_token_count is not None:
|
|
191
|
+
valid_token_count = tl.load(
|
|
192
|
+
valid_token_count, None, eviction_policy="evict_last"
|
|
193
|
+
)
|
|
194
|
+
if token_index >= valid_token_count:
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
for _ in tl.range(0, BLOCK_D_OUTER // BLOCK_D_INNER, num_stages=3):
|
|
198
|
+
x0 = tl.load(
|
|
199
|
+
x0_ptr + token_index * stride_0 + feature_offset,
|
|
200
|
+
None,
|
|
201
|
+
eviction_policy="evict_first",
|
|
202
|
+
).to(tl.float32)
|
|
203
|
+
x1 = tl.load(
|
|
204
|
+
x1_ptr + token_index * stride_1 + feature_offset,
|
|
205
|
+
None,
|
|
206
|
+
eviction_policy="evict_first",
|
|
207
|
+
).to(tl.float32)
|
|
208
|
+
|
|
209
|
+
y = x0 * tl.sigmoid(x0) * x1
|
|
210
|
+
|
|
211
|
+
tl.store(
|
|
212
|
+
y_ptr + token_index * D + feature_offset,
|
|
213
|
+
y,
|
|
214
|
+
None,
|
|
215
|
+
)
|
|
216
|
+
feature_offset += BLOCK_D_INNER
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@triton.jit
|
|
220
|
+
def _fbgemm_silu_mul_quant(
|
|
221
|
+
y_ptr,
|
|
222
|
+
y_inv_scale_ptr,
|
|
223
|
+
x0_ptr,
|
|
224
|
+
x1_ptr,
|
|
225
|
+
scale_ub_ptr,
|
|
226
|
+
stride_0,
|
|
227
|
+
stride_1,
|
|
228
|
+
valid_token_count,
|
|
229
|
+
T,
|
|
230
|
+
D: tl.constexpr,
|
|
231
|
+
BLOCK_T: tl.constexpr,
|
|
232
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
233
|
+
MAX_FP8: tl.constexpr,
|
|
234
|
+
EPS: tl.constexpr,
|
|
235
|
+
CLAMP_MAX: tl.constexpr,
|
|
236
|
+
) -> None:
|
|
237
|
+
PADDED_D: tl.constexpr = triton.next_power_of_2(D) # pyre-ignore
|
|
238
|
+
|
|
239
|
+
tidx = tl.program_id(0)
|
|
240
|
+
start_idx = tidx * BLOCK_T
|
|
241
|
+
end_idx = tl.minimum(start_idx + BLOCK_T, T)
|
|
242
|
+
|
|
243
|
+
if valid_token_count is not None:
|
|
244
|
+
valid_token_count = tl.load(
|
|
245
|
+
valid_token_count, None, eviction_policy="evict_last"
|
|
246
|
+
)
|
|
247
|
+
if start_idx >= valid_token_count:
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
offsets = tl.arange(0, PADDED_D)[:]
|
|
251
|
+
mask = offsets < D
|
|
252
|
+
|
|
253
|
+
if CLAMP_MAX:
|
|
254
|
+
ub = tl.load(scale_ub_ptr, eviction_policy="evict_last")
|
|
255
|
+
else:
|
|
256
|
+
ub = float("inf")
|
|
257
|
+
|
|
258
|
+
for token_index in tl.range(start_idx, end_idx, 1, num_stages=2):
|
|
259
|
+
x0 = tl.load(
|
|
260
|
+
x0_ptr + token_index * stride_0 + offsets,
|
|
261
|
+
mask,
|
|
262
|
+
eviction_policy="evict_first",
|
|
263
|
+
).to(tl.float32)
|
|
264
|
+
x1 = tl.load(
|
|
265
|
+
x1_ptr + token_index * stride_1 + offsets,
|
|
266
|
+
mask,
|
|
267
|
+
eviction_policy="evict_first",
|
|
268
|
+
).to(tl.float32)
|
|
269
|
+
|
|
270
|
+
y = x0 * tl.sigmoid(x0) * x1
|
|
271
|
+
|
|
272
|
+
# Masked values are set to 0.0.
|
|
273
|
+
row_max = tl.max(tl.where(mask, tl.abs(y), 0.0))
|
|
274
|
+
if CLAMP_MAX:
|
|
275
|
+
row_max = tl.clamp(row_max, EPS, ub)
|
|
276
|
+
else:
|
|
277
|
+
row_max = tl.maximum(row_max, EPS)
|
|
278
|
+
|
|
279
|
+
y_scale = MAX_FP8 / row_max
|
|
280
|
+
tl.store(y_inv_scale_ptr + token_index, 1.0 / y_scale)
|
|
281
|
+
|
|
282
|
+
y = y * y_scale
|
|
283
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
284
|
+
# This is required for AMD. Nvidia's default saturation
|
|
285
|
+
# handles it, but it's nice to have anyway.
|
|
286
|
+
y_fp8 = tl.clamp(y, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
|
287
|
+
|
|
288
|
+
tl.store(
|
|
289
|
+
y_ptr + token_index * D + offsets,
|
|
290
|
+
y_fp8,
|
|
291
|
+
mask,
|
|
292
|
+
)
|