sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
40
40
  )
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
43
+ from sglang.srt.layers.moe.topk import TopK
43
44
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
44
45
  from sglang.srt.layers.radix_attention import RadixAttention
45
46
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -152,13 +153,16 @@ class HunYuanSparseMoeBlock(nn.Module):
152
153
  else config.moe_intermediate_size[layer_id]
153
154
  )
154
155
 
156
+ self.topk = TopK(
157
+ top_k=top_k,
158
+ renormalize=True if top_k > 1 else False,
159
+ )
160
+
155
161
  self.experts = FusedMoE(
156
162
  num_experts=config.num_experts,
157
- top_k=top_k,
158
163
  hidden_size=config.hidden_size,
159
164
  intermediate_size=intermediate_size,
160
165
  reduce_results=False,
161
- renormalize=True if top_k > 1 else False,
162
166
  quant_config=quant_config,
163
167
  )
164
168
 
@@ -195,9 +199,8 @@ class HunYuanSparseMoeBlock(nn.Module):
195
199
 
196
200
  # router_logits: (num_tokens, n_experts)
197
201
  router_logits, _ = self.gate(hidden_states)
198
- final_hidden_states = self.experts(
199
- hidden_states=hidden_states, router_logits=router_logits
200
- )
202
+ topk_output = self.topk(hidden_states, router_logits)
203
+ final_hidden_states = self.experts(hidden_states, topk_output)
201
204
  if shared_output is not None:
202
205
  final_hidden_states = final_hidden_states + shared_output
203
206
  if self.tp_size > 1:
@@ -206,6 +209,42 @@ class HunYuanSparseMoeBlock(nn.Module):
206
209
  return final_hidden_states.view(orig_shape)
207
210
 
208
211
 
212
+ def get_head_dim(config):
213
+ if hasattr(config, "head_dim"):
214
+ return int(config.head_dim)
215
+ if hasattr(config, "attention_head_dim"):
216
+ return int(config.attention_head_dim)
217
+
218
+ # since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule
219
+ # wrong setting may cause runtime error, just throw error if this field is missing.
220
+ raise ValueError("Missing head dim config, try set head_dim in config.json")
221
+
222
+
223
+ def check_head_dim(config):
224
+ # Some models may lack `head_dim` and use `attention_head_dim` instead.
225
+ # This attribute is also used by flashinfer_backend.py, so we check for
226
+ # consistency and raise an error if it's not met to avoid silent failures.
227
+ # Although we could adapt the HunYuan model to use `attention_head_dim`,
228
+ # flashinfer expects `head_dim`, so we enforce its presence for correctness.
229
+ calc_head_dim = config.hidden_size // config.num_attention_heads
230
+
231
+ if hasattr(config, "attention_head_dim"):
232
+ if calc_head_dim != config.attention_head_dim and not hasattr(
233
+ config, "head_dim"
234
+ ):
235
+ # in this case, flash infer(and other components may calculate wrong value.)
236
+ raise ValueError(
237
+ f"HunYuan model config error: calculated head_dim {calc_head_dim} != attention_head_dim {config.attention_head_dim}"
238
+ + f"\nPlease Add head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
239
+ )
240
+
241
+ if hasattr(config, "head_dim") and config.attention_head_dim != config.head_dim:
242
+ raise ValueError(
243
+ f"HunYuan model config error: head_dim({config.head_dim}) != attention_head_dim({config.attention_head_dim})"
244
+ + f"\nPlease change head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
245
+ )
246
+
247
+
209
248
  class HunYuanAttention(nn.Module):
210
249
 
