sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from sglang.srt.layers.linear import (
13
13
  )
14
14
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
15
15
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
16
+ from sglang.srt.layers.moe.topk import TopK
16
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
17
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
19
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -200,15 +201,19 @@ class PhiMoE(nn.Module):
200
201
  quant_config=None,
201
202
  )
202
203
 
204
+ self.topk = TopK(
205
+ top_k=top_k,
206
+ renormalize=False,
207
+ custom_routing_function=phimoe_routing_function,
208
+ )
209
+
203
210
  self.experts = FusedMoE(
204
211
  num_experts=num_experts,
205
212
  top_k=top_k,
206
213
  hidden_size=hidden_size,
207
214
  intermediate_size=intermediate_size,
208
215
  reduce_results=True,
209
- renormalize=False,
210
216
  quant_config=quant_config,
211
- custom_routing_function=phimoe_routing_function,
212
217
  prefix=add_prefix("experts", prefix),
213
218
  )
214
219
 
@@ -219,7 +224,8 @@ class PhiMoE(nn.Module):
219
224
  orig_shape = hidden_states.shape
220
225
  hidden_states = hidden_states.view(-1, self.hidden_size)
221
226
  router_logits, _ = self.gate(hidden_states)
222
- final_hidden_states = self.experts(hidden_states, router_logits)
227
+ topk_output = self.topk(hidden_states, router_logits)
228
+ final_hidden_states = self.experts(hidden_states, topk_output)
223
229
  return final_hidden_states.view(orig_shape)
224
230
 
225
231
 
sglang/srt/models/qwen.py CHANGED
@@ -15,6 +15,7 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
17
17
 
18
+ import time
18
19
  from typing import Any, Dict, Iterable, Optional, Tuple
19
20
 
20
21
  import torch
@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
286
287
  input_ids, hidden_states, self.lm_head, forward_batch
287
288
  )
288
289
 
290
+ @torch.no_grad()
291
+ def forward_split_prefill(
292
+ self,
293
+ input_ids: torch.Tensor,
294
+ positions: torch.Tensor,
295
+ forward_batch: ForwardBatch,
296
+ split_interval: Tuple[int, int], # [start, end) 0-based
297
+ ):
298
+ start, end = split_interval
299
+ # embed
300
+ if start == 0:
301
+ forward_batch.hidden_states = self.transformer.wte(input_ids)
302
+
303
+ # decoder layer
304
+ for i in range(start, end):
305
+ layer = self.transformer.h[i]
306
+ forward_batch.hidden_states = layer(
307
+ positions,
308
+ forward_batch.hidden_states,
309
+ forward_batch,
310
+ )
311
+
312
+ if end == self.transformer.config.num_hidden_layers:
313
+ # norm
314
+ forward_batch.hidden_states = self.transformer.ln_f(
315
+ forward_batch.hidden_states
316
+ )
317
+ # logits process
318
+ result = self.logits_processor(
319
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
320
+ )
321
+ else:
322
+ result = None
323
+
324
+ return result
325
+
289
326
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
290
327
  stacked_params_mapping = [
291
328
  # (param_name, shard_name, shard_id)
@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
481
481
  else:
482
482
  return hidden_states
483
483
 
484
+ @torch.no_grad()
485
+ def forward_split_prefill(
486
+ self,
487
+ input_ids: torch.Tensor,
488
+ positions: torch.Tensor,
489
+ forward_batch: ForwardBatch,
490
+ split_interval: Tuple[int, int], # [start, end) 0-based
491
+ input_embeds: torch.Tensor = None,
492
+ ):
493
+ start, end = split_interval
494
+ # embed
495
+ if start == 0:
496
+ if input_embeds is None:
497
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
498
+ else:
499
+ forward_batch.hidden_states = input_embeds
500
+ # decoder layer
501
+ for i in range(start, end):
502
+ layer = self.model.layers[i]
503
+ forward_batch.hidden_states, forward_batch.residual = layer(
504
+ positions,
505
+ forward_batch.hidden_states,
506
+ forward_batch,
507
+ forward_batch.residual,
508
+ )
509
+
510
+ if end == self.model.config.num_hidden_layers:
511
+ # norm
512
+ hidden_states, _ = self.model.norm(
513
+ forward_batch.hidden_states, forward_batch.residual
514
+ )
515
+ forward_batch.hidden_states = hidden_states
516
+ # logits process
517
+ result = self.logits_processor(
518
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
519
+ )
520
+ else:
521
+ result = None
522
+
523
+ return result
524
+
484
525
  @property
485
526
  def start_layer(self):
486
527
  return self.model.start_layer
@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
497
497
 
498
498
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
499
499
  # in qwen-vl, last dim is the same
500
- pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
500
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
501
501
  self.visual.dtype
502
502
  )
503
503
  image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
508
508
 
