sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from sglang.srt.layers.linear import (
13
13
  )
14
14
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
15
15
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
16
+ from sglang.srt.layers.moe.topk import TopK
16
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
17
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
19
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -200,15 +201,19 @@ class PhiMoE(nn.Module):
200
201
  quant_config=None,
201
202
  )
202
203
 
204
+ self.topk = TopK(
205
+ top_k=top_k,
206
+ renormalize=False,
207
+ custom_routing_function=phimoe_routing_function,
208
+ )
209
+
203
210
  self.experts = FusedMoE(
204
211
  num_experts=num_experts,
205
212
  top_k=top_k,
206
213
  hidden_size=hidden_size,
207
214
  intermediate_size=intermediate_size,
208
215
  reduce_results=True,
209
- renormalize=False,
210
216
  quant_config=quant_config,
211
- custom_routing_function=phimoe_routing_function,
212
217
  prefix=add_prefix("experts", prefix),
213
218
  )
214
219
 
@@ -219,7 +224,8 @@ class PhiMoE(nn.Module):
219
224
  orig_shape = hidden_states.shape
220
225
  hidden_states = hidden_states.view(-1, self.hidden_size)
221
226
  router_logits, _ = self.gate(hidden_states)
222
- final_hidden_states = self.experts(hidden_states, router_logits)
227
+ topk_output = self.topk(hidden_states, router_logits)
228
+ final_hidden_states = self.experts(hidden_states, topk_output)
223
229
  return final_hidden_states.view(orig_shape)
224
230
 
225
231
 
sglang/srt/models/qwen.py CHANGED
@@ -15,6 +15,7 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
17
17
 
18
+ import time
18
19
  from typing import Any, Dict, Iterable, Optional, Tuple
19
20
 
20
21
  import torch
@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
286
287
  input_ids, hidden_states, self.lm_head, forward_batch
287
288
  )
288
289
 
290
+ @torch.no_grad()
291
+ def forward_split_prefill(
292
+ self,
293
+ input_ids: torch.Tensor,
294
+ positions: torch.Tensor,
295
+ forward_batch: ForwardBatch,
296
+ split_interval: Tuple[int, int], # [start, end) 0-based
297
+ ):
298
+ start, end = split_interval
299
+ # embed
300
+ if start == 0:
301
+ forward_batch.hidden_states = self.transformer.wte(input_ids)
302
+
303
+ # decoder layer
304
+ for i in range(start, end):
305
+ layer = self.transformer.h[i]
306
+ forward_batch.hidden_states = layer(
307
+ positions,
308
+ forward_batch.hidden_states,
309
+ forward_batch,
310
+ )
311
+
312
+ if end == self.transformer.config.num_hidden_layers:
313
+ # norm
314
+ forward_batch.hidden_states = self.transformer.ln_f(
315
+ forward_batch.hidden_states
316
+ )
317
+ # logits process
318
+ result = self.logits_processor(
319
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
320
+ )
321
+ else:
322
+ result = None
323
+
324
+ return result
325
+
289
326
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
290
327
  stacked_params_mapping = [
291
328
  # (param_name, shard_name, shard_id)
@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
481
481
  else:
482
482
  return hidden_states
483
483
 
484
+ @torch.no_grad()
485
+ def forward_split_prefill(
486
+ self,
487
+ input_ids: torch.Tensor,
488
+ positions: torch.Tensor,
489
+ forward_batch: ForwardBatch,
490
+ split_interval: Tuple[int, int], # [start, end) 0-based
491
+ input_embeds: torch.Tensor = None,
492
+ ):
493
+ start, end = split_interval
494
+ # embed
495
+ if start == 0:
496
+ if input_embeds is None:
497
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
498
+ else:
499
+ forward_batch.hidden_states = input_embeds
500
+ # decoder layer
501
+ for i in range(start, end):
502
+ layer = self.model.layers[i]
503
+ forward_batch.hidden_states, forward_batch.residual = layer(
504
+ positions,
505
+ forward_batch.hidden_states,
506
+ forward_batch,
507
+ forward_batch.residual,
508
+ )
509
+
510
+ if end == self.model.config.num_hidden_layers:
511
+ # norm
512
+ hidden_states, _ = self.model.norm(
513
+ forward_batch.hidden_states, forward_batch.residual
514
+ )
515
+ forward_batch.hidden_states = hidden_states
516
+ # logits process
517
+ result = self.logits_processor(
518
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
519
+ )
520
+ else:
521
+ result = None
522
+
523
+ return result
524
+
484
525
  @property
485
526
  def start_layer(self):
486
527
  return self.model.start_layer
@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
497
497
 
498
498
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
499
499
  # in qwen-vl, last dim is the same
500
- pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
500
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
501
501
  self.visual.dtype
502
502
  )