211
250
  def __init__(
@@ -240,9 +279,11 @@ class HunYuanAttention(nn.Module):
240
279
  assert tp_size % self.total_num_kv_heads == 0
241
280
  self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
242
281
  # MistralConfig has an optional head_dim introduced by Mistral-Nemo
243
- self.head_dim = getattr(
244
- config, "head_dim", self.hidden_size // self.total_num_heads
245
- )
282
+ # Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models.
283
+ self.head_dim = get_head_dim(config)
284
+
285
+ check_head_dim(config)
286
+
246
287
  self.q_size = self.num_heads * self.head_dim
247
288
  self.kv_size = self.num_kv_heads * self.head_dim
248
289
  self.scaling = self.head_dim**-0.5
@@ -493,7 +534,6 @@ class HunYuanModel(nn.Module):
493
534
  hidden_states = self.get_input_embeddings(input_ids)
494
535
  residual = None
495
536
 
496
- cla_factor = _get_cla_factor(self.config)
497
537
  prev_kv_states = None
498
538
  for i in range(len(self.layers)):
499
539
  layer = self.layers[i]
@@ -560,6 +600,11 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
560
600
  if config.tie_word_embeddings:
561
601
  self.lm_head.weight = self.model.embed_tokens.weight
562
602
 
603
+ self.hidden_size = config.hidden_size
604
+ self.head_dim = get_head_dim(config)
605
+
606
+ check_head_dim(config)
607
+
563
608
  logit_scale = getattr(config, "logit_scale", 1.0)
564
609
  self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
565
610
  self.sampler = Sampler()
@@ -582,16 +627,14 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
582
627
  self.config, "num_key_value_heads", self.config.num_attention_heads
583
628
  )
584
629
  num_key_value_groups = num_attention_heads // num_kv_heads
585
- hidden_size = self.config.hidden_size
586
- attention_head_dim = self.config.hidden_size // num_attention_heads
587
630
 
588
631
  qkv = qkv.reshape(
589
- num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
632
+ num_kv_heads, num_key_value_groups + 2, self.head_dim, self.hidden_size
590
633
  )
591
634
  q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
592
- q = q.reshape(-1, hidden_size)
593
- k = k.reshape(-1, hidden_size)
594
- v = v.reshape(-1, hidden_size)
635
+ q = q.reshape(-1, self.hidden_size)
636
+ k = k.reshape(-1, self.hidden_size)
637
+ v = v.reshape(-1, self.hidden_size)
595
638
  return torch.concat((q, k, v))
596
639
  # return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
597
640
 
@@ -768,4 +811,8 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
768
811
  )
769
812
 
770
813
 
771
- EntryClass = HunYuanMoEV1ForCausalLM
814
+ class HunYuanDenseV1ForCausalLM(HunYuanMoEV1ForCausalLM):
815
+ pass
816
+
817
+
818
+ EntryClass = [HunYuanMoEV1ForCausalLM, HunYuanDenseV1ForCausalLM]
@@ -510,7 +510,7 @@ class InternVLChatModel(nn.Module):
510
510
  Returns:
511
511
  image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
