sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
154
154
  return res
155
155
 
156
156
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
157
- # Get all special token IDs
158
- pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
157
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
159
158
  return pattern.pad_input_tokens(input_ids, mm_inputs)
160
159
 
161
160
  def forward(
@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
697
697
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
698
698
  self.model.load_kv_cache_scales(quantization_param_path)
699
699
 
700
- def set_eagle3_layers_to_capture(self):
700
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
701
701
  if not self.pp_group.is_last_rank:
702
702
  return
703
703
 
704
- self.capture_aux_hidden_states = True
705
- num_layers = self.config.num_hidden_layers
706
- self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
704
+ if layer_ids is None:
705
+ self.capture_aux_hidden_states = True
706
+ num_layers = self.config.num_hidden_layers
707
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
708
+ else:
709
+ self.capture_aux_hidden_states = True
710
+ # we plus 1 here because in sglang, for the ith layer, it takes the output
711
+ # of the (i-1)th layer as aux hidden state
712
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
707
713
 
708
714
 
709
715
  class Phi3ForCausalLM(LlamaForCausalLM):
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
30
31
  from sglang.srt.layers.dp_attention import (
31
- dp_gather_partial,
32
- dp_scatter,
33
32
  get_attention_tp_rank,
34
33
  get_attention_tp_size,
35
34
  get_local_attention_dp_size,
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
367
366
  bias_o_proj=False,
368
367
  prefix=add_prefix("self_attn", prefix),
369
368
  )
370
- is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
369
+ self.config = config
370
+ is_moe_layer = self._is_moe_layer(layer_id)
371
+ is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
372
+
371
373
  if is_moe_layer:
372
374
  self.feed_forward = Llama4MoE(
373
375
  config=config,
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
387
389
  config.hidden_size, eps=config.rms_norm_eps
388
390
  )
389
391
 
392
+ self.layer_scatter_modes = LayerScatterModes.init_new(
393
+ layer_id=layer_id,
394
+ num_layers=config.num_hidden_layers,
395
+ is_layer_sparse=is_moe_layer,
396
+ is_previous_layer_sparse=is_previous_moe_layer,
397
+ )
398
+
399
+ self.layer_communicator = LayerCommunicator(
400
+ layer_scatter_modes=self.layer_scatter_modes,
401
+ input_layernorm=self.input_layernorm,
402
+ post_attention_layernorm=self.post_attention_layernorm,
403
+ )
404
+
405
+ def _is_moe_layer(self, layer_id: int) -> bool:
406
+ return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
407
+
390
408
  def forward(
391
409
  self,
392
410
  positions: torch.Tensor,
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
394
412
  forward_batch: ForwardBatch,
395
413
  residual: Optional[torch.Tensor],
396
414
  ) -> Tuple[torch.Tensor, torch.Tensor]:
397
- if hidden_states.shape[0] == 0:
398
- residual = hidden_states
399
- else:
400
- # Self Attention
401
- if residual is None:
402
- residual = hidden_states
403
- hidden_states = self.input_layernorm(hidden_states)
404
- else:
405
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
415
+ hidden_states, residual = self.layer_communicator.prepare_attn(
416
+ hidden_states, residual, forward_batch
417
+ )
418
+
419
+ if hidden_states.shape[0] != 0:
406
420
  hidden_states = self.self_attn(
407
421
  positions=positions,
408
422
  hidden_states=hidden_states,
409
423
  forward_batch=forward_batch,
410
424
  )
411
425
 
412
- # Gather
413
- if get_tensor_model_parallel_world_size() > 1:
414
- # all gather and all reduce
415
- if self.local_dp_size != 1:
416
- if self.attn_tp_rank == 0:
417
- hidden_states += residual
418
- hidden_states, local_hidden_states = (
419
- forward_batch.gathered_buffer,
420
- hidden_states,
421
- )
422
- dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
423
- dp_scatter(residual, hidden_states, forward_batch)
424
- hidden_states = self.post_attention_layernorm(hidden_states)
425
- else:
426
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
427
- hidden_states, residual = self.post_attention_layernorm(
428
- hidden_states, residual
429
- )
430
- else:
431
- hidden_states, residual = self.post_attention_layernorm(
432
- hidden_states, residual
433
- )
426
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
427
+ hidden_states, residual, forward_batch
428
+ )
434
429
 
