sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,84 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from typing import Iterable, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers import Qwen2Config # Qwen3 uses Qwen2Config
20
+
21
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
25
+ from sglang.srt.utils import add_prefix
26
+
27
+
28
+ class Qwen3ForSequenceClassification(nn.Module):
29
+ def __init__(
30
+ self,
31
+ config: Qwen2Config,
32
+ quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
34
+ ) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+ self.quant_config = quant_config
38
+ self.model = Qwen3Model(
39
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
40
+ )
41
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
42
+ # Use normalize=True for qwen3 embedding based on official implementation
43
+ # Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
44
+ # Official code: output = F.normalize(output, p=2, dim=1)
45
+ normalize = True
46
+
47
+ # We don't want to normalize the embedding if we have a classification head
48
+ if config.id2label is not None or config.label2id is not None:
49
+ normalize = False
50
+
51
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)
52
+
53
+ self.eos_token_id = config.eos_token_id
54
+
55
+ @torch.no_grad()
56
+ def forward(
57
+ self,
58
+ input_ids: torch.Tensor,
59
+ positions: torch.Tensor,
60
+ forward_batch: ForwardBatch,
61
+ input_embeds: Optional[torch.Tensor] = None,
62
+ get_embedding: bool = True,
63
+ ) -> EmbeddingPoolerOutput:
64
+ assert (
65
+ get_embedding
66
+ ), "Qwen3ForSequenceClassification is only used for embedding"
67
+
68
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
69
+ logits = self.score(hidden_states)
70
+ pooled_logits = self.pooler(logits, forward_batch).embeddings
71
+
72
+ return EmbeddingPoolerOutput(pooled_logits)
73
+
74
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
75
+ # Filter out lm_head weights of Qwen3ForCausalLM
76
+ filtered_weights = [
77
+ (name, w) for name, w in weights if not name.startswith("lm_head")
78
+ ]
79
+ return Qwen3ForCausalLM.load_weights(self, filtered_weights)
80
+
81
+
82
+ EntryClass = [
83
+ Qwen3ForSequenceClassification,
84
+ ]
@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
28
28
  get_pp_group,
29
29
  get_tensor_model_parallel_rank,
30
30
  get_tensor_model_parallel_world_size,
31
- parallel_state,
32
- split_tensor_along_last_dim,
33
- tensor_model_parallel_all_gather,
34
31
  tensor_model_parallel_all_reduce,
35
32
  )
36
33
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37
34
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
38
35
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
39
- from sglang.srt.layers.activation import SiluAndMul
40
36
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
41
- from sglang.srt.layers.dp_attention import (
42
- get_attention_tp_rank,
43
- get_attention_tp_size,
44
- get_local_attention_dp_size,
45
- )
37
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
46
38
  from sglang.srt.layers.layernorm import RMSNorm
47
39
  from sglang.srt.layers.linear import (
48
- MergedColumnParallelLinear,
49
40
  QKVParallelLinear,
50
41
  ReplicatedLinear,
51
42
  RowParallelLinear,
52
43
  )
53
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
44
+ from sglang.srt.layers.logits_processor import LogitsProcessor
45
+ from sglang.srt.layers.moe import get_moe_a2a_backend
54
46
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
47
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
55
48
  from sglang.srt.layers.moe.topk import TopK
56
49
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
50
  from sglang.srt.layers.radix_attention import RadixAttention
58
51
  from sglang.srt.layers.rotary_embedding import get_rope
59
52
  from sglang.srt.layers.utils import get_layer_id
60
- from sglang.srt.layers.vocab_parallel_embedding import (
61
- ParallelLMHead,
62
- VocabParallelEmbedding,
63
- )
53
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
64
54
  from sglang.srt.managers.schedule_batch import global_server_args_dict
65
55
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
66
- from sglang.srt.model_executor.forward_batch_info import (
67
- ForwardBatch,
68
- ForwardMode,
69
- PPProxyTensors,
70
- )
56
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
71
57
  from sglang.srt.model_loader.weight_utils import default_weight_loader
72
58
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
73
59
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
74
- from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
75
60
  from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
76
61
 
77
62
  Qwen3MoeConfig = None
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
112
97
  intermediate_size=config.moe_intermediate_size,
113
98
  quant_config=quant_config,
114
99
  prefix=add_prefix("experts", prefix),
115
- **(
116
- dict(deepep_mode=global_server_args_dict["deepep_mode"])
117
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
118
- else {}
119
- ),
120
- # Additional args for FusedMoE
121
- **(
122
- dict(
123
- enable_flashinfer_cutlass_moe=True,
124
- )
125
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
126
- else {}
127
- ),
128
100
  )
129
101
 
130
102
  self.gate = ReplicatedLinear(
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
135
107
  prefix=add_prefix("gate", prefix),
136
108
  )
137
109
 
138
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
110
+ if get_moe_a2a_backend().is_deepep():
139
111
  # TODO: we will support tp < ep in the future