509
509
  def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
510
510
  # in qwen-vl, last dim is the same
511
- pixel_values = torch.cat(
512
- [getattr(item, "pixel_values_videos") for item in items], dim=0
513
- ).type(self.visual.dtype)
511
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
512
+ self.visual.dtype
513
+ )
514
514
  video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
515
515
  assert pixel_values.dim() == 2, pixel_values.dim()
516
516
  assert video_grid_thw.dim() == 2, video_grid_thw.dim()
@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
118
118
 
119
119
  def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
120
120
  # Extract audio features from input items
121
- input_features = torch.cat([item.audio_features for item in items], dim=0).type(
121
+ input_features = torch.cat([item.feature for item in items], dim=0).type(
122
122
  self.audio_tower.dtype
123
123
  )
124
124
 
@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import (
43
43
  ScatterMode,
44
44
  )
45
45
  from sglang.srt.layers.dp_attention import (
46
- attn_tp_all_gather,
47
- attn_tp_reduce_scatter,
48
- dp_gather_partial,
49
- dp_scatter,
50
46
  get_attention_tp_rank,
51
47
  get_attention_tp_size,
52
48
  get_local_attention_dp_size,
@@ -61,6 +57,7 @@ from sglang.srt.layers.linear import (
61
57
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
62
58
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
63
59
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
60
+ from sglang.srt.layers.moe.topk import TopK
64
61
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
65
62
  from sglang.srt.layers.radix_attention import RadixAttention
66
63
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -134,13 +131,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
134
131
  f"the number of experts {config.num_experts}."
135
132
  )
136
133
 
134
+ self.topk = TopK(
135
+ top_k=config.num_experts_per_tok,
136
+ renormalize=config.norm_topk_prob,
137
+ )
138
+
137
139
  self.experts = get_moe_impl_class()(
138
140
  layer_id=self.layer_id,
139
- num_experts=config.num_experts,
140
141
  top_k=config.num_experts_per_tok,
142
+ num_experts=config.num_experts,
141
143
  hidden_size=config.hidden_size,
142
144
  intermediate_size=config.moe_intermediate_size,
143
- renormalize=config.norm_topk_prob,
144
145
  quant_config=quant_config,
145
146
  prefix=add_prefix("experts", prefix),
146
147
  # Additional args for FusedMoE
@@ -189,9 +190,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
189
190
 
190
191
  # router_logits: (num_tokens, n_experts)
191
192
  router_logits, _ = self.gate(hidden_states)
192
- final_hidden_states = self.experts(
193
- hidden_states=hidden_states, router_logits=router_logits
194
- )
193
+ topk_output = self.topk(hidden_states, router_logits)
194
+ final_hidden_states = self.experts(hidden_states, topk_output)
195
195
  if shared_output is not None:
196
196
  final_hidden_states = final_hidden_states + shared_output
197
197
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
@@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module):
406
406
  alt_stream: Optional[torch.cuda.Stream] = None,
407
407
  ) -> None:
408
408
  super().__init__()
409
+ self.config = config
409
410
  self.padding_idx = config.pad_token_id
410
411
  self.vocab_size = config.vocab_size
411
412
  self.pp_group = get_pp_group()
@@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module):
554
555
  else:
555
556
  return hidden_states
556
557
 
558
+ @torch.no_grad()
559
+ def forward_split_prefill(
560
+ self,
561
+ input_ids: torch.Tensor,
562
+ positions: torch.Tensor,
563
+ forward_batch: ForwardBatch,
564
+ split_interval: Tuple[int, int], # [start, end) 0-based
565
+ input_embeds: torch.Tensor = None,
566
+ ):
567
+ start, end = split_interval
568
+ # embed
569
+ if start == 0:
570
+ if input_embeds is None:
571
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
572
+ else:
573
+ forward_batch.hidden_states = input_embeds
574
+
575
+ # decoder layer
576
+ for i in range(start, end):
577
+ with get_global_expert_distribution_recorder().with_current_layer(i):
578
+ layer = self.model.layers[i]
579
+ forward_batch.hidden_states, forward_batch.residual = layer(
580
+ positions,
581
+ forward_batch.hidden_states,
582
+ forward_batch,
583
+ forward_batch.residual,
584
+ )
585
+
586
+ if end == self.model.config.num_hidden_layers:
587
+ # norm
588
+ hidden_states, _ = self.model.norm(
589
+ forward_batch.hidden_states, forward_batch.residual
590
+ )
591
+ forward_batch.hidden_states = hidden_states
592
+ # logits process
593
+ result = self.logits_processor(
594
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
595
+ )
596
+ else:
597
+ result = None
598
+
599
+ return result
600
+
557
601
  @property
558
602
  def start_layer(self):
559
603
  return self.model.start_layer
@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
484
484
 
485
485
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
486
486
  # in qwen-vl, last dim is the same