435
430
  # Fully Connected
436
431
  hidden_states = self.feed_forward(hidden_states, forward_batch)
437
-
438
- # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
439
- # Scatter
440
- if self.local_dp_size != 1:
441
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
442
- # be careful about this!
443
- hidden_states, global_hidden_states = (
444
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
445
- hidden_states,
446
- )
447
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
432
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
433
+ hidden_states, residual, forward_batch
434
+ )
448
435
 
449
436
  return hidden_states, residual
450
437
 
@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
35
35
  VocabParallelEmbedding,
36
36
  )
37
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
38
- from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
38
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
+ from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
39
40
 
40
41
 
41
42
  class LlamaDecoderLayer(LlamaDecoderLayer):
@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
59
60
  prefix=add_prefix("qkv_proj", prefix),
60
61
  )
61
62
 
63
+ if config.model_type == "llama4_text":
64
+ inter_size = config.intermediate_size_mlp
65
+ else:
66
+ inter_size = config.intermediate_size
67
+
68
+ self.mlp = LlamaMLP(
69
+ config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
70
+ )
71
+
62
72
  self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
63
73
 
64
74
  def forward(
@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
105
115
  config.hidden_size,
106
116
  prefix=add_prefix("embed_tokens", prefix),
107
117
  )
108
- self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
118
+
109
119
  if hasattr(config, "target_hidden_size"):
110
- self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size)
120
+ self.hidden_size_in = config.target_hidden_size
111
121
  else:
112
- self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
122
+ self.hidden_size_in = config.hidden_size
123
+
124
+ self.fc = torch.nn.Linear(
125
+ self.hidden_size_in * 3,
126
+ config.hidden_size,
127
+ bias=getattr(config, "bias", False),
128
+ )
129
+
130
+ self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
113
131
 
114
132
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
115
133
 
@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
179
197
 
180
198
  self.logits_processor = LogitsProcessor(config)
181
199
  self.capture_aux_hidden_states = True
200
+ self.hot_token_id = None
201
+
202
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
203
+ params_dict = dict(self.named_parameters())
204
+ # Define the parameter mapping for stacked parameters
205
+ stacked_params_mapping = [
206
+ # (param_name, shard_name, shard_id)
207
+ (".qkv_proj", ".q_proj", "q"),
208
+ (".qkv_proj", ".k_proj", "k"),
209
+ (".qkv_proj", ".v_proj", "v"),
210
+ (".gate_up_proj", ".gate_proj", 0),
211
+ (".gate_up_proj", ".up_proj", 1),
212
+ ]
182
213
 
183
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
184
214
  for name, loaded_weight in weights:
185
215
  if "d2t" in name:
186
216
  # d2t stores diffs between draft id and target id
187
217
  self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
188
-
189
- if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
190
- new_name = f"model.{name}"
191
- super().load_weights([(new_name, loaded_weight)])
192
- elif "lm_head" in name:
193
- super().load_weights([(name, loaded_weight)])
218
+ continue
219
+
220
+ if "t2d" in name:
221
+ continue
222
+
223
+ for param_name, weight_name, shard_id in stacked_params_mapping:
224
+ if weight_name not in name:
225
+ continue
226
+ name = name.replace(weight_name, param_name)
227
+ param_name = f"model.{name}" if name not in params_dict else name
228
+ if param_name in params_dict:
229
+ param = params_dict[param_name]
230
+ weight_loader = getattr(
231
+ param, "weight_loader", default_weight_loader
232
+ )
233
+ weight_loader(param, loaded_weight, shard_id)
234
+ break
235
+ else:
236
+ # Handle regular parameters
237
+ param_name = name if name in params_dict else f"model.{name}"
238
+ if param_name in params_dict:
239
+ param = params_dict[param_name]
240
+ weight_loader = getattr(
241
+ param, "weight_loader", default_weight_loader
242
+ )
243
+ weight_loader(param, loaded_weight)
194
244
 
195
245
  def get_hot_token_id(self):
196
246
  return self.hot_token_id