512
512
  """
513
- pixel_values = torch.cat([item.pixel_values for item in items])
513
+ pixel_values = torch.cat([item.feature for item in items])
514
514
  image_features = self.extract_feature(pixel_values)
515
515
  return image_features
516
516
 
@@ -144,7 +144,7 @@ class KimiVLForConditionalGeneration(nn.Module):
144
144
 
145
145
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
146
146
  pixel_values = (
147
- torch.cat([item.pixel_values for item in items], dim=0)
147
+ torch.cat([item.feature for item in items], dim=0)
148
148
  .type(self.vision_tower.dtype)
149
149
  .to(self.vision_tower.device)
150
150
  )
@@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module):
480
480
  else:
481
481
  return hidden_states
482
482
 
483
+ @torch.no_grad()
484
+ def forward_split_prefill(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ positions: torch.Tensor,
488
+ forward_batch: ForwardBatch,
489
+ split_interval: Tuple[int, int], # [start, end) 0-based
490
+ input_embeds: torch.Tensor = None,
491
+ ) -> Optional[LogitsProcessorOutput]:
492
+ start, end = split_interval
493
+ # embed
494
+ if start == 0:
495
+ if input_embeds is None:
496
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
497
+ else:
498
+ forward_batch.hidden_states = input_embeds
499
+ # decoder layer
500
+ for i in range(start, end):
501
+ layer = self.model.layers[i]
502
+ forward_batch.hidden_states, forward_batch.residual = layer(
503
+ positions,
504
+ forward_batch.hidden_states,
505
+ forward_batch,
506
+ forward_batch.residual,
507
+ )
508
+
509
+ if end == self.model.config.num_hidden_layers:
510
+ # norm
511
+ hidden_states, _ = self.model.norm(
512
+ forward_batch.hidden_states, forward_batch.residual
513
+ )
514
+ forward_batch.hidden_states = hidden_states
515
+ # logits process
516
+ result = self.logits_processor(
517
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
518
+ )
519
+ else:
520
+ result = None
521
+
522
+ return result
523
+
483
524
  @property
484
525
  def start_layer(self):
485
526
  return self.model.start_layer
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
40
40
  RowParallelLinear,
41
41
  )
42
42
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
43
+ from sglang.srt.layers.moe.topk import TopK
43
44
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
44
45
  from sglang.srt.layers.radix_attention import RadixAttention
45
46
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -103,14 +104,17 @@ class Llama4MoE(nn.Module):
103
104
  prefix=add_prefix("router", prefix),
104
105
  )
105
106
 
107
+ self.topk = TopK(
108
+ top_k=self.top_k,
109
+ renormalize=False,
110
+ custom_routing_function=Llama4MoE.custom_routing_function,
111
+ )
112
+
106
113
  self.experts = FusedMoE(
107
114
  num_experts=config.num_local_experts,
108
- top_k=config.num_experts_per_tok,
109
115
  hidden_size=config.hidden_size,
110
- custom_routing_function=Llama4MoE.custom_routing_function,
111
116
  intermediate_size=intermediate_size_moe,
112
117
  reduce_results=False,
113
- renormalize=False,
114
118
  quant_config=quant_config,
115
119
  apply_router_weight_on_input=True,
116
120
  prefix=add_prefix("experts", prefix),
@@ -147,10 +151,8 @@ class Llama4MoE(nn.Module):
147
151
  # router_scores: [num_tokens, num_experts]
148
152
  router_logits, _ = self.router(hidden_states)
149
153
  shared_out = self.shared_expert(hidden_states)
150
- routed_out = self.experts(
151
- hidden_states=hidden_states,
152
- router_logits=router_logits,
153
- )
154
+ topk_output = self.topk(hidden_states, router_logits)
155
+ routed_out = self.experts(hidden_states, topk_output)
154
156
  return shared_out, routed_out
155
157
 
156
158
  def _forward_core_shared_routed_overlap(self, hidden_states):
@@ -163,10 +165,8 @@ class Llama4MoE(nn.Module):
163
165
  with self.device_module.stream(alt_stream):
164
166
  # router_scores: [num_tokens, num_experts]
165
167
  router_logits, _ = self.router(hidden_states)
166
- routed_out = self.experts(
167
- hidden_states=hidden_states,
168
- router_logits=router_logits,
169
- )
168
+ topk_output = self.topk(hidden_states, router_logits)
169
+ routed_out = self.experts(hidden_states, topk_output)
170
170
  self.device_module.current_stream().wait_stream(alt_stream)
171
171
 
172
172
  return shared_out, routed_out
@@ -186,7 +186,7 @@ class LlavaBaseForCausalLM(nn.Module):
186
186
  bs = forward_batch.batch_size
187
187
  pixel_values = flatten_nested_list(
188
188
  [
189
- [item.pixel_values for item in image_inputs[i].mm_items]
189
+ [item.feature for item in image_inputs[i].mm_items]
190
190
  for i in range(bs)
191
191
  if need_vision[i]
192
192
  ]
@@ -753,7 +753,7 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
753
753
  features = []
754
754
  for item in items:
755
755
  # in each item, we assume pixel_values is always batched
756
- pixel_values, image_sizes = item.pixel_values, item.image_sizes
756
+ pixel_values, image_sizes = item.feature, item.image_sizes
757
757
  image_outputs = self.vision_tower(
758
758
  pixel_values, image_sizes, output_hidden_states=True
759
759
  )
@@ -135,7 +135,7 @@ class LlavaVidForCausalLM(nn.Module):
135
135
  if need_vision.any():
136
136
  pixel_values = flatten_nested_list(
137
137
  [
138
- [item.pixel_values for item in image_inputs[i].mm_items]
138
+ [item.feature for item in image_inputs[i].mm_items]
139
139
  for i in range(bs)
140
140
  if need_vision[i]
141
141
  ]
@@ -138,8 +138,6 @@ class MiniCPMAttention(nn.Module):
138
138
  base=rope_theta,
139
139
  rope_scaling=rope_scaling,
140
140
  )
141
- # set rope as fp32 instead of bf16
142
- self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
143
141
  self.attn = RadixAttention(
144
142
  self.num_heads,
145
143
  self.head_dim,
@@ -1552,9 +1552,7 @@ class MiniCPMO(MiniCPMBaseModel):
1552
1552
  Returns:
1553
1553
  List[List[torch.Tensor]]: audio embeddings
1554
1554
  """
