sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
89
89
  intermediate_size=intermediate_size,
90
90
  reduce_results=True,
91
91
  quant_config=quant_config,
92
- tp_size=tp_size,
93
92
  layer_id=layer_id,
94
93
  prefix=add_prefix("experts", prefix),
95
94
  )
@@ -54,25 +54,6 @@ VISION_ENCODER_TO_PROCESSING_CONFIG = {
54
54
  }
55
55
 
56
56
 
57
- def get_navit_vision_model():
58
- vision_config = {
59
- "hidden_size": 1152,
60
- "image_size": 448,
61
- "intermediate_size": 4304,
62
- "model_type": "siglip_vision_model",
63
- "num_attention_heads": 16,
64
- "num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
65
- "patch_size": 14,
66
- }
67
- model_config = SiglipVisionConfig(**vision_config)
68
-
69
- vision_model = Idefics2VisionTransformer(
70
- config=model_config, require_post_norm=False
71
- )
72
-
73
- return vision_model
74
-
75
-
76
57
  class Phi4MMImageEncoder(nn.Module):
77
58
  """Image embedding."""
78
59
 
@@ -88,8 +69,9 @@ class Phi4MMImageEncoder(nn.Module):
88
69
  # n_embed or hidden_size
89
70
  hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
90
71
  self.type_feature = "patch"
91
-
92
- self.img_processor = get_navit_vision_model()
72
+ self.img_processor = Idefics2VisionTransformer(
73
+ config=config.vision_config, require_post_norm=False
74
+ )
93
75
 
94
76
  pe_weight = self.img_processor.embeddings.position_embedding.weight
95
77
  L, D = pe_weight.size()
@@ -117,6 +117,7 @@ class Qwen2_5_VisionBlock(nn.Module):
117
117
  attn_implementation: Optional[str] = None,
118
118
  quant_config: Optional[QuantizationConfig] = None,
119
119
  prefix: str = "",
120
+ num_dummy_heads: int = 0,
120
121
  ) -> None:
121
122
  super().__init__()
122
123
  if norm_layer is None:
@@ -157,6 +158,7 @@ class Qwen2_5_VisionBlock(nn.Module):
157
158
  flatten_batch=flatten_batch,
158
159
  quant_config=quant_config,
159
160
  prefix=add_prefix("attn", prefix),
161
+ num_dummy_heads=num_dummy_heads,
160
162
  )
161
163
  self.mlp = Qwen2_5_VLMLP(
162
164
  dim,
@@ -17,8 +17,6 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from dataclasses import dataclass
21
- from enum import Enum, auto
22
20
  from typing import Any, Dict, Iterable, Optional, Tuple, Union
23
21
 
24
22
  import torch
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
31
29
  get_tensor_model_parallel_world_size,
32
30
  tensor_model_parallel_all_reduce,
33
31
  )
34
- from sglang.srt.eplb.expert_distribution import (
35
- ExpertDistributionRecorder,
36
- get_global_expert_distribution_recorder,
37
- )
32
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
38
33
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
39
34
  from sglang.srt.layers.activation import SiluAndMul
40
35
  from sglang.srt.layers.communicator import (
@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
45
40
  from sglang.srt.layers.dp_attention import (
46
41
  get_attention_tp_rank,
47
42
  get_attention_tp_size,
48
- get_local_attention_dp_size,
49
43
  is_dp_attention_enabled,
50
44
  )
51
45
  from sglang.srt.layers.layernorm import RMSNorm
@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
55
49
  ReplicatedLinear,
56
50
  RowParallelLinear,
57
51
  )
58
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
59
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
52
+ from sglang.srt.layers.logits_processor import LogitsProcessor
53
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
60
54
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
61
55
  from sglang.srt.layers.moe.topk import TopK
62
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
149
143
  intermediate_size=config.moe_intermediate_size,
150
144
  quant_config=quant_config,
151
145
  prefix=add_prefix("experts", prefix),
152
- # Additional args for FusedMoE
153
- **(
154
- dict(
155
- enable_flashinfer_cutlass_moe=True,
156
- )
157
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
158
- else {}
159
- ),
160
146
  )
161
147
 
162
148
  self.gate = ReplicatedLinear(
@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
340
326
 
341
327
  self.attn_tp_size = get_attention_tp_size()
342
328
  self.attn_tp_rank = get_attention_tp_rank()
343
- self.local_dp_size = get_local_attention_dp_size()
344
329
 
345
330
  # Qwen2MoE all layers are sparse and have no nextn now
346
331
  self.is_layer_sparse = True
@@ -327,8 +327,8 @@ class Qwen3ForCausalLM(nn.Module):
327
327
  # For EAGLE3 support
328
328
  self.capture_aux_hidden_states = False
329
329
 
330
- def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
331
- return self.model.get_input_embeddings(input_ids)
330
+ def get_input_embeddings(self) -> nn.Embedding:
331
+ return self.model.get_input_embeddings()
332
332
 
333
333
  @torch.no_grad()
334
334
  def forward(
@@ -42,7 +42,13 @@ class Qwen3ForSequenceClassification(nn.Module):
42
42
  # Use normalize=True for qwen3 embedding based on official implementation
43
43
  # Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
44
44
  # Official code: output = F.normalize(output, p=2, dim=1)
45
- self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
45
+ normalize = True
46
+
47
+ # We don't want to normalize the embedding if we have a classification head
48
+ if config.id2label is not None or config.label2id is not None:
49
+ normalize = False
50
+
51
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)
46
52
 
47
53
  self.eos_token_id = config.eos_token_id
48
54
 
@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
28
28
  get_pp_group,
29
29
  get_tensor_model_parallel_rank,
30
30
  get_tensor_model_parallel_world_size,
31
- parallel_state,
32
- split_tensor_along_last_dim,
33
- tensor_model_parallel_all_gather,
34
31
  tensor_model_parallel_all_reduce,
35
32
  )
