sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/model_config.py +2 -1
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +46 -41
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/layer.py +2 -7
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/detokenizer_manager.py +0 -34
- sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +7 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +16 -0
- sglang/srt/mem_cache/memory_pool_host.py +18 -11
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +245 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/gpt_oss.py +5 -4
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/longcat_flash.py +26 -15
- sglang/srt/models/longcat_flash_nextn.py +23 -15
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -10
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
|
|
651
651
|
).T
|
652
652
|
else:
|
653
653
|
w = self_attn.kv_b_proj.weight
|
654
|
-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
655
|
-
# This may affect the accuracy of fp8 model.
|
656
|
-
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
657
654
|
use_deep_gemm_bmm = False
|
658
655
|
|
659
656
|
if w.dtype in (
|
@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
|
|
790
787
|
self.config.hidden_size / self.config.kv_lora_rank
|
791
788
|
) ** 0.5
|
792
789
|
|
790
|
+
# TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
|
791
|
+
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False
|
792
|
+
|
793
793
|
if (
|
794
794
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
795
795
|
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
|
|
804
804
|
for layer_id in range(self.config.num_hidden_layers):
|
805
805
|
layer = self.model.layers[layer_id]
|
806
806
|
for i in range(2):
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
)
|
807
|
+
self_attn = layer.self_attn[i]
|
808
|
+
module_list = [
|
809
|
+
self_attn.kv_b_proj,
|
810
|
+
self_attn.o_proj,
|
811
|
+
]
|
812
|
+
|
813
|
+
if self.config.q_lora_rank is not None:
|
814
|
+
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
|
815
|
+
module_list.append(self_attn.q_b_proj)
|
816
|
+
else:
|
817
|
+
module_list.append(self_attn.kv_a_proj_with_mqa)
|
818
|
+
module_list.append(self_attn.q_proj)
|
819
|
+
|
820
|
+
for module in module_list:
|
821
|
+
if hasattr(module, "weight_scale_inv"):
|
822
|
+
requant_weight_ue8m0_inplace(
|
823
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
824
|
+
)
|
825
|
+
|
816
826
|
mlp = layer.mlps[i]
|
817
827
|
assert isinstance(mlp, LongcatFlashMLP)
|
818
828
|
for module in [
|
819
829
|
mlp.gate_up_proj,
|
820
830
|
mlp.down_proj,
|
821
831
|
]:
|
822
|
-
|
823
|
-
|
824
|
-
|
832
|
+
if hasattr(module, "weight_scale_inv"):
|
833
|
+
requant_weight_ue8m0_inplace(
|
834
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
835
|
+
)
|
825
836
|
|
826
837
|
for layer_id in range(self.config.num_hidden_layers):
|
827
838
|
experts = layer.mlp.experts
|
@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
|
|
344
344
|
).T
|
345
345
|
else:
|
346
346
|
w = self_attn.kv_b_proj.weight
|
347
|
-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
348
|
-
# This may affect the accuracy of fp8 model.
|
349
|
-
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
350
347
|
use_deep_gemm_bmm = False
|
351
348
|
if w.dtype in (
|
352
349
|
torch.float8_e4m3fn,
|
@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
|
|
480
477
|
def _weight_requant_ue8m0(self):
|
481
478
|
weight_block_size = self.quant_config.weight_block_size
|
482
479
|
layer = self.model.decoder
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
)
|
480
|
+
self_attn = layer.self_attn
|
481
|
+
module_list = [
|
482
|
+
self_attn.kv_b_proj,
|
483
|
+
self_attn.o_proj,
|
484
|
+
]
|
485
|
+
|
486
|
+
if self.config.q_lora_rank is not None:
|
487
|
+
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
|
488
|
+
module_list.append(self_attn.q_b_proj)
|
489
|
+
else:
|
490
|
+
module_list.append(self_attn.kv_a_proj_with_mqa)
|
491
|
+
module_list.append(self_attn.q_proj)
|
492
|
+
|
493
|
+
for module in module_list:
|
494
|
+
if hasattr(module, "weight_scale_inv"):
|
495
|
+
requant_weight_ue8m0_inplace(
|
496
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
497
|
+
)
|
498
|
+
|
492
499
|
mlp = layer.mlps
|
493
500
|
assert isinstance(mlp, LongcatFlashMLP)
|
494
501
|
for module in [
|
495
502
|
mlp.gate_up_proj,
|
496
503
|
mlp.down_proj,
|
497
504
|
]:
|
498
|
-
|
499
|
-
|
500
|
-
|
505
|
+
if hasattr(module, "weight_scale_inv"):
|
506
|
+
requant_weight_ue8m0_inplace(
|
507
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
508
|
+
)
|
501
509
|
|
502
510
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
503
511
|
stacked_params_mapping = [
|
sglang/srt/models/minicpmv.py
CHANGED
@@ -54,6 +54,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
54
54
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
55
55
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
56
56
|
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
|
57
|
+
from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
|
57
58
|
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
58
59
|
from sglang.srt.utils import add_prefix, flatten_nested_list
|
59
60
|
|
@@ -581,7 +582,7 @@ class MiniCPMBaseModel(nn.Module):
|
|
581
582
|
|
582
583
|
def init_llm(
|
583
584
|
self,
|
584
|
-
config:
|
585
|
+
config: PretrainedConfig,
|
585
586
|
quant_config: Optional[QuantizationConfig] = None,
|
586
587
|
prefix: str = "",
|
587
588
|
) -> nn.Module:
|
@@ -774,7 +775,168 @@ class MiniCPMV2_6(MiniCPMBaseModel):
|
|
774
775
|
return pattern.pad_input_tokens(input_ids, image_inputs)
|
775
776
|
|
776
777
|
|
777
|
-
|
778
|
+
class MiniCPMV4_0(MiniCPMBaseModel):
|
779
|
+
packed_modules_mapping = {
|
780
|
+
"qkv_proj": [
|
781
|
+
"q_proj",
|
782
|
+
"k_proj",
|
783
|
+
"v_proj",
|
784
|
+
],
|
785
|
+
"gate_up_proj": [
|
786
|
+
"gate_proj",
|
787
|
+
"up_proj",
|
788
|
+
],
|
789
|
+
}
|
790
|
+
# LoRA specific attributes
|
791
|
+
supported_lora_modules = [
|
792
|
+
# vision encoder
|
793
|
+
"fc1",
|
794
|
+
"fc2",
|
795
|
+
"out_proj",
|
796
|
+
# language model
|
797
|
+
"qkv_proj", # same name with vision encoder
|
798
|
+
"o_proj",
|
799
|
+
"gate_up_proj",
|
800
|
+
"down_proj",
|
801
|
+
# resampler
|
802
|
+
"kv_proj",
|
803
|
+
]
|
804
|
+
|
805
|
+
# BitandBytes specific attributes
|
806
|
+
bitsandbytes_stacked_params_mapping = {
|
807
|
+
# shard_name, weight_name, index
|
808
|
+
"q_proj": ("qkv_proj", 0),
|
809
|
+
"k_proj": ("qkv_proj", 1),
|
810
|
+
"v_proj": ("qkv_proj", 2),
|
811
|
+
"gate_proj": ("gate_up_proj", 0),
|
812
|
+
"up_proj": ("gate_up_proj", 1),
|
813
|
+
}
|
814
|
+
|
815
|
+
embedding_modules = {}
|
816
|
+
embedding_padding_modules = []
|
817
|
+
|
818
|
+
def __init__(
|
819
|
+
self,
|
820
|
+
config: PretrainedConfig,
|
821
|
+
quant_config: Optional[QuantizationConfig] = None,
|
822
|
+
prefix: str = "",
|
823
|
+
):
|
824
|
+
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
825
|
+
assert self.version == (4, 0)
|
826
|
+
|
827
|
+
def init_llm(
|
828
|
+
self,
|
829
|
+
config: LlamaConfig,
|
830
|
+
quant_config: Optional[QuantizationConfig] = None,
|
831
|
+
prefix: str = "",
|
832
|
+
) -> nn.Module:
|
833
|
+
return LlamaForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
|
834
|
+
|
835
|
+
def init_vision_module(
|
836
|
+
self,
|
837
|
+
config: PretrainedConfig,
|
838
|
+
quant_config: Optional[QuantizationConfig],
|
839
|
+
prefix: str = "",
|
840
|
+
) -> nn.Module:
|
841
|
+
model = Idefics2VisionTransformer(
|
842
|
+
config=config.vision_config, quant_config=quant_config, prefix=prefix
|
843
|
+
)
|
844
|
+
if self.config.drop_vision_last_layer:
|
845
|
+
model.encoder.layers = model.encoder.layers[:-1]
|
846
|
+
|
847
|
+
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
848
|
+
setattr(model, "patch_size", model.embeddings.patch_size)
|
849
|
+
return model
|
850
|
+
|
851
|
+
def init_resampler(
|
852
|
+
self,
|
853
|
+
embed_dim: int,
|
854
|
+
vision_dim: int,
|
855
|
+
quant_config: Optional[QuantizationConfig] = None,
|
856
|
+
prefix: str = "",
|
857
|
+
) -> nn.Module:
|
858
|
+
with set_default_torch_dtype(torch.float16):
|
859
|
+
# The resampler in 2.6 remains consistent with the one in 2.5.
|
860
|
+
resampler = Resampler2_5(
|
861
|
+
num_queries=self.config.query_num,
|
862
|
+
embed_dim=embed_dim,
|
863
|
+
num_heads=embed_dim // 128,
|
864
|
+
kv_dim=vision_dim,
|
865
|
+
quant_config=quant_config,
|
866
|
+
prefix=prefix,
|
867
|
+
)
|
868
|
+
|
869
|
+
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
870
|
+
|
871
|
+
def get_vision_embedding(
|
872
|
+
self,
|
873
|
+
pixel_values: List[torch.Tensor],
|
874
|
+
patch_attn_mask: Optional[torch.Tensor] = None,
|
875
|
+
tgt_sizes: Optional[torch.Tensor] = None,
|
876
|
+
) -> torch.Tensor:
|
877
|
+
vision_embedding = self.vpm(
|
878
|
+
pixel_values,
|
879
|
+
patch_attention_mask=patch_attn_mask,
|
880
|
+
tgt_sizes=tgt_sizes,
|
881
|
+
)
|
882
|
+
return vision_embedding
|
883
|
+
|
884
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
885
|
+
# list of tensors
|
886
|
+
pixel_values = flatten_nested_list([item.feature for item in items])
|
887
|
+
tgt_sizes = torch.stack(
|
888
|
+
flatten_nested_list([item.tgt_size for item in items]), dim=0
|
889
|
+
)
|
890
|
+
assert len(pixel_values) == tgt_sizes.shape[0]
|
891
|
+
|
892
|
+
device = self.vpm.embeddings.position_embedding.weight.device
|
893
|
+
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
894
|
+
all_pixel_values_lst = [
|
895
|
+
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
|
896
|
+
]
|
897
|
+
|
898
|
+
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
899
|
+
assert isinstance(max_patches, int)
|
900
|
+
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
901
|
+
all_pixel_values_lst, batch_first=True, padding_value=0.0
|
902
|
+
)
|
903
|
+
|
904
|
+
B, L, _ = all_pixel_values.shape
|
905
|
+
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
906
|
+
patch_attn_mask = torch.zeros(
|
907
|
+
(B, 1, max_patches), dtype=torch.bool, device=device
|
908
|
+
)
|
909
|
+
|
910
|
+
tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
|
911
|
+
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
|
912
|
+
patch_attn_mask[:, 0, :] = torch.arange(
|
913
|
+
patch_attn_mask.size(2), device=patch_attn_mask.device
|
914
|
+
).unsqueeze(0) < mask_shapes.unsqueeze(1)
|
915
|
+
|
916
|
+
vision_embedding = self.vpm(
|
917
|
+
all_pixel_values.type(dtype),
|
918
|
+
patch_attention_mask=patch_attn_mask,
|
919
|
+
tgt_sizes=tgt_sizes,
|
920
|
+
)
|
921
|
+
return self.resampler(vision_embedding, tgt_sizes)
|
922
|
+
|
923
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
924
|
+
# Get all special token IDs
|
925
|
+
im_start_id: int = image_inputs.im_start_id
|
926
|
+
im_end_id: int = image_inputs.im_end_id
|
927
|
+
slice_start_id: int = image_inputs.slice_start_id
|
928
|
+
slice_end_id: int = image_inputs.slice_end_id
|
929
|
+
|
930
|
+
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
|
931
|
+
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
932
|
+
|
933
|
+
return pattern.pad_input_tokens(input_ids, image_inputs)
|
934
|
+
|
935
|
+
|
936
|
+
_SUPPORT_VERSION = {
|
937
|
+
(2, 6): MiniCPMV2_6,
|
938
|
+
(4, 0): MiniCPMV4_0,
|
939
|
+
}
|
778
940
|
|
779
941
|
|
780
942
|
class MiniCPMV:
|
@@ -809,7 +971,7 @@ class MiniCPMV:
|
|
809
971
|
# Dispatch class based on version
|
810
972
|
instance_class = _SUPPORT_VERSION.get(version)
|
811
973
|
if instance_class is None:
|
812
|
-
raise ValueError("Currently, MiniCPMV only supports versions 2.6")
|
974
|
+
raise ValueError("Currently, MiniCPMV only supports versions 2.6 and 4.0")
|
813
975
|
|
814
976
|
try:
|
815
977
|
minicpmv = instance_class(
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module):
|
|
105
105
|
def forward(
|
106
106
|
self,
|
107
107
|
x,
|
108
|
+
should_allreduce_fusion: bool = False,
|
108
109
|
use_reduce_scatter: bool = False,
|
109
110
|
):
|
110
111
|
gate_up, _ = self.gate_up_proj(x)
|
111
112
|
x = self.act_fn(gate_up)
|
112
|
-
x, _ = self.down_proj(
|
113
|
+
x, _ = self.down_proj(
|
114
|
+
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
115
|
+
)
|
113
116
|
return x
|
114
117
|
|
115
118
|
|
sglang/srt/models/qwen3.py
CHANGED
@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
|
24
24
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
25
25
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
27
|
-
from sglang.srt.model_loader.weight_utils import
|
27
|
+
from sglang.srt.model_loader.weight_utils import (
|
28
|
+
default_weight_loader,
|
29
|
+
maybe_remap_kv_scale_name,
|
30
|
+
)
|
28
31
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
29
32
|
from sglang.srt.models.qwen2 import Qwen2Model
|
30
33
|
from sglang.srt.utils import add_prefix, is_cuda
|
@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
|
|
458
461
|
continue
|
459
462
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
460
463
|
continue
|
461
|
-
|
464
|
+
if "scale" in name:
|
465
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
466
|
+
if name is None:
|
467
|
+
continue
|
462
468
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
463
469
|
if weight_name not in name:
|
464
470
|
continue
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
|
|
42
42
|
RowParallelLinear,
|
43
43
|
)
|
44
44
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
|
-
from sglang.srt.layers.moe import
|
45
|
+
from sglang.srt.layers.moe import (
|
46
|
+
get_moe_a2a_backend,
|
47
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
48
|
+
)
|
46
49
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
47
50
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
48
51
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|
57
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
61
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
59
62
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
60
|
-
from sglang.srt.utils import
|
63
|
+
from sglang.srt.utils import (
|
64
|
+
add_prefix,
|
65
|
+
is_cuda,
|
66
|
+
is_flashinfer_available,
|
67
|
+
is_non_idle_and_non_empty,
|
68
|
+
)
|
61
69
|
|
62
70
|
Qwen3MoeConfig = None
|
63
71
|
|
72
|
+
_is_flashinfer_available = is_flashinfer_available()
|
73
|
+
|
64
74
|
logger = logging.getLogger(__name__)
|
65
75
|
_is_cuda = is_cuda()
|
66
76
|
|
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
119
129
|
self,
|
120
130
|
hidden_states: torch.Tensor,
|
121
131
|
forward_batch: Optional[ForwardBatch] = None,
|
132
|
+
should_allreduce_fusion: bool = False,
|
122
133
|
use_reduce_scatter: bool = False,
|
123
134
|
) -> torch.Tensor:
|
124
135
|
|
125
136
|
if not get_moe_a2a_backend().is_deepep():
|
126
|
-
return self.forward_normal(
|
137
|
+
return self.forward_normal(
|
138
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
139
|
+
)
|
127
140
|
else:
|
128
141
|
return self.forward_deepep(hidden_states, forward_batch)
|
129
142
|
|
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
137
150
|
def forward_normal(
|
138
151
|
self,
|
139
152
|
hidden_states: torch.Tensor,
|
153
|
+
should_allreduce_fusion: bool = False,
|
140
154
|
use_reduce_scatter: bool = False,
|
141
155
|
) -> torch.Tensor:
|
142
156
|
num_tokens, hidden_dim = hidden_states.shape
|
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
146
160
|
router_logits, _ = self.gate(hidden_states)
|
147
161
|
topk_output = self.topk(hidden_states, router_logits)
|
148
162
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
149
|
-
if
|
163
|
+
if (
|
164
|
+
self.tp_size > 1
|
165
|
+
and not should_allreduce_fusion
|
166
|
+
and not use_reduce_scatter
|
167
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
168
|
+
):
|
150
169
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
151
170
|
|
152
171
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
500
519
|
input_layernorm=self.input_layernorm,
|
501
520
|
post_attention_layernorm=self.post_attention_layernorm,
|
502
521
|
allow_reduce_scatter=True,
|
522
|
+
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
|
503
523
|
)
|
504
524
|
|
505
525
|
def forward(
|
@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
525
545
|
hidden_states, residual, forward_batch
|
526
546
|
)
|
527
547
|
|
548
|
+
should_allreduce_fusion = (
|
549
|
+
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
550
|
+
forward_batch
|
551
|
+
)
|
552
|
+
)
|
553
|
+
|
528
554
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
529
555
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
530
556
|
forward_batch
|
531
557
|
)
|
532
558
|
|
533
|
-
hidden_states = self.mlp(
|
534
|
-
|
535
|
-
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
536
|
-
hidden_states, residual, forward_batch
|
559
|
+
hidden_states = self.mlp(
|
560
|
+
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
537
561
|
)
|
538
562
|
|
563
|
+
if should_allreduce_fusion:
|
564
|
+
hidden_states._sglang_needs_allreduce_fusion = True
|
565
|
+
else:
|
566
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
567
|
+
hidden_states, residual, forward_batch
|
568
|
+
)
|
569
|
+
|
539
570
|
return hidden_states, residual
|
540
571
|
|
541
572
|
def op_comm_prepare_attn(
|
@@ -22,7 +22,7 @@ Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
|
|
22
22
|
|
23
23
|
Here is a quick example to enable TP:
|
24
24
|
```python
|
25
|
-
from sglang.srt.model_parallel import tensor_parallel
|
25
|
+
from sglang.srt.layers.model_parallel import tensor_parallel
|
26
26
|
|
27
27
|
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
|
28
28
|
tensor_parallel(model, device_mesh)
|
sglang/srt/server_args.py
CHANGED
@@ -26,7 +26,7 @@ from typing import List, Literal, Optional, Union
|
|
26
26
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
27
27
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
28
28
|
from sglang.srt.lora.lora_registry import LoRARef
|
29
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
29
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
30
30
|
from sglang.srt.utils import (
|
31
31
|
LORA_TARGET_ALL_MODULES,
|
32
32
|
SUPPORTED_LORA_TARGET_MODULES,
|
@@ -195,6 +195,8 @@ class ServerArgs:
|
|
195
195
|
bucket_inter_token_latency: Optional[List[float]] = None
|
196
196
|
bucket_e2e_request_latency: Optional[List[float]] = None
|
197
197
|
collect_tokens_histogram: bool = False
|
198
|
+
prompt_tokens_buckets: Optional[List[str]] = None
|
199
|
+
generation_tokens_buckets: Optional[List[str]] = None
|
198
200
|
decode_log_interval: int = 40
|
199
201
|
enable_request_time_stats_logging: bool = False
|
200
202
|
kv_events_config: Optional[str] = None
|
@@ -1234,6 +1236,26 @@ class ServerArgs:
|
|
1234
1236
|
default=ServerArgs.collect_tokens_histogram,
|
1235
1237
|
help="Collect prompt/generation tokens histogram.",
|
1236
1238
|
)
|
1239
|
+
bucket_rule = (
|
1240
|
+
"Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' "
|
1241
|
+
"generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets "
|
1242
|
+
"[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> "
|
1243
|
+
"<value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500')."
|
1244
|
+
)
|
1245
|
+
parser.add_argument(
|
1246
|
+
"--prompt-tokens-buckets",
|
1247
|
+
type=str,
|
1248
|
+
nargs="+",
|
1249
|
+
default=ServerArgs.prompt_tokens_buckets,
|
1250
|
+
help=f"The buckets rule of prompt tokens. {bucket_rule}",
|
1251
|
+
)
|
1252
|
+
parser.add_argument(
|
1253
|
+
"--generation-tokens-buckets",
|
1254
|
+
type=str,
|
1255
|
+
nargs="+",
|
1256
|
+
default=ServerArgs.generation_tokens_buckets,
|
1257
|
+
help=f"The buckets rule for generation tokens histogram. {bucket_rule}",
|
1258
|
+
)
|
1237
1259
|
parser.add_argument(
|
1238
1260
|
"--gc-warning-threshold-secs",
|
1239
1261
|
type=float,
|
@@ -2185,6 +2207,12 @@ class ServerArgs:
|
|
2185
2207
|
|
2186
2208
|
# Check multi tokenizer
|
2187
2209
|
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
|
2210
|
+
self.validate_buckets_rule(
|
2211
|
+
"--prompt-tokens-buckets", self.prompt_tokens_buckets
|
2212
|
+
)
|
2213
|
+
self.validate_buckets_rule(
|
2214
|
+
"--generation-tokens-buckets", self.generation_tokens_buckets
|
2215
|
+
)
|
2188
2216
|
|
2189
2217
|
def check_lora_server_args(self):
|
2190
2218
|
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
@@ -2277,6 +2305,54 @@ class ServerArgs:
|
|
2277
2305
|
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
2278
2306
|
)
|
2279
2307
|
|
2308
|
+
def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]):
|
2309
|
+
if not buckets_rule:
|
2310
|
+
return
|
2311
|
+
|
2312
|
+
assert len(buckets_rule) > 0, f"{arg_name} cannot be empty list"
|
2313
|
+
rule = buckets_rule[0]
|
2314
|
+
assert rule in [
|
2315
|
+
"tse",
|
2316
|
+
"default",
|
2317
|
+
"customer",
|
2318
|
+
], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'customer'"
|
2319
|
+
|
2320
|
+
if rule == "tse":
|
2321
|
+
assert (
|
2322
|
+
len(buckets_rule) == 4
|
2323
|
+
), f"{arg_name} TSE rule requires exactly 4 parameters: ['tse', middle, base, count], got {len(buckets_rule)}"
|
2324
|
+
try:
|
2325
|
+
middle = float(buckets_rule[1])
|
2326
|
+
base = float(buckets_rule[2])
|
2327
|
+
count = int(buckets_rule[3])
|
2328
|
+
except (ValueError, IndexError):
|
2329
|
+
assert (
|
2330
|
+
False
|
2331
|
+
), f"{arg_name} TSE rule parameters must be: ['tse', <float:middle>, <float:base>, <int:count>]"
|
2332
|
+
assert base > 1, f"{arg_name} TSE base must be larger than 1, got: {base}"
|
2333
|
+
assert count > 0, f"{arg_name} TSE count must be positive, got: {count}"
|
2334
|
+
assert middle > 0, f"{arg_name} TSE middle must be positive, got: {middle}"
|
2335
|
+
|
2336
|
+
elif rule == "default":
|
2337
|
+
assert (
|
2338
|
+
len(buckets_rule) == 1
|
2339
|
+
), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}"
|
2340
|
+
|
2341
|
+
elif rule == "customer":
|
2342
|
+
assert (
|
2343
|
+
len(buckets_rule) >= 2
|
2344
|
+
), f"{arg_name} customer rule requires at least one bucket value: ['customer', value1, ...]"
|
2345
|
+
try:
|
2346
|
+
bucket_values = [float(x) for x in buckets_rule[1:]]
|
2347
|
+
except ValueError:
|
2348
|
+
assert False, f"{arg_name} customer rule bucket values must be numeric"
|
2349
|
+
assert len(set(bucket_values)) == len(
|
2350
|
+
bucket_values
|
2351
|
+
), f"{arg_name} customer rule bucket values should not contain duplicates"
|
2352
|
+
assert all(
|
2353
|
+
val >= 0 for val in bucket_values
|
2354
|
+
), f"{arg_name} customer rule bucket values should be non-negative"
|
2355
|
+
|
2280
2356
|
def model_specific_adjustments(self):
|
2281
2357
|
hf_config = self.get_hf_config()
|
2282
2358
|
model_arch = hf_config.architectures[0]
|
@@ -2336,7 +2412,8 @@ class ServerArgs:
|
|
2336
2412
|
assert self.attention_backend in {
|
2337
2413
|
"fa3",
|
2338
2414
|
"aiter",
|
2339
|
-
|
2415
|
+
"triton",
|
2416
|
+
}, "fa3, aiter, or triton is required for Llama4 model"
|
2340
2417
|
elif model_arch in [
|
2341
2418
|
"Gemma2ForCausalLM",
|
2342
2419
|
"Gemma3ForCausalLM",
|