@@ -41,16 +41,16 @@ from sglang.srt.managers.schedule_batch import (
41
41
  MultimodalDataItem,
42
42
  MultimodalInputs,
43
43
  )
44
- from sglang.srt.mm_utils import (
45
- get_anyres_image_grid_shape,
46
- unpad_image,
47
- unpad_image_shape,
48
- )
49
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
45
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
46
  from sglang.srt.models.llama import LlamaForCausalLM
52
47
  from sglang.srt.models.mistral import MistralForCausalLM
53
48
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
49
+ from sglang.srt.multimodal.mm_utils import (
50
+ get_anyres_image_grid_shape,
51
+ unpad_image,
52
+ unpad_image_shape,
53
+ )
54
54
  from sglang.srt.utils import add_prefix, flatten_nested_list, logger
55
55
 
56
56
 
@@ -32,7 +32,7 @@ from transformers.activations import ACT2FN
32
32
  from transformers.cache_utils import DynamicCache, EncoderDecoderCache
33
33
  from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
34
34
  from transformers.models.whisper.modeling_whisper import (
35
- WHISPER_ATTENTION_CLASSES,
35
+ WhisperAttention,
36
36
  WhisperConfig,
37
37
  WhisperEncoder,
38
38
  )
@@ -1090,7 +1090,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
1090
1090
  def __init__(self, config: WhisperConfig, layer_idx: int = None):
1091
1091
  super().__init__()
1092
1092
  self.embed_dim = config.d_model
