sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,654 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
|
6
|
+
import importlib.util
|
7
|
+
import logging
|
8
|
+
from typing import TYPE_CHECKING, List, Optional
|
9
|
+
|
10
|
+
import torch
|
11
|
+
import triton.language as tl
|
12
|
+
from torch.nn.parameter import Parameter
|
13
|
+
|
14
|
+
from sglang.srt.layers.quantization.base_config import (
|
15
|
+
FusedMoEMethodBase,
|
16
|
+
QuantizationConfig,
|
17
|
+
QuantizeMethodBase,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
20
|
+
from sglang.srt.layers.utils import is_sm100_supported
|
21
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
22
|
+
from sglang.srt.utils import (
|
23
|
+
direct_register_custom_op,
|
24
|
+
get_bool_env_var,
|
25
|
+
is_cuda,
|
26
|
+
is_flashinfer_available,
|
27
|
+
is_hip,
|
28
|
+
is_triton_kernels_available,
|
29
|
+
log_info_on_rank0,
|
30
|
+
next_power_of_2,
|
31
|
+
round_up,
|
32
|
+
set_weight_attrs,
|
33
|
+
)
|
34
|
+
|
35
|
+
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
36
|
+
has_triton_kernels = is_triton_kernels_available()
|
37
|
+
|
38
|
+
|
39
|
+
if is_flashinfer_available():
|
40
|
+
from flashinfer import (
|
41
|
+
mxfp8_quantize,
|
42
|
+
shuffle_matrix_a,
|
43
|
+
shuffle_matrix_sf_a,
|
44
|
+
trtllm_fp4_block_scale_moe,
|
45
|
+
)
|
46
|
+
|
47
|
+
logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
if TYPE_CHECKING:
|
50
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
51
|
+
|
52
|
+
OCP_MX_BLOCK_SIZE = 32
|
53
|
+
|
54
|
+
|
55
|
+
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
56
|
+
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
57
|
+
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
58
|
+
from triton_kernels.numerics import InFlexData
|
59
|
+
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
60
|
+
from triton_kernels.tensor_details import layout
|
61
|
+
|
62
|
+
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
|
63
|
+
mx_axis=1
|
64
|
+
)
|
65
|
+
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
|
66
|
+
mx_axis=1, num_warps=num_warps
|
67
|
+
)
|
68
|
+
if _is_sm100_supported:
|
69
|
+
constraints = {
|
70
|
+
"is_persistent": True,
|
71
|
+
"epilogue_subtile": 1,
|
72
|
+
}
|
73
|
+
opt_flags.update_opt_flags_constraints(constraints)
|
74
|
+
# transpose the tensor so that the quantization axis is on dim1
|
75
|
+
quant_tensor = quant_tensor.transpose(-2, -1)
|
76
|
+
scale = scale.transpose(-2, -1)
|
77
|
+
quant_tensor = convert_layout(
|
78
|
+
wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts
|
79
|
+
)
|
80
|
+
scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts)
|
81
|
+
return quant_tensor, InFlexData(), scale
|
82
|
+
|
83
|
+
|
84
|
+
def _dequant_mxfp4(
|
85
|
+
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
86
|
+
) -> torch.Tensor:
|
87
|
+
try:
|
88
|
+
from quark.torch.kernel import mx
|
89
|
+
except ImportError as err:
|
90
|
+
raise ImportError(
|
91
|
+
"The package `amd-quark` is required to use "
|
92
|
+
"MX-FP4 models. Please install it with `pip install "
|
93
|
+
"amd-quark`."
|
94
|
+
) from err
|
95
|
+
|
96
|
+
return mx.dq_mxfp4(x, scale, float_dtype)
|
97
|
+
|
98
|
+
|
99
|
+
def _dequant_mxfp4_fake(
|
100
|
+
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
101
|
+
) -> torch.Tensor:
|
102
|
+
return torch.empty(
|
103
|
+
(*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
def _quant_dequant_mxfp4(
|
108
|
+
x: torch.Tensor, scale_calculation_mode: str = "even"
|
109
|
+
) -> torch.Tensor:
|
110
|
+
try:
|
111
|
+
from quark.torch.kernel import mx
|
112
|
+
except ImportError as err:
|
113
|
+
raise ImportError(
|
114
|
+
"The package `amd-quark` is required to use "
|
115
|
+
"MX-FP4 models. Please install it with `pip install "
|
116
|
+
"amd-quark`."
|
117
|
+
) from err
|
118
|
+
|
119
|
+
return mx.qdq_mxfp4(x, scale_calculation_mode)
|
120
|
+
|
121
|
+
|
122
|
+
def _quant_dequant_mxfp4_fake(
|
123
|
+
x: torch.Tensor, scale_calculation_mode: str = "even"
|
124
|
+
) -> torch.Tensor:
|
125
|
+
return torch.empty_like(x)
|
126
|
+
|
127
|
+
|
128
|
+
try:
|
129
|
+
direct_register_custom_op(
|
130
|
+
op_name="dequant_mxfp4",
|
131
|
+
op_func=_dequant_mxfp4,
|
132
|
+
mutates_args=[],
|
133
|
+
fake_impl=_dequant_mxfp4_fake,
|
134
|
+
)
|
135
|
+
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
136
|
+
except AttributeError as error:
|
137
|
+
raise error
|
138
|
+
|
139
|
+
try:
|
140
|
+
direct_register_custom_op(
|
141
|
+
op_name="quant_dequant_mxfp4",
|
142
|
+
op_func=_quant_dequant_mxfp4,
|
143
|
+
mutates_args=[],
|
144
|
+
fake_impl=_quant_dequant_mxfp4_fake,
|
145
|
+
)
|
146
|
+
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
147
|
+
except AttributeError as error:
|
148
|
+
raise error
|
149
|
+
|
150
|
+
|
151
|
+
class Mxfp4Config(QuantizationConfig):
|
152
|
+
|
153
|
+
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
154
|
+
super().__init__()
|
155
|
+
self.ignored_layers = ignored_layers
|
156
|
+
|
157
|
+
@classmethod
|
158
|
+
def from_config(cls, config):
|
159
|
+
return cls()
|
160
|
+
|
161
|
+
@classmethod
|
162
|
+
def get_min_capability(cls) -> int:
|
163
|
+
return 80
|
164
|
+
|
165
|
+
@classmethod
|
166
|
+
def get_name(cls) -> str:
|
167
|
+
return "mxfp4"
|
168
|
+
|
169
|
+
@classmethod
|
170
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
171
|
+
return [torch.bfloat16, torch.float16]
|
172
|
+
|
173
|
+
@classmethod
|
174
|
+
def get_config_filenames(cls) -> list[str]:
|
175
|
+
return []
|
176
|
+
|
177
|
+
def get_quant_method(
|
178
|
+
self, layer: torch.nn.Module, prefix: str
|
179
|
+
) -> Optional["QuantizeMethodBase"]:
|
180
|
+
|
181
|
+
from sglang.srt.layers.linear import LinearBase
|
182
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
183
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
184
|
+
|
185
|
+
if isinstance(layer, LinearBase):
|
186
|
+
if self.ignored_layers and is_layer_skipped(
|
187
|
+
prefix=prefix,
|
188
|
+
ignored_layers=self.ignored_layers,
|
189
|
+
fused_mapping=self.packed_modules_mapping,
|
190
|
+
):
|
191
|
+
return UnquantizedLinearMethod()
|
192
|
+
elif isinstance(layer, FusedMoE):
|
193
|
+
return Mxfp4MoEMethod(prefix)
|
194
|
+
else:
|
195
|
+
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
196
|
+
return None
|
197
|
+
|
198
|
+
def get_scaled_act_names(self) -> List[str]:
|
199
|
+
return []
|
200
|
+
|
201
|
+
|
202
|
+
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
203
|
+
|
204
|
+
def __init__(
|
205
|
+
self,
|
206
|
+
prefix: str,
|
207
|
+
):
|
208
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
209
|
+
|
210
|
+
super().__init__()
|
211
|
+
|
212
|
+
self.topk_indices_dtype = None
|
213
|
+
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
214
|
+
self.with_bias = False
|
215
|
+
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
|
216
|
+
|
217
|
+
self.triton_kernel_moe_forward = None
|
218
|
+
self.triton_kernel_moe_with_bias_forward = None
|
219
|
+
if torch.cuda.is_available() and has_triton_kernels:
|
220
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
221
|
+
triton_kernel_moe_forward as _tk_forward,
|
222
|
+
)
|
223
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
224
|
+
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
|
225
|
+
)
|
226
|
+
|
227
|
+
self.triton_kernel_moe_forward = _tk_forward
|
228
|
+
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
|
229
|
+
|
230
|
+
def create_weights(
|
231
|
+
self,
|
232
|
+
layer: torch.nn.Module,
|
233
|
+
num_experts: int,
|
234
|
+
hidden_size: int,
|
235
|
+
intermediate_size: int,
|
236
|
+
params_dtype: torch.dtype,
|
237
|
+
with_bias: bool = False,
|
238
|
+
**extra_weight_attrs,
|
239
|
+
):
|
240
|
+
self.num_experts = num_experts
|
241
|
+
weight_dtype = torch.uint8
|
242
|
+
scale_dtype = torch.uint8
|
243
|
+
self.with_bias = with_bias
|
244
|
+
mxfp4_block = 32
|
245
|
+
|
246
|
+
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
247
|
+
# for to hold non-uniform sharded tensor as well as swizzling
|
248
|
+
intermediate_size_per_partition_after_pad = intermediate_size
|
249
|
+
if _is_sm100_supported:
|
250
|
+
if self.use_flashinfer:
|
251
|
+
intermediate_size_per_partition_after_pad = round_up(
|
252
|
+
intermediate_size, 256
|
253
|
+
)
|
254
|
+
hidden_size = round_up(hidden_size, 256)
|
255
|
+
else:
|
256
|
+
intermediate_size_per_partition_after_pad = round_up(
|
257
|
+
intermediate_size, 64
|
258
|
+
)
|
259
|
+
|
260
|
+
self.intermediate_size = intermediate_size_per_partition_after_pad
|
261
|
+
|
262
|
+
self.hidden_size = hidden_size
|
263
|
+
# Fused gate_up_proj (column parallel)
|
264
|
+
w13_weight = torch.nn.Parameter(
|
265
|
+
torch.zeros(
|
266
|
+
layer.num_local_experts,
|
267
|
+
2 * intermediate_size_per_partition_after_pad,
|
268
|
+
hidden_size // 2,
|
269
|
+
dtype=weight_dtype,
|
270
|
+
),
|
271
|
+
requires_grad=False,
|
272
|
+
)
|
273
|
+
layer.register_parameter("w13_weight", w13_weight)
|
274
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
275
|
+
|
276
|
+
w13_weight_scale = torch.nn.Parameter(
|
277
|
+
torch.zeros(
|
278
|
+
layer.num_local_experts,
|
279
|
+
2 * intermediate_size_per_partition_after_pad,
|
280
|
+
hidden_size // mxfp4_block,
|
281
|
+
dtype=scale_dtype,
|
282
|
+
),
|
283
|
+
requires_grad=False,
|
284
|
+
)
|
285
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
286
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
287
|
+
|
288
|
+
w13_weight_bias = torch.nn.Parameter(
|
289
|
+
torch.zeros(
|
290
|
+
layer.num_local_experts,
|
291
|
+
2 * intermediate_size_per_partition_after_pad,
|
292
|
+
dtype=torch.bfloat16,
|
293
|
+
),
|
294
|
+
requires_grad=False,
|
295
|
+
)
|
296
|
+
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
297
|
+
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
|
298
|
+
|
299
|
+
# down_proj (row parallel)
|
300
|
+
w2_weight = torch.nn.Parameter(
|
301
|
+
torch.zeros(
|
302
|
+
layer.num_local_experts,
|
303
|
+
hidden_size,
|
304
|
+
intermediate_size_per_partition_after_pad // 2,
|
305
|
+
dtype=weight_dtype,
|
306
|
+
),
|
307
|
+
requires_grad=False,
|
308
|
+
)
|
309
|
+
layer.register_parameter("w2_weight", w2_weight)
|
310
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
311
|
+
|
312
|
+
w2_weight_scale = torch.nn.Parameter(
|
313
|
+
torch.zeros(
|
314
|
+
layer.num_local_experts,
|
315
|
+
hidden_size,
|
316
|
+
intermediate_size_per_partition_after_pad // mxfp4_block,
|
317
|
+
dtype=scale_dtype,
|
318
|
+
),
|
319
|
+
requires_grad=False,
|
320
|
+
)
|
321
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
322
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
323
|
+
|
324
|
+
w2_weight_bias = torch.nn.Parameter(
|
325
|
+
torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
|
326
|
+
requires_grad=False,
|
327
|
+
)
|
328
|
+
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
329
|
+
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
330
|
+
|
331
|
+
def process_weights_after_loading(self, layer):
|
332
|
+
if self.use_flashinfer:
|
333
|
+
log_info_on_rank0(
|
334
|
+
logger,
|
335
|
+
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
|
336
|
+
)
|
337
|
+
layer.gemm1_alpha = Parameter(
|
338
|
+
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
339
|
+
requires_grad=False,
|
340
|
+
)
|
341
|
+
layer.gemm1_beta = Parameter(
|
342
|
+
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
343
|
+
requires_grad=False,
|
344
|
+
)
|
345
|
+
layer.gemm1_clamp_limit = Parameter(
|
346
|
+
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
347
|
+
requires_grad=False,
|
348
|
+
)
|
349
|
+
sf_block_size = 32 # mxfp4 block size
|
350
|
+
|
351
|
+
assert (
|
352
|
+
layer.w13_weight.dim() == 3
|
353
|
+
and layer.w13_weight.shape[0] == self.num_experts
|
354
|
+
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
355
|
+
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
356
|
+
)
|
357
|
+
assert (
|
358
|
+
layer.w13_weight_scale.dim() == 3
|
359
|
+
and layer.w13_weight_scale.shape[0] == self.num_experts
|
360
|
+
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
|
361
|
+
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
362
|
+
)
|
363
|
+
assert (
|
364
|
+
layer.w2_weight.dim() == 3
|
365
|
+
and layer.w2_weight.shape[0] == self.num_experts
|
366
|
+
and layer.w2_weight.shape[1] == self.hidden_size
|
367
|
+
and layer.w2_weight.shape[2] == self.intermediate_size // 2
|
368
|
+
)
|
369
|
+
assert (
|
370
|
+
layer.w2_weight_scale.dim() == 3
|
371
|
+
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
372
|
+
and layer.w2_weight_scale.shape[2]
|
373
|
+
== self.intermediate_size // sf_block_size
|
374
|
+
)
|
375
|
+
assert (
|
376
|
+
layer.w13_weight_bias.dim() == 2
|
377
|
+
and layer.w13_weight_bias.shape[0] == self.num_experts
|
378
|
+
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
|
379
|
+
)
|
380
|
+
assert (
|
381
|
+
layer.w2_weight_bias.dim() == 2
|
382
|
+
and layer.w2_weight_bias.shape[0] == self.num_experts
|
383
|
+
and layer.w2_weight_bias.shape[1] == self.hidden_size
|
384
|
+
)
|
385
|
+
|
386
|
+
w13_weight_scale = layer.w13_weight_scale.data
|
387
|
+
w2_weight_scale = layer.w2_weight_scale.data
|
388
|
+
w13_weight = layer.w13_weight.data
|
389
|
+
w2_weight = layer.w2_weight.data
|
390
|
+
w13_bias = layer.w13_weight_bias.data.to(torch.float32)
|
391
|
+
w2_bias = layer.w2_weight_bias.data.to(torch.float32)
|
392
|
+
|
393
|
+
# Swap w1 and w3 as the definition of
|
394
|
+
# swiglu is different in the trtllm-gen
|
395
|
+
def swap_every_two_rows(x, axis=-1):
|
396
|
+
shape = x.shape
|
397
|
+
if axis < 0:
|
398
|
+
axis = len(shape) + axis
|
399
|
+
|
400
|
+
# Create a new shape with pairs swapped along specified axis
|
401
|
+
new_shape = list(shape)
|
402
|
+
new_shape[axis] = shape[axis] // 2
|
403
|
+
new_shape.insert(axis + 1, 2)
|
404
|
+
|
405
|
+
# Reshape to expose pairs, swap them, and reshape back
|
406
|
+
x = x.reshape(*new_shape)
|
407
|
+
x = x.flip(axis + 1)
|
408
|
+
new_shape = list(shape)
|
409
|
+
return x.reshape(*new_shape)
|
410
|
+
|
411
|
+
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
412
|
+
w13_weight = swap_every_two_rows(w13_weight, -2)
|
413
|
+
w13_bias = swap_every_two_rows(w13_bias, -1)
|
414
|
+
|
415
|
+
# Shuffle weights and scaling factors for transposed mma output
|
416
|
+
gemm1_weights_mxfp4_shuffled = []
|
417
|
+
gemm1_scales_mxfp4_shuffled = []
|
418
|
+
gemm2_weights_mxfp4_shuffled = []
|
419
|
+
gemm2_scales_mxfp4_shuffled = []
|
420
|
+
gemm1_bias_shuffled = []
|
421
|
+
gemm2_bias_shuffled = []
|
422
|
+
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
423
|
+
for i in range(self.num_experts):
|
424
|
+
gemm1_weights_mxfp4_shuffled.append(
|
425
|
+
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
426
|
+
)
|
427
|
+
gemm1_scales_mxfp4_shuffled.append(
|
428
|
+
shuffle_matrix_sf_a(
|
429
|
+
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
430
|
+
)
|
431
|
+
)
|
432
|
+
gemm1_bias_shuffled.append(
|
433
|
+
shuffle_matrix_a(
|
434
|
+
w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m
|
435
|
+
)
|
436
|
+
)
|
437
|
+
|
438
|
+
gemm2_weights_mxfp4_shuffled.append(
|
439
|
+
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
440
|
+
)
|
441
|
+
gemm2_scales_mxfp4_shuffled.append(
|
442
|
+
shuffle_matrix_sf_a(
|
443
|
+
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
444
|
+
)
|
445
|
+
)
|
446
|
+
gemm2_bias_shuffled.append(
|
447
|
+
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m)
|
448
|
+
)
|
449
|
+
|
450
|
+
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
451
|
+
w13_weight_scale = (
|
452
|
+
torch.stack(gemm1_scales_mxfp4_shuffled)
|
453
|
+
.reshape(
|
454
|
+
self.num_experts,
|
455
|
+
2 * self.intermediate_size,
|
456
|
+
self.hidden_size // sf_block_size,
|
457
|
+
)
|
458
|
+
.view(torch.float8_e4m3fn)
|
459
|
+
)
|
460
|
+
|
461
|
+
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
462
|
+
w2_weight_scale = (
|
463
|
+
torch.stack(gemm2_scales_mxfp4_shuffled)
|
464
|
+
.reshape(
|
465
|
+
self.num_experts,
|
466
|
+
self.hidden_size,
|
467
|
+
self.intermediate_size // sf_block_size,
|
468
|
+
)
|
469
|
+
.view(torch.float8_e4m3fn)
|
470
|
+
)
|
471
|
+
|
472
|
+
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
473
|
+
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
|
474
|
+
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
475
|
+
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
|
476
|
+
layer.w13_weight_bias = Parameter(
|
477
|
+
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
478
|
+
requires_grad=False,
|
479
|
+
)
|
480
|
+
layer.w2_weight_bias = Parameter(
|
481
|
+
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
|
482
|
+
requires_grad=False,
|
483
|
+
)
|
484
|
+
return
|
485
|
+
|
486
|
+
if self.use_triton_kernels:
|
487
|
+
|
488
|
+
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
489
|
+
|
490
|
+
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
491
|
+
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
492
|
+
|
493
|
+
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
494
|
+
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
495
|
+
|
496
|
+
num_warps = 8
|
497
|
+
|
498
|
+
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
499
|
+
layer.w13_weight, layer.w13_weight_scale, num_warps
|
500
|
+
)
|
501
|
+
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
502
|
+
layer.w2_weight, layer.w2_weight_scale, num_warps
|
503
|
+
)
|
504
|
+
|
505
|
+
self.w13_precision_config = PrecisionConfig(
|
506
|
+
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
507
|
+
)
|
508
|
+
self.w2_precision_config = PrecisionConfig(
|
509
|
+
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
510
|
+
)
|
511
|
+
|
512
|
+
self.w13_weight_triton_tensor = w13_weight
|
513
|
+
self.w2_weight_triton_tensor = w2_weight
|
514
|
+
del layer.w13_weight
|
515
|
+
del layer.w2_weight
|
516
|
+
else:
|
517
|
+
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
|
518
|
+
|
519
|
+
w13_weight = upcast_from_mxfp(
|
520
|
+
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
|
521
|
+
)
|
522
|
+
w2_weight = upcast_from_mxfp(
|
523
|
+
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
|
524
|
+
)
|
525
|
+
del layer.w13_weight
|
526
|
+
del layer.w2_weight
|
527
|
+
del layer.w13_weight_scale
|
528
|
+
del layer.w2_weight_scale
|
529
|
+
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
|
530
|
+
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
531
|
+
torch.cuda.empty_cache()
|
532
|
+
|
533
|
+
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
534
|
+
# Number of tokens in the input tensor.
|
535
|
+
num_tokens = x.shape[0]
|
536
|
+
# Factor to account for the imbalance of the experts.
|
537
|
+
# factor equals to the
|
538
|
+
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
539
|
+
# - 1.0 means perfect expert distribution.
|
540
|
+
# - > 1.0 means some experts have more
|
541
|
+
# tokens than the perfect distribution.
|
542
|
+
# - < 1.0 does not make sense.
|
543
|
+
imbalance_factor = 1.3
|
544
|
+
# Calculate the number of tokens per expert
|
545
|
+
# assuming perfect distribution.
|
546
|
+
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
547
|
+
# Apply the imbalance factor.
|
548
|
+
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
549
|
+
# And pad the number to the next power of 2.
|
550
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
551
|
+
# Cap to 8-64 tokens per CTA tile
|
552
|
+
# as it's the range supported by the kernel.
|
553
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
554
|
+
|
555
|
+
return tile_tokens_dim
|
556
|
+
|
557
|
+
def apply(
|
558
|
+
self,
|
559
|
+
layer: torch.nn.Module,
|
560
|
+
x: torch.Tensor,
|
561
|
+
topk_output: TopKOutput,
|
562
|
+
*,
|
563
|
+
activation: str = "silu",
|
564
|
+
apply_router_weight_on_input: bool = False,
|
565
|
+
inplace: bool = True,
|
566
|
+
no_combine: bool = False,
|
567
|
+
routed_scaling_factor: Optional[float] = None,
|
568
|
+
activation_alpha: Optional[float] = None,
|
569
|
+
swiglu_limit: Optional[float] = None,
|
570
|
+
) -> torch.Tensor:
|
571
|
+
if self.use_flashinfer:
|
572
|
+
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
573
|
+
x_quant, x_scale = mxfp8_quantize(
|
574
|
+
x, False, alignment=self.hidden_size
|
575
|
+
) # to mxfp8
|
576
|
+
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
577
|
+
assert x_quant.shape[-1] == self.hidden_size
|
578
|
+
|
579
|
+
top_k, router_logits = topk_output
|
580
|
+
|
581
|
+
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
582
|
+
router_logits.to(torch.bfloat16),
|
583
|
+
None, # routing_bias
|
584
|
+
x_quant,
|
585
|
+
x_scale,
|
586
|
+
layer.w13_weight, # uint8 (e2m1 x 2)
|
587
|
+
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
588
|
+
layer.w13_weight_bias, # fp32 per expert per channel
|
589
|
+
layer.gemm1_alpha, # fp32 per expert
|
590
|
+
layer.gemm1_beta, # fp32 per expert
|
591
|
+
layer.gemm1_clamp_limit, # fp32 per expert
|
592
|
+
layer.w2_weight, # uint8 (e2m1 x 2)
|
593
|
+
layer.w2_weight_scale, # ue8m0
|
594
|
+
layer.w2_weight_bias, # fp32 per expert per channel
|
595
|
+
None, # output1_scale_scalar
|
596
|
+
None, # output1_scale_gate_scalar
|
597
|
+
None, # output2_scale_scalar
|
598
|
+
layer.num_experts,
|
599
|
+
top_k,
|
600
|
+
None, # n_group
|
601
|
+
None, # topk_group
|
602
|
+
self.intermediate_size, # padded to multiple of 256
|
603
|
+
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
604
|
+
layer.num_local_experts, # local num experts
|
605
|
+
None,
|
606
|
+
self._get_tile_tokens_dim(x, top_k),
|
607
|
+
1, # routing_method_type, renormalize
|
608
|
+
True, # do finalize
|
609
|
+
)[0]
|
610
|
+
return trtllm_gen_output
|
611
|
+
|
612
|
+
if self.use_triton_kernels:
|
613
|
+
assert (
|
614
|
+
layer.moe_ep_size == 1
|
615
|
+
), "Expert parallel is not supported when using triton kernels"
|
616
|
+
if self.with_bias:
|
617
|
+
return self.triton_kernel_moe_with_bias_forward(
|
618
|
+
hidden_states=x,
|
619
|
+
w1=self.w13_weight_triton_tensor,
|
620
|
+
w1_pcg=self.w13_precision_config,
|
621
|
+
w2=self.w2_weight_triton_tensor,
|
622
|
+
w2_pcg=self.w2_precision_config,
|
623
|
+
b1=layer.w13_weight_bias,
|
624
|
+
b2=layer.w2_weight_bias,
|
625
|
+
topk_output=topk_output,
|
626
|
+
activation=activation,
|
627
|
+
activation_alpha=activation_alpha,
|
628
|
+
swiglu_limit=swiglu_limit,
|
629
|
+
)
|
630
|
+
else:
|
631
|
+
return self.triton_kernel_moe_forward(
|
632
|
+
hidden_states=x,
|
633
|
+
w1=layer.w13_weight,
|
634
|
+
w2=layer.w2_weight,
|
635
|
+
topk_output=topk_output,
|
636
|
+
)
|
637
|
+
else:
|
638
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
639
|
+
|
640
|
+
return fused_experts(
|
641
|
+
hidden_states=x,
|
642
|
+
w1=layer.w13_weight,
|
643
|
+
w2=layer.w2_weight,
|
644
|
+
topk_output=topk_output,
|
645
|
+
b1=layer.w13_weight_bias,
|
646
|
+
b2=layer.w2_weight_bias,
|
647
|
+
inplace=inplace,
|
648
|
+
activation=activation,
|
649
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
650
|
+
no_combine=no_combine,
|
651
|
+
routed_scaling_factor=routed_scaling_factor,
|
652
|
+
activation_alpha=activation_alpha,
|
653
|
+
swiglu_limit=swiglu_limit,
|
654
|
+
)
|