503
503
  image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
508
508
 
509
509
  def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
510
510
  # in qwen-vl, last dim is the same
511
- pixel_values = torch.cat(
512
- [getattr(item, "pixel_values_videos") for item in items], dim=0
513
- ).type(self.visual.dtype)
511
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
512
+ self.visual.dtype
513
+ )
514
514
  video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
515
515
  assert pixel_values.dim() == 2, pixel_values.dim()
516
516
  assert video_grid_thw.dim() == 2, video_grid_thw.dim()
@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
118
118
 
119
119
  def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
120
120
  # Extract audio features from input items
121
- input_features = torch.cat([item.audio_features for item in items], dim=0).type(
121
+ input_features = torch.cat([item.feature for item in items], dim=0).type(
122
122
  self.audio_tower.dtype
123
123
  )
124
124
 
@@ -61,6 +61,7 @@ from sglang.srt.layers.linear import (
61
61
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
62
62
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
63
63
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
64
+ from sglang.srt.layers.moe.topk import TopK
64
65
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
65
66
  from sglang.srt.layers.radix_attention import RadixAttention
66
67
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -134,13 +135,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
134
135
  f"the number of experts {config.num_experts}."
135
136
  )
136
137
 
138
+ self.topk = TopK(
139
+ top_k=config.num_experts_per_tok,
140
+ renormalize=config.norm_topk_prob,
141
+ )
142
+
137
143
  self.experts = get_moe_impl_class()(
138
144
  layer_id=self.layer_id,
139
- num_experts=config.num_experts,
140
145
  top_k=config.num_experts_per_tok,
146
+ num_experts=config.num_experts,
141
147
  hidden_size=config.hidden_size,
142
148
  intermediate_size=config.moe_intermediate_size,
143
- renormalize=config.norm_topk_prob,
144
149
  quant_config=quant_config,
145
150
  prefix=add_prefix("experts", prefix),
146
151
  # Additional args for FusedMoE
@@ -189,9 +194,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
189
194
 
190
195
  # router_logits: (num_tokens, n_experts)
191
196
  router_logits, _ = self.gate(hidden_states)
192
- final_hidden_states = self.experts(
193
- hidden_states=hidden_states, router_logits=router_logits
194
- )
197
+ topk_output = self.topk(hidden_states, router_logits)
198
+ final_hidden_states = self.experts(hidden_states, topk_output)
195
199
  if shared_output is not None:
196
200
  final_hidden_states = final_hidden_states + shared_output
197
201
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
@@ -406,6 +410,7 @@ class Qwen2MoeModel(nn.Module):
406
410
  alt_stream: Optional[torch.cuda.Stream] = None,
407
411
  ) -> None:
408
412
  super().__init__()
413
+ self.config = config
409
414
  self.padding_idx = config.pad_token_id
410
415
  self.vocab_size = config.vocab_size
411
416
  self.pp_group = get_pp_group()
@@ -554,6 +559,49 @@ class Qwen2MoeForCausalLM(nn.Module):
554
559
  else:
555
560
  return hidden_states
556
561
 
562
+ @torch.no_grad()
563
+ def forward_split_prefill(
564
+ self,
565
+ input_ids: torch.Tensor,
566
+ positions: torch.Tensor,
567
+ forward_batch: ForwardBatch,
568
+ split_interval: Tuple[int, int], # [start, end) 0-based
569
+ input_embeds: torch.Tensor = None,
570
+ ):
571
+ start, end = split_interval
572
+ # embed
573
+ if start == 0:
574
+ if input_embeds is None:
575
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
576
+ else:
577
+ forward_batch.hidden_states = input_embeds
578
+
579
+ # decoder layer
580
+ for i in range(start, end):
581
+ with get_global_expert_distribution_recorder().with_current_layer(i):
582
+ layer = self.model.layers[i]
583
+ forward_batch.hidden_states, forward_batch.residual = layer(
584
+ positions,
585
+ forward_batch.hidden_states,
586
+ forward_batch,
587
+ forward_batch.residual,
588
+ )
589
+
590
+ if end == self.model.config.num_hidden_layers:
591
+ # norm
592
+ hidden_states, _ = self.model.norm(
593
+ forward_batch.hidden_states, forward_batch.residual
594
+ )
595
+ forward_batch.hidden_states = hidden_states
596
+ # logits process
597
+ result = self.logits_processor(
598
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
599
+ )
600
+ else:
601
+ result = None
602
+
603
+ return result
604
+
557
605
  @property
558
606
  def start_layer(self):
559
607
  return self.model.start_layer
@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
484
484
 
485
485
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
486
486
  # in qwen-vl, last dim is the same
487
- pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
487
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
488
488
  self.visual.dtype
489
489
  )