487
- pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
487
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
488
488
  self.visual.dtype
489
489
  )
490
490
  image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
495
495
 
496
496
  def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
497
497
  # in qwen-vl, last dim is the same
498
- pixel_values = torch.cat(
499
- [item.pixel_values_videos for item in items], dim=0
500
- ).type(self.visual.dtype)
498
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
499
+ self.visual.dtype
500
+ )
501
501
  video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
502
502
  assert pixel_values.dim() == 2, pixel_values.dim()
503
503
  assert video_grid_thw.dim() == 2, video_grid_thw.dim()
@@ -1,5 +1,4 @@
1
1
  # Adapted from qwen2.py
2
-
3
2
  import logging
4
3
  from functools import partial
5
4
  from typing import Any, Dict, Iterable, List, Optional, Tuple
@@ -331,6 +330,30 @@ class Qwen3ForCausalLM(nn.Module):
331
330
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
332
331
  return self.model.get_input_embeddings(input_ids)
333
332
 
333
+ def get_hidden_dim(self, module_name: str) -> Tuple[int]:
334
+ # return input_dim, output_dim
335
+ if module_name in ["q_proj", "qkv_proj"]:
336
+ return (
337
+ self.config.hidden_size,
338
+ self.config.head_dim * self.config.num_attention_heads,
339
+ )
340
+ elif module_name in ["o_proj"]:
341
+ return (
342
+ self.config.head_dim * self.config.num_attention_heads,
343
+ self.config.hidden_size,
344
+ )
345
+ elif module_name in ["kv_proj"]:
346
+ return (
347
+ self.config.hidden_size,
348
+ self.config.head_dim * self.config.num_key_value_heads,
349
+ )
350
+ elif module_name == "gate_up_proj":
351
+ return self.config.hidden_size, self.config.intermediate_size
352
+ elif module_name == "down_proj":
353
+ return self.config.intermediate_size, self.config.hidden_size
354
+ else:
355
+ raise NotImplementedError()
356
+
334
357
  @torch.no_grad()