36
33
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37
34
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
38
35
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
39
- from sglang.srt.layers.activation import SiluAndMul
40
36
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
41
- from sglang.srt.layers.dp_attention import (
42
- get_attention_tp_rank,
43
- get_attention_tp_size,
44
- get_local_attention_dp_size,
45
- )
37
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
46
38
  from sglang.srt.layers.layernorm import RMSNorm
47
39
  from sglang.srt.layers.linear import (
48
- MergedColumnParallelLinear,
49
40
  QKVParallelLinear,
50
41
  ReplicatedLinear,
51
42
  RowParallelLinear,
52
43
  )
53
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
44
+ from sglang.srt.layers.logits_processor import LogitsProcessor
45
+ from sglang.srt.layers.moe import get_moe_a2a_backend
54
46
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
47
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
55
48
  from sglang.srt.layers.moe.topk import TopK
56
49
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
50
  from sglang.srt.layers.radix_attention import RadixAttention
58
51
  from sglang.srt.layers.rotary_embedding import get_rope
59
52
  from sglang.srt.layers.utils import get_layer_id
60
- from sglang.srt.layers.vocab_parallel_embedding import (
61
- ParallelLMHead,
62
- VocabParallelEmbedding,
63
- )
53
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
64
54
  from sglang.srt.managers.schedule_batch import global_server_args_dict
65
55
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
66
- from sglang.srt.model_executor.forward_batch_info import (
67
- ForwardBatch,
68
- ForwardMode,
69
- PPProxyTensors,
70
- )
56
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
71
57
  from sglang.srt.model_loader.weight_utils import default_weight_loader
72
58
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
73
59
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
74
- from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
75
60
  from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
76
61
 
77
62
  Qwen3MoeConfig = None
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
112
97
  intermediate_size=config.moe_intermediate_size,
113
98
  quant_config=quant_config,
114
99
  prefix=add_prefix("experts", prefix),
115
- **(
116
- dict(deepep_mode=global_server_args_dict["deepep_mode"])
117
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
118
- else {}
119
- ),
120
- # Additional args for FusedMoE
121
- **(
122
- dict(
123
- enable_flashinfer_cutlass_moe=True,
124
- )
125
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
126
- else {}
127
- ),
128
100
  )
129
101
 
130
102
  self.gate = ReplicatedLinear(
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
135
107
  prefix=add_prefix("gate", prefix),
136
108
  )
137
109
 
138
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
110
+ if get_moe_a2a_backend().is_deepep():
139
111
  # TODO: we will support tp < ep in the future
140
112
  self.ep_size = get_moe_expert_parallel_world_size()
141
113
  self.num_experts = (
@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
150
122
  use_reduce_scatter: bool = False,
151
123
  ) -> torch.Tensor:
152
124
 
153
- if not global_server_args_dict["moe_a2a_backend"].is_deepep():
125
+ if not get_moe_a2a_backend().is_deepep():
154
126
  return self.forward_normal(hidden_states, use_reduce_scatter)
155
127
  else:
156
128
  return self.forward_deepep(hidden_states, forward_batch)
@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
491
463
 
492
464
  self.attn_tp_size = get_attention_tp_size()
493
465
  self.attn_tp_rank = get_attention_tp_rank()
494
- self.local_dp_size = get_local_attention_dp_size()
495
466
 
496
467
  # Qwen3MoE all layers are sparse and have no nextn now
497
468
  self.is_layer_sparse = True
@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
778
749
  ("gate_up_proj", "up_proj", 1),
779
750
  ]
780
751
 
781
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
752
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
782
753
  ckpt_gate_proj_name="gate_proj",
783
754
  ckpt_down_proj_name="down_proj",
784
755
  ckpt_up_proj_name="up_proj",
@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
38
38
  RowParallelLinear,
39
39
  )
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
+ from sglang.srt.layers.moe import get_moe_a2a_backend
41
42
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
42
43
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
43
44
  from sglang.srt.layers.moe.topk import TopK
@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
150
151
  prefix=add_prefix("gate", prefix),
151
152
  )
152
153
 
153
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
154
+ if get_moe_a2a_backend().is_deepep():
154
155
  raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
155
156
 
156
157
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
33
33
  RowParallelLinear,
34
34
  )