490
490
  image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
495
495
 
496
496
  def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
497
497
  # in qwen-vl, last dim is the same
498
- pixel_values = torch.cat(
499
- [item.pixel_values_videos for item in items], dim=0
500
- ).type(self.visual.dtype)
498
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
499
+ self.visual.dtype
500
+ )
501
501
  video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
502
502
  assert pixel_values.dim() == 2, pixel_values.dim()
503
503
  assert video_grid_thw.dim() == 2, video_grid_thw.dim()
@@ -1,5 +1,4 @@
1
1
  # Adapted from qwen2.py
2
-
3
2
  import logging
4
3
  from functools import partial
5
4
  from typing import Any, Dict, Iterable, List, Optional, Tuple
@@ -331,6 +330,30 @@ class Qwen3ForCausalLM(nn.Module):
331
330
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
332
331
  return self.model.get_input_embeddings(input_ids)
333
332
 
333
+ def get_hidden_dim(self, module_name: str) -> Tuple[int]:
334
+ # return input_dim, output_dim
335
+ if module_name in ["q_proj", "qkv_proj"]:
336
+ return (
337
+ self.config.hidden_size,
338
+ self.config.head_dim * self.config.num_attention_heads,
339
+ )
340
+ elif module_name in ["o_proj"]:
341
+ return (
342
+ self.config.head_dim * self.config.num_attention_heads,
343
+ self.config.hidden_size,
344
+ )
345
+ elif module_name in ["kv_proj"]:
346
+ return (
347
+ self.config.hidden_size,
348
+ self.config.head_dim * self.config.num_key_value_heads,
349
+ )
350
+ elif module_name == "gate_up_proj":
351
+ return self.config.hidden_size, self.config.intermediate_size
352
+ elif module_name == "down_proj":
353
+ return self.config.intermediate_size, self.config.hidden_size
354
+ else:
355
+ raise NotImplementedError()
356
+
334
357
  @torch.no_grad()