335
358
  def forward(
336
359
  self,
@@ -367,6 +390,47 @@ class Qwen3ForCausalLM(nn.Module):
367
390
  else:
368
391
  return hidden_states
369
392
 
393
+ @torch.no_grad()
394
+ def forward_split_prefill(
395
+ self,
396
+ input_ids: torch.Tensor,
397
+ positions: torch.Tensor,
398
+ forward_batch: ForwardBatch,
399
+ split_interval: Tuple[int, int], # [start, end) 0-based
400
+ input_embeds: torch.Tensor = None,
401
+ ):
402
+ start, end = split_interval
403
+ # embed
404
+ if start == 0:
405
+ if input_embeds is None:
406
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
407
+ else:
408
+ forward_batch.hidden_states = input_embeds
409
+ # decoder layer
410
+ for i in range(start, end):
411
+ layer = self.model.layers[i]
412
+ forward_batch.hidden_states, forward_batch.residual = layer(
413
+ positions,
414
+ forward_batch.hidden_states,
415
+ forward_batch,
416
+ forward_batch.residual,
417
+ )
418
+
419
+ if end == self.model.config.num_hidden_layers:
420
+ # norm
421
+ hidden_states, _ = self.model.norm(
422
+ forward_batch.hidden_states, forward_batch.residual
423
+ )
424
+ forward_batch.hidden_states = hidden_states
425
+ # logits process
426
+ result = self.logits_processor(
427
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
428
+ )
429
+ else:
430
+ result = None
431
+
432
+ return result
433
+
370
434
  @property
371
435
  def start_layer(self):
372
436
  return self.model.start_layer
@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
40
40
  from sglang.srt.layers.dp_attention import (
41
- attn_tp_all_gather,
42
- attn_tp_reduce_scatter,
43
- dp_gather_partial,
44
- dp_scatter,
45
41
  get_attention_tp_rank,
46
42
  get_attention_tp_size,
47
43
  get_local_attention_dp_size,
@@ -56,8 +52,7 @@ from sglang.srt.layers.linear import (
56
52
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
57
53
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
58
54
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
59
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
60
- from sglang.srt.layers.moe.topk import select_experts
55
+ from sglang.srt.layers.moe.topk import TopK
61
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
62
57
  from sglang.srt.layers.radix_attention import RadixAttention
63
58
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -102,6 +97,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
102
97
  f"the number of experts {config.num_experts}."
103
98
  )
104
99
 
100
+ self.topk = TopK(
101
+ top_k=config.num_experts_per_tok,
102
+ renormalize=config.norm_topk_prob,
103
+ use_grouped_topk=False,
104
+ )
105
+
105
106
  self.experts = get_moe_impl_class()(
106
107
  num_experts=config.num_experts
107
108
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -109,7 +110,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
109
110
  layer_id=layer_id,
110
111
  hidden_size=config.hidden_size,
111
112
  intermediate_size=config.moe_intermediate_size,
112
- renormalize=config.norm_topk_prob,
113
113
  quant_config=quant_config,
114
114
  prefix=add_prefix("experts", prefix),
115
115
  **(
@@ -143,7 +143,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
143
143
  config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
144
144
  )
145
145
  self.top_k = config.num_experts_per_tok
146
- self.renormalize = config.norm_topk_prob
147
146
 
148
147
  self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
149
148
  group=parallel_state.get_tp_group().device_group,
@@ -180,9 +179,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
180
179
 
181
180
  # router_logits: (num_tokens, n_experts)
182
181
  router_logits, _ = self.gate(hidden_states)
183
- final_hidden_states = self.experts(
184
- hidden_states=hidden_states, router_logits=router_logits
185
- )
182
+ topk_output = self.topk(hidden_states, router_logits)
183
+ final_hidden_states = self.experts(hidden_states, topk_output)
186
184
  if self.tp_size > 1:
187
185
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
188
186
 
@@ -191,17 +189,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
191
189
  def forward_deepep(
192
190
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
193
191
  ) -> torch.Tensor:
194
- forward_mode = forward_batch.forward_mode
195
- if is_non_idle_and_non_empty(forward_mode, hidden_states):
192
+ if hidden_states.shape[0] > 0:
196
193
  # router_logits: (num_tokens, n_experts)
197
194
  router_logits, _ = self.gate(hidden_states)
198
-
199
- topk_weights, topk_idx = select_experts(
200
- hidden_states=hidden_states,
201
- router_logits=router_logits,
202
- top_k=self.top_k,
203
- use_grouped_topk=False,
204
- renormalize=self.renormalize,
195
+ topk_weights, topk_idx, _ = self.topk(
196
+ hidden_states,
197
+ router_logits,
205
198
  num_token_non_padded=forward_batch.num_token_non_padded,
206
199
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
207
200
  layer_id=self.layer_id,
@@ -267,12 +260,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
267
260
  with get_global_expert_distribution_recorder().with_current_layer(
268
261
  self.layer_id
269
262
  ):
270
- state.topk_weights_local, state.topk_idx_local = select_experts(
263
+ state.topk_weights_local, state.topk_idx_local, _ = self.topk(
271
264
  hidden_states=hidden_states,
272
265
  router_logits=router_logits,
273
- top_k=self.top_k,
274
- use_grouped_topk=False,
275
- renormalize=self.renormalize,
276
266
  num_token_non_padded=state.forward_batch.num_token_non_padded,
277
267
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
278
268
  layer_id=self.layer_id,
@@ -745,6 +735,49 @@ class Qwen3MoeForCausalLM(nn.Module):
745
735
  else:
746
736
  return hidden_states
747
737
 
738
+ @torch.no_grad()
739
+ def forward_split_prefill(
740
+ self,
741
+ input_ids: torch.Tensor,
742
+ positions: torch.Tensor,
743
+ forward_batch: ForwardBatch,
744
+ split_interval: Tuple[int, int], # [start, end) 0-based
745
+ input_embeds: torch.Tensor = None,
746
+ ):
747
+ start, end = split_interval
748
+ # embed
749
+ if start == 0:
750
+ if input_embeds is None:
751
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
752
+ else:
753
+ forward_batch.hidden_states = input_embeds
754
+
755
+ # decoder layer
756
+ for i in range(start, end):
757
+ with get_global_expert_distribution_recorder().with_current_layer(i):
758
+ layer = self.model.layers[i]
759
+ forward_batch.hidden_states, forward_batch.residual = layer(
760
+ positions,
761
+ forward_batch.hidden_states,
762
+ forward_batch,
763
+ forward_batch.residual,
764
+ )
765
+
766
+ if end == self.model.config.num_hidden_layers:
767
+ # norm
768
+ hidden_states, _ = self.model.norm(
769
+ forward_batch.hidden_states, forward_batch.residual
770
+ )
771
+ forward_batch.hidden_states = hidden_states
772
+ # logits process
773
+ result = self.logits_processor(
774
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
775
+ )
776
+ else:
777
+ result = None
778
+
779
+ return result
780
+
748
781
  @property
749
782
  def start_layer(self):
750
783
  return self.model.start_layer
sglang/srt/models/vila.py CHANGED
@@ -237,7 +237,7 @@ class VILAForConditionalGeneration(nn.Module):
237
237
  return cast(LogitsProcessorOutput, output)
238
238
 
239
239
  def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
240
- pixel_values = cast(Tensor, mm_input[0].pixel_values)
240
+ pixel_values = cast(Tensor, mm_input[0].feature)
241
241
 
242
242
  ##### BEGIN COPY modeling_vila.py #####
243
243