sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,28 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
1
2
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
-
#
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py
|
3
16
|
|
4
17
|
from __future__ import annotations
|
5
18
|
|
6
|
-
import importlib.util
|
7
19
|
import logging
|
8
20
|
from typing import TYPE_CHECKING, List, Optional
|
9
21
|
|
10
22
|
import torch
|
11
|
-
import triton.language as tl
|
12
23
|
from torch.nn.parameter import Parameter
|
13
24
|
|
25
|
+
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
14
26
|
from sglang.srt.layers.quantization.base_config import (
|
15
27
|
FusedMoEMethodBase,
|
16
28
|
QuantizationConfig,
|
@@ -27,6 +39,7 @@ from sglang.srt.utils import (
|
|
27
39
|
is_hip,
|
28
40
|
is_triton_kernels_available,
|
29
41
|
log_info_on_rank0,
|
42
|
+
mxfp_supported,
|
30
43
|
next_power_of_2,
|
31
44
|
round_up,
|
32
45
|
set_weight_attrs,
|
@@ -47,9 +60,17 @@ if is_flashinfer_available():
|
|
47
60
|
logger = logging.getLogger(__name__)
|
48
61
|
|
49
62
|
if TYPE_CHECKING:
|
63
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
50
64
|
from sglang.srt.layers.moe.topk import TopKOutput
|
51
65
|
|
52
|
-
|
66
|
+
_is_hip = is_hip()
|
67
|
+
|
68
|
+
if _is_hip:
|
69
|
+
# import aiter
|
70
|
+
from aiter import ActivationType, QuantType, dtypes
|
71
|
+
from aiter.fused_moe import fused_moe
|
72
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
73
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
53
74
|
|
54
75
|
|
55
76
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
@@ -150,13 +171,34 @@ except AttributeError as error:
|
|
150
171
|
|
151
172
|
class Mxfp4Config(QuantizationConfig):
|
152
173
|
|
153
|
-
def __init__(
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
ignored_layers: Optional[list[str]] = None,
|
177
|
+
is_checkpoint_mxfp4_serialized: bool = False,
|
178
|
+
):
|
154
179
|
super().__init__()
|
180
|
+
self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
|
155
181
|
self.ignored_layers = ignored_layers
|
156
182
|
|
157
183
|
@classmethod
|
158
184
|
def from_config(cls, config):
|
159
|
-
|
185
|
+
|
186
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
187
|
+
is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method
|
188
|
+
|
189
|
+
if _is_hip:
|
190
|
+
if mxfp_supported():
|
191
|
+
return cls(
|
192
|
+
is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
|
196
|
+
platform = torch.cuda.get_device_properties(0).gcnArchName
|
197
|
+
raise ValueError(
|
198
|
+
f"Current platform {platform} not support mxfp4 computation"
|
199
|
+
)
|
200
|
+
|
201
|
+
return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized)
|
160
202
|
|
161
203
|
@classmethod
|
162
204
|
def get_min_capability(cls) -> int:
|
@@ -174,6 +216,9 @@ class Mxfp4Config(QuantizationConfig):
|
|
174
216
|
def get_config_filenames(cls) -> list[str]:
|
175
217
|
return []
|
176
218
|
|
219
|
+
def is_static_cfg(self):
|
220
|
+
return self.is_checkpoint_mxfp4_serialized
|
221
|
+
|
177
222
|
def get_quant_method(
|
178
223
|
self, layer: torch.nn.Module, prefix: str
|
179
224
|
) -> Optional["QuantizeMethodBase"]:
|
@@ -189,10 +234,16 @@ class Mxfp4Config(QuantizationConfig):
|
|
189
234
|
fused_mapping=self.packed_modules_mapping,
|
190
235
|
):
|
191
236
|
return UnquantizedLinearMethod()
|
237
|
+
elif _is_hip:
|
238
|
+
return UnquantizedLinearMethod()
|
192
239
|
elif isinstance(layer, FusedMoE):
|
193
|
-
|
240
|
+
if self.is_checkpoint_mxfp4_serialized:
|
241
|
+
return Mxfp4MoEMethod(prefix=prefix)
|
242
|
+
else:
|
243
|
+
return Mxfp4DynamicQuantMoEMethod()
|
194
244
|
else:
|
195
|
-
|
245
|
+
if self.is_checkpoint_mxfp4_serialized:
|
246
|
+
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
196
247
|
return None
|
197
248
|
|
198
249
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -205,14 +256,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
205
256
|
self,
|
206
257
|
prefix: str,
|
207
258
|
):
|
208
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
209
|
-
|
210
259
|
super().__init__()
|
211
260
|
|
261
|
+
self.prefix = prefix
|
212
262
|
self.topk_indices_dtype = None
|
213
|
-
self.use_triton_kernels =
|
263
|
+
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
214
264
|
self.with_bias = False
|
215
|
-
self.use_flashinfer =
|
265
|
+
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
266
|
+
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
|
267
|
+
"flashinfer_mxfp4_moe_precision"
|
268
|
+
]
|
216
269
|
|
217
270
|
self.triton_kernel_moe_forward = None
|
218
271
|
self.triton_kernel_moe_with_bias_forward = None
|
@@ -256,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
256
309
|
intermediate_size_per_partition_after_pad = round_up(
|
257
310
|
intermediate_size, 64
|
258
311
|
)
|
312
|
+
elif has_triton_kernels:
|
313
|
+
# TODO: this is a hack to make
|
314
|
+
# intermediate_size_per_partition_after_pad the same as the
|
315
|
+
# per_rank_intermediate_size during weight loading
|
316
|
+
intermediate_size_per_partition_after_pad = round_up(
|
317
|
+
intermediate_size, mxfp4_block
|
318
|
+
)
|
259
319
|
|
260
320
|
self.intermediate_size = intermediate_size_per_partition_after_pad
|
261
321
|
|
@@ -332,8 +392,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
332
392
|
if self.use_flashinfer:
|
333
393
|
log_info_on_rank0(
|
334
394
|
logger,
|
335
|
-
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
|
395
|
+
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
336
396
|
)
|
397
|
+
# TODO: these values are hardcoded for now, we need to get them from the model
|
337
398
|
layer.gemm1_alpha = Parameter(
|
338
399
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
339
400
|
requires_grad=False,
|
@@ -559,24 +620,40 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
559
620
|
layer: torch.nn.Module,
|
560
621
|
x: torch.Tensor,
|
561
622
|
topk_output: TopKOutput,
|
562
|
-
|
563
|
-
activation: str = "silu",
|
564
|
-
apply_router_weight_on_input: bool = False,
|
565
|
-
inplace: bool = True,
|
566
|
-
no_combine: bool = False,
|
567
|
-
routed_scaling_factor: Optional[float] = None,
|
568
|
-
activation_alpha: Optional[float] = None,
|
569
|
-
swiglu_limit: Optional[float] = None,
|
623
|
+
moe_runner_config: MoeRunnerConfig,
|
570
624
|
) -> torch.Tensor:
|
625
|
+
|
626
|
+
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
627
|
+
|
571
628
|
if self.use_flashinfer:
|
572
|
-
#
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
629
|
+
# When bf16 mode is enabled, we don't need to quantize the input,
|
630
|
+
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
631
|
+
# which can theoretically improve performance
|
632
|
+
if self.flashinfer_mxfp4_moe_precision == "bf16":
|
633
|
+
assert x.dtype == torch.bfloat16
|
634
|
+
x_quant = x
|
635
|
+
x_scale = None
|
636
|
+
|
637
|
+
# May be fused later if this code branch is frequently needed
|
638
|
+
origin_hidden_states_dim = x_quant.shape[-1]
|
639
|
+
if self.hidden_size != origin_hidden_states_dim:
|
640
|
+
x_quant = torch.nn.functional.pad(
|
641
|
+
x_quant,
|
642
|
+
(0, self.hidden_size - origin_hidden_states_dim),
|
643
|
+
mode="constant",
|
644
|
+
value=0.0,
|
645
|
+
)
|
646
|
+
elif self.flashinfer_mxfp4_moe_precision == "default":
|
647
|
+
x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
|
648
|
+
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
649
|
+
else:
|
650
|
+
raise NotImplementedError
|
651
|
+
|
577
652
|
assert x_quant.shape[-1] == self.hidden_size
|
653
|
+
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
578
654
|
|
579
|
-
top_k
|
655
|
+
top_k = topk_output.topk_config.top_k
|
656
|
+
router_logits = topk_output.router_logits
|
580
657
|
|
581
658
|
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
582
659
|
router_logits.to(torch.bfloat16),
|
@@ -597,8 +674,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
597
674
|
None, # output2_scale_scalar
|
598
675
|
layer.num_experts,
|
599
676
|
top_k,
|
600
|
-
None, # n_group
|
601
|
-
None, # topk_group
|
677
|
+
None, # n_group # TODO: support n_group
|
678
|
+
None, # topk_group # TODO: support topk_group
|
602
679
|
self.intermediate_size, # padded to multiple of 256
|
603
680
|
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
604
681
|
layer.num_local_experts, # local num experts
|
@@ -623,9 +700,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
623
700
|
b1=layer.w13_weight_bias,
|
624
701
|
b2=layer.w2_weight_bias,
|
625
702
|
topk_output=topk_output,
|
626
|
-
|
627
|
-
activation_alpha=activation_alpha,
|
628
|
-
swiglu_limit=swiglu_limit,
|
703
|
+
moe_runner_config=moe_runner_config,
|
629
704
|
)
|
630
705
|
else:
|
631
706
|
return self.triton_kernel_moe_forward(
|
@@ -633,6 +708,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
633
708
|
w1=layer.w13_weight,
|
634
709
|
w2=layer.w2_weight,
|
635
710
|
topk_output=topk_output,
|
711
|
+
moe_runner_config=moe_runner_config,
|
636
712
|
)
|
637
713
|
else:
|
638
714
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -642,13 +718,120 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
642
718
|
w1=layer.w13_weight,
|
643
719
|
w2=layer.w2_weight,
|
644
720
|
topk_output=topk_output,
|
721
|
+
moe_runner_config=moe_runner_config,
|
645
722
|
b1=layer.w13_weight_bias,
|
646
723
|
b2=layer.w2_weight_bias,
|
647
|
-
inplace=inplace,
|
648
|
-
activation=activation,
|
649
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
650
|
-
no_combine=no_combine,
|
651
|
-
routed_scaling_factor=routed_scaling_factor,
|
652
|
-
activation_alpha=activation_alpha,
|
653
|
-
swiglu_limit=swiglu_limit,
|
654
724
|
)
|
725
|
+
|
726
|
+
|
727
|
+
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
728
|
+
def create_weights(
|
729
|
+
self,
|
730
|
+
layer: torch.nn.Module,
|
731
|
+
num_experts: int,
|
732
|
+
hidden_size: int,
|
733
|
+
intermediate_size_per_partition: int,
|
734
|
+
params_dtype: torch.dtype,
|
735
|
+
**extra_weight_attrs,
|
736
|
+
):
|
737
|
+
|
738
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
739
|
+
|
740
|
+
w13_weight = torch.nn.Parameter(
|
741
|
+
torch.empty(
|
742
|
+
num_experts,
|
743
|
+
2 * intermediate_size_per_partition,
|
744
|
+
hidden_size,
|
745
|
+
dtype=params_dtype,
|
746
|
+
),
|
747
|
+
requires_grad=False,
|
748
|
+
)
|
749
|
+
w2_weight = torch.nn.Parameter(
|
750
|
+
torch.empty(
|
751
|
+
num_experts,
|
752
|
+
hidden_size,
|
753
|
+
intermediate_size_per_partition,
|
754
|
+
dtype=params_dtype,
|
755
|
+
),
|
756
|
+
requires_grad=False,
|
757
|
+
)
|
758
|
+
|
759
|
+
layer.register_parameter("w13_weight", w13_weight)
|
760
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
761
|
+
|
762
|
+
layer.register_parameter("w2_weight", w2_weight)
|
763
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
764
|
+
|
765
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
766
|
+
# They will be combined to a single scale after weight loading.
|
767
|
+
w13_weight_scale = torch.nn.Parameter(
|
768
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
769
|
+
)
|
770
|
+
w2_weight_scale = torch.nn.Parameter(
|
771
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
772
|
+
)
|
773
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
774
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
775
|
+
|
776
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
777
|
+
# to ensure the weight scales are loaded in properly
|
778
|
+
extra_weight_attrs.update(
|
779
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
780
|
+
)
|
781
|
+
|
782
|
+
layer.w13_input_scale = None
|
783
|
+
layer.w2_input_scale = None
|
784
|
+
|
785
|
+
def mxfp4_quantize(self, w):
|
786
|
+
w_shape = w.shape
|
787
|
+
w_need_reshape = True if w.dim() != 2 else False
|
788
|
+
|
789
|
+
if w_need_reshape:
|
790
|
+
w_last_dim_size = w_shape[-1]
|
791
|
+
w = w.view(-1, w_last_dim_size)
|
792
|
+
|
793
|
+
w, mx_scales = dynamic_mxfp4_quant(w)
|
794
|
+
|
795
|
+
if w_need_reshape:
|
796
|
+
w_new_shape = w_shape[:-1] + (w.shape[-1],)
|
797
|
+
w = w.view(w_new_shape)
|
798
|
+
|
799
|
+
mx_scales = e8m0_shuffle(mx_scales)
|
800
|
+
|
801
|
+
return w, mx_scales
|
802
|
+
|
803
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
804
|
+
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
805
|
+
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
806
|
+
|
807
|
+
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
|
808
|
+
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
|
809
|
+
|
810
|
+
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
811
|
+
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
812
|
+
|
813
|
+
def apply(
|
814
|
+
self,
|
815
|
+
layer: torch.nn.Module,
|
816
|
+
x: torch.Tensor,
|
817
|
+
topk_output: TopKOutput,
|
818
|
+
moe_runner_config: MoeRunnerConfig,
|
819
|
+
) -> torch.Tensor:
|
820
|
+
topk_weights, topk_ids, _ = topk_output
|
821
|
+
|
822
|
+
return fused_moe(
|
823
|
+
x,
|
824
|
+
layer.w13_weight,
|
825
|
+
layer.w2_weight,
|
826
|
+
topk_weights,
|
827
|
+
topk_ids,
|
828
|
+
quant_type=QuantType.per_1x32,
|
829
|
+
w1_scale=layer.w13_weight_scale,
|
830
|
+
w2_scale=layer.w2_weight_scale,
|
831
|
+
activation=(
|
832
|
+
ActivationType.Silu
|
833
|
+
if moe_runner_config.activation == "silu"
|
834
|
+
else ActivationType.Gelu
|
835
|
+
),
|
836
|
+
doweight_stage1=False,
|
837
|
+
)
|