sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
113
113
|
quant_config: Optional[QuantizationConfig] = None,
|
114
114
|
prefix: str = "",
|
115
115
|
num_dummy_heads: int = 0,
|
116
|
+
rms_norm_eps: float = 1e-6,
|
116
117
|
) -> None:
|
117
118
|
super().__init__()
|
118
119
|
if norm_layer is None:
|
119
120
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
120
|
-
self.norm1 = RMSNorm(dim, eps=
|
121
|
-
self.norm2 = RMSNorm(dim, eps=
|
121
|
+
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
|
122
|
+
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
|
122
123
|
|
123
124
|
if attn_implementation is None:
|
124
125
|
softmax_in_single_precision = False
|
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
517
518
|
self.logits_processor = LogitsProcessor(config)
|
518
519
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
519
520
|
|
521
|
+
# For EAGLE3 support
|
522
|
+
self.capture_aux_hidden_states = False
|
523
|
+
|
520
524
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
521
525
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
522
526
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
587
591
|
positions=positions,
|
588
592
|
)
|
589
593
|
|
594
|
+
aux_hidden_states = None
|
595
|
+
if self.capture_aux_hidden_states:
|
596
|
+
hidden_states, aux_hidden_states = hidden_states
|
597
|
+
|
590
598
|
if not get_embedding:
|
591
599
|
return self.logits_processor(
|
592
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
600
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
593
601
|
)
|
594
602
|
else:
|
595
603
|
return self.pooler(hidden_states, forward_batch)
|
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
643
651
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
644
652
|
weight_loader(param, loaded_weight)
|
645
653
|
|
654
|
+
def get_embed_and_head(self):
|
655
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
656
|
+
|
657
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
658
|
+
self.capture_aux_hidden_states = True
|
659
|
+
self.model.capture_aux_hidden_states = True
|
660
|
+
if layer_ids is None:
|
661
|
+
num_layers = self.config.num_hidden_layers
|
662
|
+
self.model.layers_to_capture = [
|
663
|
+
2,
|
664
|
+
num_layers // 2,
|
665
|
+
num_layers - 3,
|
666
|
+
] # Specific layers for EAGLE3 support
|
667
|
+
else:
|
668
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
669
|
+
|
646
670
|
|
647
671
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -62,13 +62,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
62
62
|
VocabParallelEmbedding,
|
63
63
|
)
|
64
64
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
65
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
65
66
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
67
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
68
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
68
|
-
from sglang.srt.utils import add_prefix, make_layers
|
69
|
+
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
69
70
|
|
70
71
|
logger = logging.getLogger(__name__)
|
71
72
|
|
73
|
+
_is_cuda = is_cuda()
|
74
|
+
|
72
75
|
|
73
76
|
class Qwen2MoeMLP(nn.Module):
|
74
77
|
def __init__(
|
@@ -122,11 +125,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
122
125
|
layer_id: int,
|
123
126
|
config: PretrainedConfig,
|
124
127
|
quant_config: Optional[QuantizationConfig] = None,
|
128
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
125
129
|
prefix: str = "",
|
126
130
|
):
|
127
131
|
super().__init__()
|
128
132
|
self.tp_size = get_tensor_model_parallel_world_size()
|
129
133
|
self.layer_id = layer_id
|
134
|
+
self.alt_stream = alt_stream
|
130
135
|
if self.tp_size > config.num_experts:
|
131
136
|
raise ValueError(
|
132
137
|
f"Tensor parallel size {self.tp_size} is greater than "
|
@@ -138,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
138
143
|
renormalize=config.norm_topk_prob,
|
139
144
|
)
|
140
145
|
|
141
|
-
self.experts = get_moe_impl_class()(
|
146
|
+
self.experts = get_moe_impl_class(quant_config)(
|
142
147
|
layer_id=self.layer_id,
|
143
148
|
top_k=config.num_experts_per_tok,
|
144
149
|
num_experts=config.num_experts,
|
@@ -168,14 +173,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
168
173
|
self.shared_expert = None
|
169
174
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
170
175
|
|
171
|
-
def
|
172
|
-
self,
|
173
|
-
hidden_states: torch.Tensor,
|
174
|
-
forward_batch: Optional[ForwardBatch] = None,
|
175
|
-
use_reduce_scatter: bool = False,
|
176
|
-
) -> torch.Tensor:
|
177
|
-
num_tokens, hidden_dim = hidden_states.shape
|
178
|
-
hidden_states = hidden_states.view(-1, hidden_dim)
|
176
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
179
177
|
shared_output = None
|
180
178
|
if self.shared_expert is not None:
|
181
179
|
shared_output = self.shared_expert(hidden_states)
|
@@ -183,11 +181,52 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
183
181
|
shared_output = (
|
184
182
|
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
185
183
|
)
|
184
|
+
return shared_output
|
186
185
|
|
186
|
+
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
187
187
|
# router_logits: (num_tokens, n_experts)
|
188
188
|
router_logits, _ = self.gate(hidden_states)
|
189
189
|
topk_output = self.topk(hidden_states, router_logits)
|
190
|
-
|
190
|
+
return self.experts(hidden_states, topk_output)
|
191
|
+
|
192
|
+
def forward_normal_dual_stream(
|
193
|
+
self,
|
194
|
+
hidden_states: torch.Tensor,
|
195
|
+
) -> torch.Tensor:
|
196
|
+
current_stream = torch.cuda.current_stream()
|
197
|
+
self.alt_stream.wait_stream(current_stream)
|
198
|
+
shared_output = self._forward_shared_experts(hidden_states.clone())
|
199
|
+
|
200
|
+
with torch.cuda.stream(self.alt_stream):
|
201
|
+
router_output = self._forward_router_experts(hidden_states)
|
202
|
+
|
203
|
+
current_stream.wait_stream(self.alt_stream)
|
204
|
+
|
205
|
+
return router_output, shared_output
|
206
|
+
|
207
|
+
def forward(
|
208
|
+
self,
|
209
|
+
hidden_states: torch.Tensor,
|
210
|
+
forward_batch: Optional[ForwardBatch] = None,
|
211
|
+
use_reduce_scatter: bool = False,
|
212
|
+
) -> torch.Tensor:
|
213
|
+
num_tokens, hidden_dim = hidden_states.shape
|
214
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
215
|
+
|
216
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
217
|
+
if (
|
218
|
+
self.alt_stream is not None
|
219
|
+
and hidden_states.shape[0] > 0
|
220
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
221
|
+
and get_is_capture_mode()
|
222
|
+
):
|
223
|
+
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
224
|
+
hidden_states
|
225
|
+
)
|
226
|
+
else:
|
227
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
228
|
+
final_hidden_states = self._forward_router_experts(hidden_states)
|
229
|
+
|
191
230
|
if shared_output is not None:
|
192
231
|
final_hidden_states = final_hidden_states + shared_output
|
193
232
|
if self.tp_size > 1 and not use_reduce_scatter:
|
@@ -346,6 +385,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
346
385
|
layer_id=layer_id,
|
347
386
|
config=config,
|
348
387
|
quant_config=quant_config,
|
388
|
+
alt_stream=alt_stream,
|
349
389
|
prefix=add_prefix("mlp", prefix),
|
350
390
|
)
|
351
391
|
else:
|
@@ -528,8 +568,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
528
568
|
self.pp_group = get_pp_group()
|
529
569
|
self.config = config
|
530
570
|
self.quant_config = quant_config
|
571
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
531
572
|
self.model = Qwen2MoeModel(
|
532
|
-
config,
|
573
|
+
config,
|
574
|
+
quant_config,
|
575
|
+
prefix=add_prefix("model", prefix),
|
576
|
+
alt_stream=alt_stream,
|
533
577
|
)
|
534
578
|
self.lm_head = ParallelLMHead(
|
535
579
|
config.vocab_size,
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
98
98
|
use_grouped_topk=False,
|
99
99
|
)
|
100
100
|
|
101
|
-
self.experts = get_moe_impl_class()(
|
101
|
+
self.experts = get_moe_impl_class(quant_config)(
|
102
102
|
num_experts=config.num_experts
|
103
103
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
104
104
|
top_k=config.num_experts_per_tok,
|