335
358
  def forward(
336
359
  self,
@@ -367,6 +390,47 @@ class Qwen3ForCausalLM(nn.Module):
367
390
  else:
368
391
  return hidden_states
369
392
 
393
+ @torch.no_grad()
394
+ def forward_split_prefill(
395
+ self,
396
+ input_ids: torch.Tensor,
397
+ positions: torch.Tensor,
398
+ forward_batch: ForwardBatch,
399
+ split_interval: Tuple[int, int], # [start, end) 0-based
400
+ input_embeds: torch.Tensor = None,
401
+ ):
402
+ start, end = split_interval
403
+ # embed
404
+ if start == 0:
405
+ if input_embeds is None:
406
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
407
+ else:
408
+ forward_batch.hidden_states = input_embeds
409
+ # decoder layer
410
+ for i in range(start, end):
411
+ layer = self.model.layers[i]
412
+ forward_batch.hidden_states, forward_batch.residual = layer(
413
+ positions,
414
+ forward_batch.hidden_states,
415
+ forward_batch,
416
+ forward_batch.residual,
417
+ )
418
+
419
+ if end == self.model.config.num_hidden_layers:
420
+ # norm
421
+ hidden_states, _ = self.model.norm(
422
+ forward_batch.hidden_states, forward_batch.residual
423
+ )
424
+ forward_batch.hidden_states = hidden_states
425
+ # logits process
426
+ result = self.logits_processor(
427
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
428
+ )
429
+ else:
430
+ result = None
431
+
432
+ return result
433
+
370
434
  @property
371
435
  def start_layer(self):
372
436
  return self.model.start_layer
@@ -56,8 +56,7 @@ from sglang.srt.layers.linear import (
56
56
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
57
57
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
58
58
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
59
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
60
- from sglang.srt.layers.moe.topk import select_experts
59
+ from sglang.srt.layers.moe.topk import TopK
61
60
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
62
61
  from sglang.srt.layers.radix_attention import RadixAttention
63
62
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -102,6 +101,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
102
101
  f"the number of experts {config.num_experts}."
103
102
  )
104
103
 
104
+ self.topk = TopK(
105
+ top_k=config.num_experts_per_tok,
106
+ renormalize=config.norm_topk_prob,
107
+ use_grouped_topk=False,
108
+ )
109
+
105
110
  self.experts = get_moe_impl_class()(
106
111
  num_experts=config.num_experts
107
112
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -109,7 +114,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
109
114
  layer_id=layer_id,
110
115
  hidden_size=config.hidden_size,
111
116
  intermediate_size=config.moe_intermediate_size,
112
- renormalize=config.norm_topk_prob,
113
117
  quant_config=quant_config,
114
118
  prefix=add_prefix("experts", prefix),
115
119
  **(
@@ -143,7 +147,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
143
147
  config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
144
148
  )
145
149
  self.top_k = config.num_experts_per_tok
146
- self.renormalize = config.norm_topk_prob
147
150
 
148
151
  self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
149
152
  group=parallel_state.get_tp_group().device_group,
@@ -180,9 +183,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
180
183
 
181
184
  # router_logits: (num_tokens, n_experts)
182
185
  router_logits, _ = self.gate(hidden_states)
183
- final_hidden_states = self.experts(
184
- hidden_states=hidden_states, router_logits=router_logits
185
- )
186
+ topk_output = self.topk(hidden_states, router_logits)
187
+ final_hidden_states = self.experts(hidden_states, topk_output)
186
188
  if self.tp_size > 1:
187
189
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
188
190
 
@@ -195,13 +197,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
195
197
  if is_non_idle_and_non_empty(forward_mode, hidden_states):
196
198
  # router_logits: (num_tokens, n_experts)
197
199
  router_logits, _ = self.gate(hidden_states)
198
-
199
- topk_weights, topk_idx = select_experts(
200
- hidden_states=hidden_states,
201
- router_logits=router_logits,
202
- top_k=self.top_k,
203
- use_grouped_topk=False,
204
- renormalize=self.renormalize,
200
+ topk_weights, topk_idx, _ = self.topk(
201
+ hidden_states,
202
+ router_logits,
205
203
  num_token_non_padded=forward_batch.num_token_non_padded,
206
204
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
207
205
  layer_id=self.layer_id,
@@ -267,12 +265,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
267
265
  with get_global_expert_distribution_recorder().with_current_layer(
268
266
  self.layer_id
269
267
  ):
270
- state.topk_weights_local, state.topk_idx_local = select_experts(
268
+ state.topk_weights_local, state.topk_idx_local, _ = self.topk(
271
269
  hidden_states=hidden_states,
272
270
  router_logits=router_logits,
273
- top_k=self.top_k,
274
- use_grouped_topk=False,
275
- renormalize=self.renormalize,
276
271
  num_token_non_padded=state.forward_batch.num_token_non_padded,
277
272
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
278
273
  layer_id=self.layer_id,
@@ -745,6 +740,49 @@ class Qwen3MoeForCausalLM(nn.Module):
745
740
  else:
746
741
  return hidden_states
747
742
 
743
+ @torch.no_grad()
744
+ def forward_split_prefill(
745
+ self,
746
+ input_ids: torch.Tensor,
747
+ positions: torch.Tensor,
748
+ forward_batch: ForwardBatch,
749
+ split_interval: Tuple[int, int], # [start, end) 0-based
750
+ input_embeds: torch.Tensor = None,
751
+ ):
752
+ start, end = split_interval
753
+ # embed
754
+ if start == 0:
755
+ if input_embeds is None:
756
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
757
+ else:
758
+ forward_batch.hidden_states = input_embeds
759
+
760
+ # decoder layer
761
+ for i in range(start, end):
762
+ with get_global_expert_distribution_recorder().with_current_layer(i):
763
+ layer = self.model.layers[i]
764
+ forward_batch.hidden_states, forward_batch.residual = layer(
765
+ positions,
766
+ forward_batch.hidden_states,
767
+ forward_batch,
768
+ forward_batch.residual,
769
+ )
770
+
771
+ if end == self.model.config.num_hidden_layers:
772
+ # norm
773
+ hidden_states, _ = self.model.norm(
774
+ forward_batch.hidden_states, forward_batch.residual
775
+ )
776
+ forward_batch.hidden_states = hidden_states
777
+ # logits process
778
+ result = self.logits_processor(
779
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
780
+ )
781
+ else:
782
+ result = None
783
+
784
+ return result
785
+
748
786
  @property
749
787
  def start_layer(self):
750
788
  return self.model.start_layer
sglang/srt/models/vila.py CHANGED
@@ -237,7 +237,7 @@ class VILAForConditionalGeneration(nn.Module):
237
237
  return cast(LogitsProcessorOutput, output)
238
238
 
239
239
  def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
240
- pixel_values = cast(Tensor, mm_input[0].pixel_values)
240
+ pixel_values = cast(Tensor, mm_input[0].feature)
241
241
 
242
242
  ##### BEGIN COPY modeling_vila.py #####
243
243