1093
- self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
1093
+ self.self_attn = WhisperAttention(
1094
1094
  embed_dim=self.embed_dim,
1095
1095
  num_heads=config.encoder_attention_heads,
1096
1096
  dropout=config.attention_dropout,
@@ -13,7 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Inference-only Mistral model."""
15
15
 
16
- from typing import List, Union
16
+ from typing import List
17
17
 
18
18
  import torch
19
19
  from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
@@ -16,7 +16,9 @@ from sglang.srt.managers.mm_utils import (
16
16
  from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
17
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
18
  from sglang.srt.model_loader.weight_utils import default_weight_loader
19
- from sglang.srt.utils import add_prefix
19
+ from sglang.srt.utils import add_prefix, is_cpu
20
+
21
+ _is_cpu = is_cpu()
20
22
 
21
23
 
22
24
  class Llama4ForConditionalGeneration(nn.Module):
@@ -50,10 +52,7 @@ class Llama4ForConditionalGeneration(nn.Module):
50
52
  self.logits_processor = LogitsProcessor(config.text_config)
51
53
 
52
54
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
53
- # Get all special token IDs
54
- im_token_id: int = mm_inputs.im_token_id
55
-
56
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
55
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
57
56
  return pattern.pad_input_tokens(input_ids, mm_inputs)
58
57
 
59
58
  def get_image_feature(
@@ -110,13 +109,17 @@ class Llama4ForConditionalGeneration(nn.Module):
110
109
 
111
110
  # rotary embeds should be sliced
112
111
  if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
113
- loaded_weight = permute(
114
- loaded_weight, self.language_model.config.num_key_value_heads
115
- )
112
+ if _is_cpu:
113
+ dim = self.language_model.config.original_total_num_kv_heads
114
+ else:
115
+ dim = self.language_model.config.num_key_value_heads
116
+ loaded_weight = permute(loaded_weight, dim)
116
117
  elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
117
- loaded_weight = permute(
118
- loaded_weight, self.language_model.config.num_attention_heads
119
- )
118
+ if _is_cpu:
119
+ dim = self.language_model.config.original_num_attention_heads
120
+ else:
121
+ dim = self.language_model.config.num_attention_heads
122
+ loaded_weight = permute(loaded_weight, dim)
120
123
 
121
124
  return name, loaded_weight
122
125
 
@@ -223,5 +226,34 @@ class Llama4ForConditionalGeneration(nn.Module):
223
226
  )
224
227
  weight_loader(param, loaded_weight)
225
228
 
229
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
230
+ if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
231
+ self.language_model.set_eagle3_layers_to_capture(layer_ids)
232
+
233
+ def get_embed_and_head(self):
234
+ # For EAGLE3, we delegate to the language model which should have this method
235
+ # If the language model doesn't have lm_head (like EAGLE3), we return None for head
236
+ embed = self.language_model.get_embed()
237
+ if hasattr(self.language_model, "get_embed_and_head"):
238
+ return self.language_model.get_embed_and_head()
239
+ elif hasattr(self.language_model, "lm_head"):
240
+ return embed, self.language_model.lm_head.weight
241
+ else:
242
+ # For EAGLE3, head might not be needed
243
+ return embed, None
244
+
245
+ def set_embed_and_head(self, embed, head):
246
+ if hasattr(self.language_model, "set_embed_and_head"):
247
+ return self.language_model.set_embed_and_head(embed, head)
248
+ else:
249
+ # For EAGLE3, only set embed
250
+ return self.language_model.set_embed(embed)
251
+
252
+ def get_embed(self):
253
+ return self.language_model.get_embed()
254
+
255
+ def set_embed(self, embed):
256
+ return self.language_model.set_embed(embed)
257
+
226
258
 
227
259
  EntryClass = Llama4ForConditionalGeneration
@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module):
446
446
  return hidden_states
447
447
 
448
448
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
449
- # Get all special token IDs
450
- im_token_id: int = mm_inputs.im_token_id
451
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
449
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
452
450
  return pattern.pad_input_tokens(input_ids, mm_inputs)
453
451
 
454
452
  def should_apply_lora(self, module_name: str) -> bool:
@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module):
268
268
 
269
269
  DEFAULT_IMAGE_TOKEN_ID = 10
270
270
 
271
- def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
272
- return self.input_padder.pad_input_tokens(input_ids, image_inputs)
271
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
272
+ return self.input_padder.pad_input_tokens(input_ids, mm_inputs)
273
273
 
274
274
  def __init__(
275
275
  self,
276
276
  config: PixtralVisionConfig,
277
277
  quant_config: Optional[QuantizationConfig] = None,
278
278
  *,
279
- image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
280
279
  num_hidden_layers_override: Optional[int] = None,
281
280
  prefix: str = "",
282
281
  ) -> None:
@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module):
314
313
  )
315
314
 
316
315
  # Initialize patch position embedding
317
- self.image_token_id = image_token_id
318
316
  self.patch_positional_embedding = PixtralRotaryEmbedding(config)
319
- self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
320
- [self.image_token_id]
321
- )
317
+ self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens()
322
318
 
323
319
  @property
324
320
  def dtype(self):
@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
43
  ParallelLMHead,
44
44
  VocabParallelEmbedding,
45
45
  )
46
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
46
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
48
  from sglang.srt.model_loader.weight_utils import (
48
49
  default_weight_loader,
@@ -100,6 +101,7 @@ class Qwen2Attention(nn.Module):
100
101
  hidden_size: int,
101
102
  num_heads: int,
102
103
  num_kv_heads: int,
104
+ head_dim: Optional[int] = None,
103
105
  layer_id: int = 0,
104
106
  rope_theta: float = 1000000,
105
107
  rope_scaling: Optional[Dict[str, Any]] = None,
@@ -123,7 +125,10 @@ class Qwen2Attention(nn.Module):
123
125
  # the KV heads across multiple tensor parallel GPUs.
124
126
  assert tp_size % self.total_num_kv_heads == 0
125
127
  self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
126
- self.head_dim = hidden_size // self.total_num_heads
128
+ if head_dim is not None:
129
+ self.head_dim = head_dim
130
+ else:
131
+ self.head_dim = hidden_size // self.total_num_heads
127
132
  self.q_size = self.num_heads * self.head_dim
128
133
  self.kv_size = self.num_kv_heads * self.head_dim
129
134
  self.scaling = self.head_dim**-0.5
@@ -185,16 +190,19 @@ class Qwen2DecoderLayer(nn.Module):
185
190
  layer_id: int = 0,
186
191
  quant_config: Optional[QuantizationConfig] = None,
187
192
  prefix: str = "",
193
+ alt_stream: Optional[torch.cuda.Stream] = None,
188
194
  ) -> None:
189
195
  super().__init__()
190
196
  self.hidden_size = config.hidden_size
191
197
  rope_theta = getattr(config, "rope_theta", 1000000)
192
198
  rope_scaling = getattr(config, "rope_scaling", None)
193
199
  max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
200
+ head_dim = getattr(config, "head_dim", None)
194
201
  self.self_attn = Qwen2Attention(
195
202
  hidden_size=self.hidden_size,
196
203
  num_heads=config.num_attention_heads,
197
204
  num_kv_heads=config.num_key_value_heads,
205
+ head_dim=head_dim,
198
206
  layer_id=layer_id,
199
207
  rope_theta=rope_theta,
200
208
  rope_scaling=rope_scaling,
@@ -246,6 +254,7 @@ class Qwen2Model(nn.Module):
246
254
  quant_config: Optional[QuantizationConfig] = None,
247
255
  prefix: str = "",
248
256
  decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
257
+ alt_stream: Optional[torch.cuda.Stream] = None,
249
258
  ) -> None:
250
259
  super().__init__()
251
260
  self.config = config
@@ -258,6 +267,7 @@ class Qwen2Model(nn.Module):
258
267
  config.vocab_size,
259
268
  config.hidden_size,
260
269
  quant_config=quant_config,
270
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
261
271
  prefix=add_prefix("embed_tokens", prefix),
262
272
  )
263
273
  else:
@@ -272,6 +282,7 @@ class Qwen2Model(nn.Module):
272
282
  config=config,
273
283
  quant_config=quant_config,
274
284
  prefix=prefix,
285
+ alt_stream=alt_stream,
275
286
  ),
276
287
  pp_rank=self.pp_group.rank_in_group,
277
288
  pp_size=self.pp_group.world_size,
@@ -282,6 +293,9 @@ class Qwen2Model(nn.Module):
282
293
  else:
283
294
  self.norm = PPMissingLayer(return_tuple=True)
284
295
 
296
+ # For EAGLE3 support
297
+ self.layers_to_capture = []
298
+
285
299
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
286
300
  if hasattr(self.config, "scale_emb"):
287
301
  return self.get_input_embeddings()(input_ids) * self.config.scale_emb
@@ -310,7 +324,12 @@ class Qwen2Model(nn.Module):
310
324
  hidden_states = pp_proxy_tensors["hidden_states"]
311
325
  residual = pp_proxy_tensors["residual"]
312
326
 
327
+ aux_hidden_states = []
313
328
  for i in range(self.start_layer, self.end_layer):
329
+ if i in self.layers_to_capture:
330
+ aux_hidden_states.append(
331
+ hidden_states + residual if residual is not None else hidden_states
332
+ )
314
333
  layer = self.layers[i]
315
334
  hidden_states, residual = layer(
316
335
  positions,
@@ -326,8 +345,16 @@ class Qwen2Model(nn.Module):
326
345
  }
327
346
  )
328
347
  else:
329
- hidden_states, _ = self.norm(hidden_states, residual)
330
- return hidden_states
348
+ if hidden_states.shape[0] != 0:
349
+ if residual is None:
350
+ hidden_states = self.norm(hidden_states)
351
+ else:
352
+ hidden_states, _ = self.norm(hidden_states, residual)
353
+
354
+ if len(aux_hidden_states) == 0:
355
+ return hidden_states
356
+
357
+ return hidden_states, aux_hidden_states
331
358
 
332
359
  # If this function is called, it should always initialize KV cache scale
333
360
  # factors (or else raise an exception). Thus, handled exceptions should
@@ -398,6 +425,7 @@ class Qwen2ForCausalLM(nn.Module):
398
425
  quant_config=quant_config,
399
426
  prefix=add_prefix("lm_head", prefix),
400
427
  )
428
+
401
429
  else:
402
430
  # ranks other than the last rank will have a placeholder layer
403
431
  self.lm_head = PPMissingLayer()
@@ -493,9 +493,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
493
493
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
494
494
 
495
495
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
496
- # Get all special token IDs
497
- im_token_id: int = mm_inputs.im_token_id
498
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
496
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
499
497
  return pattern.pad_input_tokens(input_ids, mm_inputs)
500
498
 
501
499
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: