sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -16,14 +16,13 @@
|
|
16
16
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import importlib.util
|
20
19
|
import logging
|
21
20
|
from typing import TYPE_CHECKING, List, Optional
|
22
21
|
|
23
22
|
import torch
|
24
|
-
import triton.language as tl
|
25
23
|
from torch.nn.parameter import Parameter
|
26
24
|
|
25
|
+
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
27
26
|
from sglang.srt.layers.quantization.base_config import (
|
28
27
|
FusedMoEMethodBase,
|
29
28
|
QuantizationConfig,
|
@@ -40,6 +39,7 @@ from sglang.srt.utils import (
|
|
40
39
|
is_hip,
|
41
40
|
is_triton_kernels_available,
|
42
41
|
log_info_on_rank0,
|
42
|
+
mxfp_supported,
|
43
43
|
next_power_of_2,
|
44
44
|
round_up,
|
45
45
|
set_weight_attrs,
|
@@ -60,9 +60,17 @@ if is_flashinfer_available():
|
|
60
60
|
logger = logging.getLogger(__name__)
|
61
61
|
|
62
62
|
if TYPE_CHECKING:
|
63
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
63
64
|
from sglang.srt.layers.moe.topk import TopKOutput
|
64
65
|
|
65
|
-
|
66
|
+
_is_hip = is_hip()
|
67
|
+
|
68
|
+
if _is_hip:
|
69
|
+
# import aiter
|
70
|
+
from aiter import ActivationType, QuantType, dtypes
|
71
|
+
from aiter.fused_moe import fused_moe
|
72
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
73
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
66
74
|
|
67
75
|
|
68
76
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
@@ -163,13 +171,34 @@ except AttributeError as error:
|
|
163
171
|
|
164
172
|
class Mxfp4Config(QuantizationConfig):
|
165
173
|
|
166
|
-
def __init__(
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
ignored_layers: Optional[list[str]] = None,
|
177
|
+
is_checkpoint_mxfp4_serialized: bool = False,
|
178
|
+
):
|
167
179
|
super().__init__()
|
180
|
+
self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
|
168
181
|
self.ignored_layers = ignored_layers
|
169
182
|
|
170
183
|
@classmethod
|
171
184
|
def from_config(cls, config):
|
172
|
-
|
185
|
+
|
186
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
187
|
+
is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method
|
188
|
+
|
189
|
+
if _is_hip:
|
190
|
+
if mxfp_supported():
|
191
|
+
return cls(
|
192
|
+
is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
|
196
|
+
platform = torch.cuda.get_device_properties(0).gcnArchName
|
197
|
+
raise ValueError(
|
198
|
+
f"Current platform {platform} not support mxfp4 computation"
|
199
|
+
)
|
200
|
+
|
201
|
+
return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized)
|
173
202
|
|
174
203
|
@classmethod
|
175
204
|
def get_min_capability(cls) -> int:
|
@@ -187,6 +216,9 @@ class Mxfp4Config(QuantizationConfig):
|
|
187
216
|
def get_config_filenames(cls) -> list[str]:
|
188
217
|
return []
|
189
218
|
|
219
|
+
def is_static_cfg(self):
|
220
|
+
return self.is_checkpoint_mxfp4_serialized
|
221
|
+
|
190
222
|
def get_quant_method(
|
191
223
|
self, layer: torch.nn.Module, prefix: str
|
192
224
|
) -> Optional["QuantizeMethodBase"]:
|
@@ -202,10 +234,16 @@ class Mxfp4Config(QuantizationConfig):
|
|
202
234
|
fused_mapping=self.packed_modules_mapping,
|
203
235
|
):
|
204
236
|
return UnquantizedLinearMethod()
|
237
|
+
elif _is_hip:
|
238
|
+
return UnquantizedLinearMethod()
|
205
239
|
elif isinstance(layer, FusedMoE):
|
206
|
-
|
240
|
+
if self.is_checkpoint_mxfp4_serialized:
|
241
|
+
return Mxfp4MoEMethod(prefix=prefix)
|
242
|
+
else:
|
243
|
+
return Mxfp4DynamicQuantMoEMethod()
|
207
244
|
else:
|
208
|
-
|
245
|
+
if self.is_checkpoint_mxfp4_serialized:
|
246
|
+
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
209
247
|
return None
|
210
248
|
|
211
249
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -218,15 +256,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
218
256
|
self,
|
219
257
|
prefix: str,
|
220
258
|
):
|
221
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
222
|
-
|
223
259
|
super().__init__()
|
224
260
|
|
225
261
|
self.prefix = prefix
|
226
262
|
self.topk_indices_dtype = None
|
227
|
-
self.use_triton_kernels =
|
263
|
+
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
228
264
|
self.with_bias = False
|
229
|
-
self.use_flashinfer =
|
265
|
+
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
266
|
+
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
|
267
|
+
"flashinfer_mxfp4_moe_precision"
|
268
|
+
]
|
230
269
|
|
231
270
|
self.triton_kernel_moe_forward = None
|
232
271
|
self.triton_kernel_moe_with_bias_forward = None
|
@@ -270,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
270
309
|
intermediate_size_per_partition_after_pad = round_up(
|
271
310
|
intermediate_size, 64
|
272
311
|
)
|
312
|
+
elif has_triton_kernels:
|
313
|
+
# TODO: this is a hack to make
|
314
|
+
# intermediate_size_per_partition_after_pad the same as the
|
315
|
+
# per_rank_intermediate_size during weight loading
|
316
|
+
intermediate_size_per_partition_after_pad = round_up(
|
317
|
+
intermediate_size, mxfp4_block
|
318
|
+
)
|
273
319
|
|
274
320
|
self.intermediate_size = intermediate_size_per_partition_after_pad
|
275
321
|
|
@@ -348,6 +394,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
348
394
|
logger,
|
349
395
|
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
350
396
|
)
|
397
|
+
# TODO: these values are hardcoded for now, we need to get them from the model
|
351
398
|
layer.gemm1_alpha = Parameter(
|
352
399
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
353
400
|
requires_grad=False,
|
@@ -573,24 +620,40 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
573
620
|
layer: torch.nn.Module,
|
574
621
|
x: torch.Tensor,
|
575
622
|
topk_output: TopKOutput,
|
576
|
-
|
577
|
-
activation: str = "silu",
|
578
|
-
apply_router_weight_on_input: bool = False,
|
579
|
-
inplace: bool = True,
|
580
|
-
no_combine: bool = False,
|
581
|
-
routed_scaling_factor: Optional[float] = None,
|
582
|
-
activation_alpha: Optional[float] = None,
|
583
|
-
swiglu_limit: Optional[float] = None,
|
623
|
+
moe_runner_config: MoeRunnerConfig,
|
584
624
|
) -> torch.Tensor:
|
625
|
+
|
626
|
+
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
627
|
+
|
585
628
|
if self.use_flashinfer:
|
586
|
-
#
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
629
|
+
# When bf16 mode is enabled, we don't need to quantize the input,
|
630
|
+
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
631
|
+
# which can theoretically improve performance
|
632
|
+
if self.flashinfer_mxfp4_moe_precision == "bf16":
|
633
|
+
assert x.dtype == torch.bfloat16
|
634
|
+
x_quant = x
|
635
|
+
x_scale = None
|
636
|
+
|
637
|
+
# May be fused later if this code branch is frequently needed
|
638
|
+
origin_hidden_states_dim = x_quant.shape[-1]
|
639
|
+
if self.hidden_size != origin_hidden_states_dim:
|
640
|
+
x_quant = torch.nn.functional.pad(
|
641
|
+
x_quant,
|
642
|
+
(0, self.hidden_size - origin_hidden_states_dim),
|
643
|
+
mode="constant",
|
644
|
+
value=0.0,
|
645
|
+
)
|
646
|
+
elif self.flashinfer_mxfp4_moe_precision == "default":
|
647
|
+
x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
|
648
|
+
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
649
|
+
else:
|
650
|
+
raise NotImplementedError
|
651
|
+
|
591
652
|
assert x_quant.shape[-1] == self.hidden_size
|
653
|
+
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
592
654
|
|
593
|
-
top_k
|
655
|
+
top_k = topk_output.topk_config.top_k
|
656
|
+
router_logits = topk_output.router_logits
|
594
657
|
|
595
658
|
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
596
659
|
router_logits.to(torch.bfloat16),
|
@@ -611,8 +674,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
611
674
|
None, # output2_scale_scalar
|
612
675
|
layer.num_experts,
|
613
676
|
top_k,
|
614
|
-
None, # n_group
|
615
|
-
None, # topk_group
|
677
|
+
None, # n_group # TODO: support n_group
|
678
|
+
None, # topk_group # TODO: support topk_group
|
616
679
|
self.intermediate_size, # padded to multiple of 256
|
617
680
|
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
618
681
|
layer.num_local_experts, # local num experts
|
@@ -637,9 +700,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
637
700
|
b1=layer.w13_weight_bias,
|
638
701
|
b2=layer.w2_weight_bias,
|
639
702
|
topk_output=topk_output,
|
640
|
-
|
641
|
-
activation_alpha=activation_alpha,
|
642
|
-
swiglu_limit=swiglu_limit,
|
703
|
+
moe_runner_config=moe_runner_config,
|
643
704
|
)
|
644
705
|
else:
|
645
706
|
return self.triton_kernel_moe_forward(
|
@@ -647,6 +708,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
647
708
|
w1=layer.w13_weight,
|
648
709
|
w2=layer.w2_weight,
|
649
710
|
topk_output=topk_output,
|
711
|
+
moe_runner_config=moe_runner_config,
|
650
712
|
)
|
651
713
|
else:
|
652
714
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -656,13 +718,120 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
656
718
|
w1=layer.w13_weight,
|
657
719
|
w2=layer.w2_weight,
|
658
720
|
topk_output=topk_output,
|
721
|
+
moe_runner_config=moe_runner_config,
|
659
722
|
b1=layer.w13_weight_bias,
|
660
723
|
b2=layer.w2_weight_bias,
|
661
|
-
inplace=inplace,
|
662
|
-
activation=activation,
|
663
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
664
|
-
no_combine=no_combine,
|
665
|
-
routed_scaling_factor=routed_scaling_factor,
|
666
|
-
activation_alpha=activation_alpha,
|
667
|
-
swiglu_limit=swiglu_limit,
|
668
724
|
)
|
725
|
+
|
726
|
+
|
727
|
+
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
728
|
+
def create_weights(
|
729
|
+
self,
|
730
|
+
layer: torch.nn.Module,
|
731
|
+
num_experts: int,
|
732
|
+
hidden_size: int,
|
733
|
+
intermediate_size_per_partition: int,
|
734
|
+
params_dtype: torch.dtype,
|
735
|
+
**extra_weight_attrs,
|
736
|
+
):
|
737
|
+
|
738
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
739
|
+
|
740
|
+
w13_weight = torch.nn.Parameter(
|
741
|
+
torch.empty(
|
742
|
+
num_experts,
|
743
|
+
2 * intermediate_size_per_partition,
|
744
|
+
hidden_size,
|
745
|
+
dtype=params_dtype,
|
746
|
+
),
|
747
|
+
requires_grad=False,
|
748
|
+
)
|
749
|
+
w2_weight = torch.nn.Parameter(
|
750
|
+
torch.empty(
|
751
|
+
num_experts,
|
752
|
+
hidden_size,
|
753
|
+
intermediate_size_per_partition,
|
754
|
+
dtype=params_dtype,
|
755
|
+
),
|
756
|
+
requires_grad=False,
|
757
|
+
)
|
758
|
+
|
759
|
+
layer.register_parameter("w13_weight", w13_weight)
|
760
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
761
|
+
|
762
|
+
layer.register_parameter("w2_weight", w2_weight)
|
763
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
764
|
+
|
765
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
766
|
+
# They will be combined to a single scale after weight loading.
|
767
|
+
w13_weight_scale = torch.nn.Parameter(
|
768
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
769
|
+
)
|
770
|
+
w2_weight_scale = torch.nn.Parameter(
|
771
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
772
|
+
)
|
773
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
774
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
775
|
+
|
776
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
777
|
+
# to ensure the weight scales are loaded in properly
|
778
|
+
extra_weight_attrs.update(
|
779
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
780
|
+
)
|
781
|
+
|
782
|
+
layer.w13_input_scale = None
|
783
|
+
layer.w2_input_scale = None
|
784
|
+
|
785
|
+
def mxfp4_quantize(self, w):
|
786
|
+
w_shape = w.shape
|
787
|
+
w_need_reshape = True if w.dim() != 2 else False
|
788
|
+
|
789
|
+
if w_need_reshape:
|
790
|
+
w_last_dim_size = w_shape[-1]
|
791
|
+
w = w.view(-1, w_last_dim_size)
|
792
|
+
|
793
|
+
w, mx_scales = dynamic_mxfp4_quant(w)
|
794
|
+
|
795
|
+
if w_need_reshape:
|
796
|
+
w_new_shape = w_shape[:-1] + (w.shape[-1],)
|
797
|
+
w = w.view(w_new_shape)
|
798
|
+
|
799
|
+
mx_scales = e8m0_shuffle(mx_scales)
|
800
|
+
|
801
|
+
return w, mx_scales
|
802
|
+
|
803
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
804
|
+
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
805
|
+
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
806
|
+
|
807
|
+
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
|
808
|
+
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
|
809
|
+
|
810
|
+
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
811
|
+
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
812
|
+
|
813
|
+
def apply(
|
814
|
+
self,
|
815
|
+
layer: torch.nn.Module,
|
816
|
+
x: torch.Tensor,
|
817
|
+
topk_output: TopKOutput,
|
818
|
+
moe_runner_config: MoeRunnerConfig,
|
819
|
+
) -> torch.Tensor:
|
820
|
+
topk_weights, topk_ids, _ = topk_output
|
821
|
+
|
822
|
+
return fused_moe(
|
823
|
+
x,
|
824
|
+
layer.w13_weight,
|
825
|
+
layer.w2_weight,
|
826
|
+
topk_weights,
|
827
|
+
topk_ids,
|
828
|
+
quant_type=QuantType.per_1x32,
|
829
|
+
w1_scale=layer.w13_weight_scale,
|
830
|
+
w2_scale=layer.w2_weight_scale,
|
831
|
+
activation=(
|
832
|
+
ActivationType.Silu
|
833
|
+
if moe_runner_config.activation == "silu"
|
834
|
+
else ActivationType.Gelu
|
835
|
+
),
|
836
|
+
doweight_stage1=False,
|
837
|
+
)
|