140
112
  self.ep_size = get_moe_expert_parallel_world_size()
141
113
  self.num_experts = (
@@ -144,11 +116,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
144
116
  self.top_k = config.num_experts_per_tok
145
117
 
146
118
  def forward(
147
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ forward_batch: Optional[ForwardBatch] = None,
122
+ use_reduce_scatter: bool = False,
148
123
  ) -> torch.Tensor:
149
124
 
150
- if not global_server_args_dict["moe_a2a_backend"].is_deepep():
151
- return self.forward_normal(hidden_states)
125
+ if not get_moe_a2a_backend().is_deepep():
126
+ return self.forward_normal(hidden_states, use_reduce_scatter)
152
127
  else:
153
128
  return self.forward_deepep(hidden_states, forward_batch)
154
129
 
@@ -159,7 +134,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
159
134
  if name not in ["correction_bias"]
160
135
  ]
161
136
 
162
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
137
+ def forward_normal(
138
+ self,
139
+ hidden_states: torch.Tensor,
140
+ use_reduce_scatter: bool = False,
141
+ ) -> torch.Tensor:
163
142
  num_tokens, hidden_dim = hidden_states.shape
164
143
  hidden_states = hidden_states.view(-1, hidden_dim)
165
144
 
@@ -167,7 +146,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
167
146
  router_logits, _ = self.gate(hidden_states)
168
147
  topk_output = self.topk(hidden_states, router_logits)
169
148
  final_hidden_states = self.experts(hidden_states, topk_output)
170
- if self.tp_size > 1:
149
+ if self.tp_size > 1 and not use_reduce_scatter:
171
150
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
172
151
 
173
152
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -484,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
484
463
 
485
464
  self.attn_tp_size = get_attention_tp_size()
486
465
  self.attn_tp_rank = get_attention_tp_rank()
487
- self.local_dp_size = get_local_attention_dp_size()
488
466
 
489
467
  # Qwen3MoE all layers are sparse and have no nextn now
490
468
  self.is_layer_sparse = True
@@ -521,6 +499,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
521
499
  layer_scatter_modes=self.layer_scatter_modes,
522
500
  input_layernorm=self.input_layernorm,
523
501
  post_attention_layernorm=self.post_attention_layernorm,
502
+ allow_reduce_scatter=True,
524
503
  )
525
504
 
526
505
  def forward(
@@ -546,7 +525,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
546
525
  hidden_states, residual, forward_batch
547
526
  )
548
527
 
549
- hidden_states = self.mlp(hidden_states, forward_batch)
528
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
529
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
530
+ forward_batch
531
+ )
532
+
533
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
550
534
 
551
535
  hidden_states, residual = self.layer_communicator.postprocess_layer(
552
536
  hidden_states, residual, forward_batch
@@ -765,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
765
749
  ("gate_up_proj", "up_proj", 1),
766
750
  ]
767
751
 
768
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
752
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
769
753
  ckpt_gate_proj_name="gate_proj",
770
754
  ckpt_down_proj_name="down_proj",
771
755
  ckpt_up_proj_name="up_proj",
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
25
25
  from sglang.srt.layers.activation import SiluAndMul
26
26
  from sglang.srt.layers.attention.vision import VisionAttention
27
27
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
28
- from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
28
+ from sglang.srt.layers.dp_attention import (
29
+ get_attention_tp_rank,
30
+ get_attention_tp_size,
31
+ is_dp_attention_enabled,
32
+ )
29
33
  from sglang.srt.layers.layernorm import RMSNorm
30
34
  from sglang.srt.layers.linear import (
31
35
  ColumnParallelLinear,
@@ -34,6 +38,7 @@ from sglang.srt.layers.linear import (
34
38
  RowParallelLinear,
35
39
  )
36
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
+ from sglang.srt.layers.moe import get_moe_a2a_backend
37
42
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
38
43
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
39
44
  from sglang.srt.layers.moe.topk import TopK
@@ -146,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
146
151
  prefix=add_prefix("gate", prefix),
147
152
  )
148
153
 
149
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
154
+ if get_moe_a2a_backend().is_deepep():
150
155
  raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
151
156
 
152
157
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -437,7 +442,7 @@ class Step3TextModel(nn.Module):
437
442
  self.embed_tokens = VocabParallelEmbedding(
438
443
  config.vocab_size,
439
444
  config.hidden_size,
440
- enable_tp=not global_server_args_dict["enable_dp_attention"],
445
+ enable_tp=not is_dp_attention_enabled(),
441
446
  prefix=add_prefix("embed_tokens", prefix),
442
447
  )
443
448
 
@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
33
33
  RowParallelLinear,
34
34
  )
35
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
- from sglang.srt.layers.moe.fused_moe_triton import fused_moe
36
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
37
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
38
+ from sglang.srt.layers.moe.topk import TopK
37
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
41
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
121
123
  ]
122
124
  )
123
125
  self.pack_params()
126
+ self.moe_runner_config = MoeRunnerConfig(inplace=True)
124
127
 
125
128
  self.router = ReplicatedLinear(
126
129
  config.hidden_size,
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
129
132
  quant_config=None,
130
133
  prefix=add_prefix("router", prefix),
131
134
  )
135
+ self.topk = TopK(
136
+ top_k=self.top_k,
137
+ renormalize=getattr(self.config, "norm_topk_prob", False),
138
+ )
132
139
 
133
140
  if config.num_shared_experts is not None:
134
141
  intermediate_size = config.intermediate_size * config.num_shared_experts
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
167
174
  shared_output = self.shared_experts(hidden_states)
168
175
  # router_logits: (num_tokens, n_experts)
169
176
  router_logits, _ = self.router(hidden_states)
177
+ topk_output = self.topk(hidden_states, router_logits)
170
178
  final_hidden_states = fused_moe(
171
179
  hidden_states,
172
180
  self.w1,
173
181
  self.w2,
174
- router_logits,
175
- self.top_k,
176
- renormalize=getattr(self.config, "norm_topk_prob", False),
177
- inplace=True,
182
+ topk_output,
183
+ self.moe_runner_config,
178
184
  )
179
185
 
180
186
  if self.config.num_shared_experts is not None:
@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC):
217
217
  if videos:
218
218
  kwargs["videos"] = videos
219
219
  if audios:
220
- if self.arch in {
221
- "Gemma3nForConditionalGeneration",
222
- "Qwen2AudioForConditionalGeneration",
220
+ if self._processor.__class__.__name__ in {
221
+ "Gemma3nProcessor",
222
+ "Qwen2AudioProcessor",
223
223
  }:
224
224
  # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
225
225
  kwargs["audio"] = audios
@@ -44,7 +44,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
44
44
  self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
45
45
  self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
46
46
  self.mm_tokens = MultimodalSpecialTokens(
47
- image_token="<image>",
47
+ image_token="<IMG_CONTEXT>",
48
48
  image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
49
49
  ).build(_image_processor)
50
50
 
@@ -218,13 +218,18 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
218
218
 
219
219
  pixel_values = torch.cat(pixel_values, dim=0)
220
220
 
221
+ original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
222
+ input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
223
+
221
224
  for idx, num_patches in enumerate(num_patches_list):
222
225
  image_tokens = (
223
226
  self.IMG_START_TOKEN
224
227
  + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
225
228
  + self.IMG_END_TOKEN
226
229
  )
227
- input_text = input_text.replace("<image>", image_tokens, 1)
230
+ input_text = input_text.replace(original_placeholder, image_tokens, 1)
231
+
232
+ input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN)
228
233
 
229
234
  input_ids = self.tokenizer(input_text, return_tensors="pt")[
230
235
  "input_ids"
@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM
18
18
  from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
19
19
  from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image
20
20
  from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
21
- from sglang.srt.utils import load_image, logger
21
+ from sglang.srt.utils import ImageData, load_image, logger
22
22
  from sglang.utils import get_exception_traceback
23
23
 
24
24
 
@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
35
35
 
36
36
  @staticmethod
37
37
  def _process_single_image_task(
38
- image_data: Union[str, bytes],
38
+ image_data: Union[str, bytes, ImageData],
39
39
  image_aspect_ratio: Optional[str] = None,
40
40
  image_grid_pinpoints: Optional[str] = None,
41
41
  processor=None,
@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
44
44
  image_processor = processor.image_processor
45
45
 
46
46
  try:
47
- image, image_size = load_image(image_data)
47
+ url = image_data.url if isinstance(image_data, ImageData) else image_data
48
+ image, image_size = load_image(url)
48
49
  if image_size is not None:
49
50
  # It is a video with multiple images
50
- image_hash = hash(image_data)
51
+ image_hash = hash(url)
51
52
  pixel_values = image_processor(image)["pixel_values"]
52
53
  for _ in range(len(pixel_values)):
53
54
  pixel_values[_] = pixel_values[_].astype(np.float16)
@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
55
56
  return pixel_values, image_hash, image_size
56
57
  else:
57
58
  # It is an image
58
- image_hash = hash(image_data)
59
+ image_hash = hash(url)
59
60
  if image_aspect_ratio == "pad":
60
61
  image = expand2square(
61
62
  image,
@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
82
83
  logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
83
84
 
84
85
  async def _process_single_image(
85
- self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
86
+ self,
87
+ image_data: Union[bytes, str, ImageData],
88
+ aspect_ratio: str,
89
+ grid_pinpoints: str,
86
90
  ):
87
91
  if self.cpu_executor is not None:
88
92
  loop = asyncio.get_event_loop()
@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
104
108
 
105
109
  async def process_mm_data_async(
106
110
  self,
107
- image_data: List[Union[str, bytes]],
111
+ image_data: List[Union[str, bytes, ImageData]],
108
112
  input_text,
109
113
  request_obj,
110
114
  *args,