1555
- wavforms = flatten_nested_list(
1556
- [item.audio_features for item in items if item.audio_features]
1557
- )
1555
+ wavforms = flatten_nested_list([item.feature for item in items if item.feature])
1558
1556
  # list, [[x1, x2], [y1], [z1]]
1559
1557
  audio_feature_lens_raw = flatten_nested_list(
1560
1558
  [item.audio_feature_lens for item in items if item.audio_feature_lens]
@@ -1659,9 +1657,7 @@ class MiniCPMO(MiniCPMBaseModel):
1659
1657
  List[List[torch.Tensor]]: audio embeddings
1660
1658
  """
1661
1659
  # (bs, 80, frames) or [], multi audios need filled in advance
1662
- wavforms = flatten_nested_list(
1663
- [item.audio_features for item in items if item.audio_features]
1664
- )
1660
+ wavforms = flatten_nested_list([item.feature for item in items if item.feature])
1665
1661
  # list, [[x1, x2], [y1], [z1]]
1666
1662
  audio_feature_lens_raw = flatten_nested_list(
1667
1663
  [item.audio_feature_lens for item in items if item.audio_feature_lens]
@@ -1778,7 +1774,7 @@ class MiniCPMO(MiniCPMBaseModel):
1778
1774
 
1779
1775
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1780
1776
  # list of tensors
1781
- pixel_values = flatten_nested_list([item.pixel_values for item in items])
1777
+ pixel_values = flatten_nested_list([item.feature for item in items])
1782
1778
  tgt_sizes = torch.stack(
1783
1779
  flatten_nested_list([item.tgt_size for item in items]), dim=0
1784
1780
  )
@@ -724,7 +724,7 @@ class MiniCPMV2_6(MiniCPMBaseModel):
724
724
 
725
725
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
726
726
  # list of tensors
727
- pixel_values = flatten_nested_list([item.pixel_values for item in items])
727
+ pixel_values = flatten_nested_list([item.feature for item in items])
728
728
  tgt_sizes = torch.stack(
729
729
  flatten_nested_list([item.tgt_size for item in items]), dim=0
730
730
  )
@@ -56,7 +56,7 @@ class Mistral3ForConditionalGeneration:
56
56
  features = []
57
57
  for item in items:
58
58
  # in each item, we assume pixel_values is always batched
59
- pixel_values, image_sizes = item.pixel_values, item.image_sizes
59
+ pixel_values, image_sizes = item.feature, item.image_sizes
60
60
  image_outputs = self.vision_tower(
61
61
  pixel_values, image_sizes, output_hidden_states=True
62
62
  )
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
37
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
38
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
39
39
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
40
+ from sglang.srt.layers.moe.topk import TopK
40
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
42
  from sglang.srt.layers.radix_attention import RadixAttention
42
43
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -86,6 +87,12 @@ class MixtralMoE(nn.Module):
86
87
  quant_config=None,
87
88
  prefix=add_prefix("gate", prefix),
88
89
  )
90
+
91
+ self.topk = TopK(
92
+ top_k=top_k,
93
+ renormalize=True,
94
+ )
95
+
89
96
  MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
90
97
  self.experts = MoEImpl(
91
98
  num_experts=num_experts,
@@ -93,7 +100,6 @@ class MixtralMoE(nn.Module):
93
100
  hidden_size=hidden_size,
94
101
  intermediate_size=intermediate_size,
95
102
  params_dtype=params_dtype,
96
- renormalize=True,
97
103
  quant_config=quant_config,
98
104
  tp_size=tp_size,
99
105
  prefix=add_prefix("experts", prefix),
@@ -105,7 +111,8 @@ class MixtralMoE(nn.Module):
105
111
  hidden_states = hidden_states.view(-1, self.hidden_size)
106
112
  # router_logits: (num_tokens, n_experts)
107
113
  router_logits, _ = self.gate(hidden_states)
108
- final_hidden_states = self.experts(hidden_states, router_logits)
114
+ topk_output = self.topk(hidden_states, router_logits)
115
+ final_hidden_states = self.experts(hidden_states, topk_output)
109
116
  if self.tp_size > 1:
110
117
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
111
118
  return final_hidden_states.view(orig_shape)
@@ -838,9 +838,7 @@ class MllamaForConditionalGeneration(nn.Module):
838
838
  self.logits_processor = LogitsProcessor(config.text_config)
839
839
 
840
840
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
841
- pixel_values = torch.cat(
842
- [item.pixel_values for item in mm_inputs.mm_items], dim=0
843
- )
841
+ pixel_values = torch.cat([item.feature for item in mm_inputs.mm_items], dim=0)
844
842
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
845
843
 
846
844
  num_concurrent_media, num_tiles = pixel_values.shape[1:3]
@@ -862,7 +860,7 @@ class MllamaForConditionalGeneration(nn.Module):
862
860
 
863
861
  if not forward_batch.encoder_cached[i] and mm_input is not None:
864
862
  pixel_values = torch.cat(
865
- [item.pixel_values for item in mm_input.mm_items], dim=0
863
+ [item.feature for item in mm_input.mm_items], dim=0
866
864
  )
867
865
  max_num_images = max(max_num_images, pixel_values.shape[1])
868
866
 
@@ -897,7 +895,7 @@ class MllamaForConditionalGeneration(nn.Module):
897
895
 
898
896
  encoder_lens_need.append(forward_batch.encoder_lens[k])
899
897
  pixel_values = torch.cat(
900
- [item.pixel_values for item in mm_input.mm_items], dim=0
898
+ [item.feature for item in mm_input.mm_items], dim=0
901
899
  )
902
900
  for j in range(pixel_values.shape[1]):
903
901
  img = pixel_values[0, j]
@@ -23,6 +23,7 @@ from sglang.srt.managers.schedule_batch import (
23
23
  Modality,
24
24
  MultimodalDataItem,
25
25
  MultimodalInputs,
26
+ global_server_args_dict,
26
27
  )
27
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
29
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -55,13 +56,17 @@ class Llama4ForConditionalGeneration(nn.Module):
55
56
  self.quant_config = quant_config
56
57
 
57
58
  # Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
58
- self.has_vision = self._has_vision_weights(config)
59
- if not self.has_vision:
59
+ self.has_vision_weights = self._has_vision_weights(config)
60
+ if not self.has_vision_weights:
60
61
  logger.warning(
61
62
  "No vision weights found in checkpoint. Model will run in text-only mode. "
62
63
  "Multimodal capabilities (image processing) will be unavailable."
63
64
  )
64
65
 
66
+ self.has_vision = (
67
+ self.has_vision_weights and global_server_args_dict["enable_multimodal"]
68
+ )
69
+
65
70
  if self.has_vision:
66
71
  self.vision_model = Llama4VisionModel(config.vision_config)
67
72
  self.multi_modal_projector = Llama4MultiModalProjector(config)
@@ -81,6 +86,7 @@ class Llama4ForConditionalGeneration(nn.Module):
81
86
  self.logits_processor = LogitsProcessor(
82
87
  config.text_config if hasattr(config, "text_config") else config
83
88
  )
89
+ self.padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens()
84
90
 
85
91
  def _has_vision_weights(self, config) -> bool:
86
92
  """Check if the model has vision components by examining the checkpoint."""
@@ -135,8 +141,7 @@ class Llama4ForConditionalGeneration(nn.Module):
135
141
  return False
136
142
 
137
143
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
138
- pattern = MultiModalityDataPaddingPatternMultimodalTokens()
139
- return pattern.pad_input_tokens(input_ids, mm_inputs)
144
+ return self.padding_pattern.pad_input_tokens(input_ids, mm_inputs)
140
145
 
141
146
  def get_image_feature(
142
147
  self,
@@ -147,7 +152,7 @@ class Llama4ForConditionalGeneration(nn.Module):
147
152
  raise ValueError("Vision model not available for text-only checkpoint")
148
153
 
149
154
  pixel_values = (
150
- torch.concat([item.pixel_values for item in items])
155
+ torch.concat([item.feature for item in items])
151
156
  .to(next(self.vision_model.parameters()).device)
152
157
  .type(next(self.vision_model.parameters()).dtype)
153
158
  )
@@ -269,7 +274,9 @@ class Llama4ForConditionalGeneration(nn.Module):
269
274
 
270
275
  def _should_skip_weight(self, name: str) -> bool:
271
276
  """Check if we should skip loading this weight."""
272
- return "vision" in name and not self.has_vision
277
+ return not self.has_vision and (
278
+ "vision" in name or "multi_modal_projector" in name
279
+ )
273
280
 
274
281
  def _transform_weight_name(self, name: str) -> str:
275
282
  """Transform weight name by adding language_model prefix if needed."""
@@ -32,6 +32,7 @@ from sglang.srt.layers.linear import (
32
32
  )
33
33
  from sglang.srt.layers.logits_processor import LogitsProcessor
34
34
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
35
+ from sglang.srt.layers.moe.topk import TopK
35
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
37
  from sglang.srt.layers.radix_attention import RadixAttention
37
38
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -76,13 +77,16 @@ class OlmoeMoE(nn.Module):
76
77
  prefix=add_prefix("gate", prefix),
77
78
  )
78
79
 
80
+ self.topk = TopK(
81
+ top_k=top_k,
82
+ renormalize=False,
83
+ )
84
+
79
85
  self.experts = FusedMoE(
80
86
  num_experts=num_experts,
81
- top_k=top_k,
82
87
  hidden_size=hidden_size,
83
88
  intermediate_size=intermediate_size,
84
89
  reduce_results=True,
85
- renormalize=False,
86
90
  quant_config=quant_config,
87
91
  tp_size=tp_size,
88
92
  prefix=add_prefix("experts", prefix),
@@ -94,9 +98,8 @@ class OlmoeMoE(nn.Module):
94
98
  hidden_states = hidden_states.view(-1, self.hidden_size)
95
99
  # router_logits: (num_tokens, n_experts)
96
100
  router_logits, _ = self.gate(hidden_states)
97
- final_hidden_states = self.experts(
98
- hidden_states=hidden_states, router_logits=router_logits
99
- )
101
+ topk_output = self.topk(hidden_states, router_logits)
102
+ final_hidden_states = self.experts(hidden_states, topk_output)
100
103
  return final_hidden_states.view(orig_shape)
101
104
 
102
105