35
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
- from sglang.srt.layers.moe.fused_moe_triton import fused_moe
36
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
37
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
38
+ from sglang.srt.layers.moe.topk import TopK
37
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
41
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
121
123
  ]
122
124
  )
123
125
  self.pack_params()
126
+ self.moe_runner_config = MoeRunnerConfig(inplace=True)
124
127
 
125
128
  self.router = ReplicatedLinear(
126
129
  config.hidden_size,
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
129
132
  quant_config=None,
130
133
  prefix=add_prefix("router", prefix),
131
134
  )
135
+ self.topk = TopK(
136
+ top_k=self.top_k,
137
+ renormalize=getattr(self.config, "norm_topk_prob", False),
138
+ )
132
139
 
133
140
  if config.num_shared_experts is not None:
134
141
  intermediate_size = config.intermediate_size * config.num_shared_experts
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
167
174
  shared_output = self.shared_experts(hidden_states)
168
175
  # router_logits: (num_tokens, n_experts)
169
176
  router_logits, _ = self.router(hidden_states)
177
+ topk_output = self.topk(hidden_states, router_logits)
170
178
  final_hidden_states = fused_moe(
171
179
  hidden_states,
172
180
  self.w1,
173
181
  self.w2,
174
- router_logits,
175
- self.top_k,
176
- renormalize=getattr(self.config, "norm_topk_prob", False),
177
- inplace=True,
182
+ topk_output,
183
+ self.moe_runner_config,
178
184
  )
179
185
 
180
186
  if self.config.num_shared_experts is not None:
@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC):
217
217
  if videos:
218
218
  kwargs["videos"] = videos
219
219
  if audios:
220
- if self.arch in {
221
- "Gemma3nForConditionalGeneration",
222
- "Qwen2AudioForConditionalGeneration",
220
+ if self._processor.__class__.__name__ in {
221
+ "Gemma3nProcessor",
222
+ "Qwen2AudioProcessor",
223
223
  }:
224
224
  # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
225
225
  kwargs["audio"] = audios
@@ -44,7 +44,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
44
44
  self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
45
45
  self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
46
46
  self.mm_tokens = MultimodalSpecialTokens(
47
- image_token="<image>",
47
+ image_token="<IMG_CONTEXT>",
48
48
  image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
49
49
  ).build(_image_processor)
50
50
 
@@ -218,13 +218,18 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
218
218
 
219
219
  pixel_values = torch.cat(pixel_values, dim=0)
220
220
 
221
+ original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
222
+ input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
223
+
221
224
  for idx, num_patches in enumerate(num_patches_list):
222
225
  image_tokens = (
223
226
  self.IMG_START_TOKEN
224
227
  + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
225
228
  + self.IMG_END_TOKEN
226
229
  )
227
- input_text = input_text.replace("<image>", image_tokens, 1)
230
+ input_text = input_text.replace(original_placeholder, image_tokens, 1)
231
+
232
+ input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN)
228
233
 
229
234
  input_ids = self.tokenizer(input_text, return_tensors="pt")[
230
235
  "input_ids"
@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM
18
18
  from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
19
19
  from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image
20
20
  from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
21
- from sglang.srt.utils import load_image, logger
21
+ from sglang.srt.utils import ImageData, load_image, logger
22
22
  from sglang.utils import get_exception_traceback
23
23
 
24
24
 
@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
35
35
 
36
36
  @staticmethod
37
37
  def _process_single_image_task(
38
- image_data: Union[str, bytes],
38
+ image_data: Union[str, bytes, ImageData],
39
39
  image_aspect_ratio: Optional[str] = None,
40
40
  image_grid_pinpoints: Optional[str] = None,
41
41
  processor=None,
@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
44
44
  image_processor = processor.image_processor
45
45
 
46
46
  try:
47
- image, image_size = load_image(image_data)
47
+ url = image_data.url if isinstance(image_data, ImageData) else image_data
48
+ image, image_size = load_image(url)
48
49
  if image_size is not None:
49
50
  # It is a video with multiple images
50
- image_hash = hash(image_data)
51
+ image_hash = hash(url)
51
52
  pixel_values = image_processor(image)["pixel_values"]
52
53
  for _ in range(len(pixel_values)):
53
54
  pixel_values[_] = pixel_values[_].astype(np.float16)
@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
55
56
  return pixel_values, image_hash, image_size
56
57
  else:
57
58
  # It is an image
58
- image_hash = hash(image_data)
59
+ image_hash = hash(url)
59
60
  if image_aspect_ratio == "pad":
60
61
  image = expand2square(
61
62
  image,
@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
82
83
  logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
83
84
 
84
85
  async def _process_single_image(
85
- self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
86
+ self,
87
+ image_data: Union[bytes, str, ImageData],
88
+ aspect_ratio: str,
89
+ grid_pinpoints: str,
86
90
  ):
87
91
  if self.cpu_executor is not None:
88
92
  loop = asyncio.get_event_loop()
@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
104
108
 
105
109
  async def process_mm_data_async(
106
110
  self,
107
- image_data: List[Union[str, bytes]],
111
+ image_data: List[Union[str, bytes, ImageData]],
108
112
  input_text,
109
113
  request_obj,
110
114
  *args,