sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
28
28
  get_tensor_model_parallel_world_size,
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
+ from sglang.srt.eplb.expert_distribution import ExpertDistributionRecorder
31
32
  from sglang.srt.layers.activation import SiluAndMul
32
33
  from sglang.srt.layers.layernorm import RMSNorm
33
34
  from sglang.srt.layers.linear import (
@@ -48,7 +49,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
49
  ParallelLMHead,
49
50
  VocabParallelEmbedding,
50
51
  )
51
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
52
52
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
53
53
  from sglang.srt.model_loader.weight_utils import (
54
54
  default_weight_loader,
@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
154
154
  return res
155
155
 
156
156
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
157
- # Get all special token IDs
158
- pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
157
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
159
158
  return pattern.pad_input_tokens(input_ids, mm_inputs)
160
159
 
161
160
  def forward(
@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
697
697
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
698
698
  self.model.load_kv_cache_scales(quantization_param_path)
699
699
 
700
- def set_eagle3_layers_to_capture(self):
700
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
701
701
  if not self.pp_group.is_last_rank:
702
702
  return
703
703
 
704
- self.capture_aux_hidden_states = True
705
- num_layers = self.config.num_hidden_layers
706
- self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
704
+ if layer_ids is None:
705
+ self.capture_aux_hidden_states = True
706
+ num_layers = self.config.num_hidden_layers
707
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
708
+ else:
709
+ self.capture_aux_hidden_states = True
710
+ # we plus 1 here because in sglang, for the ith layer, it takes the output
711
+ # of the (i-1)th layer as aux hidden state
712
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
707
713
 
708
714
 
709
715
  class Phi3ForCausalLM(LlamaForCausalLM):
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
30
31
  from sglang.srt.layers.dp_attention import (
31
- dp_gather_partial,
32
- dp_scatter,
33
32
  get_attention_tp_rank,
34
33
  get_attention_tp_size,
35
34
  get_local_attention_dp_size,
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
367
366
  bias_o_proj=False,
368
367
  prefix=add_prefix("self_attn", prefix),
369
368
  )
370
- is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
369
+ self.config = config
370
+ is_moe_layer = self._is_moe_layer(layer_id)
371
+ is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
372
+
371
373
  if is_moe_layer:
372
374
  self.feed_forward = Llama4MoE(
373
375
  config=config,
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
387
389
  config.hidden_size, eps=config.rms_norm_eps
388
390
  )
389
391
 
392
+ self.layer_scatter_modes = LayerScatterModes.init_new(
393
+ layer_id=layer_id,
394
+ num_layers=config.num_hidden_layers,
395
+ is_layer_sparse=is_moe_layer,
396
+ is_previous_layer_sparse=is_previous_moe_layer,
397
+ )
398
+
399
+ self.layer_communicator = LayerCommunicator(
400
+ layer_scatter_modes=self.layer_scatter_modes,
401
+ input_layernorm=self.input_layernorm,
402
+ post_attention_layernorm=self.post_attention_layernorm,
403
+ )
404
+
405
+ def _is_moe_layer(self, layer_id: int) -> bool:
406
+ return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
407
+
390
408
  def forward(
391
409
  self,
392
410
  positions: torch.Tensor,
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
394
412
  forward_batch: ForwardBatch,
395
413
  residual: Optional[torch.Tensor],
396
414
  ) -> Tuple[torch.Tensor, torch.Tensor]:
397
- if hidden_states.shape[0] == 0:
398
- residual = hidden_states
399
- else:
400
- # Self Attention
401
- if residual is None:
402
- residual = hidden_states
403
- hidden_states = self.input_layernorm(hidden_states)
404
- else:
405
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
415
+ hidden_states, residual = self.layer_communicator.prepare_attn(
416
+ hidden_states, residual, forward_batch
417
+ )
418
+
419
+ if hidden_states.shape[0] != 0:
406
420
  hidden_states = self.self_attn(
407
421
  positions=positions,
408
422
  hidden_states=hidden_states,
409
423
  forward_batch=forward_batch,
410
424
  )
411
425
 
412
- # Gather
413
- if get_tensor_model_parallel_world_size() > 1:
414
- # all gather and all reduce
415
- if self.local_dp_size != 1:
416
- if self.attn_tp_rank == 0:
417
- hidden_states += residual
418
- hidden_states, local_hidden_states = (
419
- forward_batch.gathered_buffer,
420
- hidden_states,
421
- )
422
- dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
423
- dp_scatter(residual, hidden_states, forward_batch)
424
- hidden_states = self.post_attention_layernorm(hidden_states)
425
- else:
426
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
427
- hidden_states, residual = self.post_attention_layernorm(
428
- hidden_states, residual
429
- )
430
- else:
431
- hidden_states, residual = self.post_attention_layernorm(
432
- hidden_states, residual
433
- )
426
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
427
+ hidden_states, residual, forward_batch
428
+ )
434
429
 
435
430
  # Fully Connected
436
431
  hidden_states = self.feed_forward(hidden_states, forward_batch)
437
-
438
- # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
439
- # Scatter
440
- if self.local_dp_size != 1:
441
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
442
- # be careful about this!
443
- hidden_states, global_hidden_states = (
444
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
445
- hidden_states,
446
- )
447
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
432
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
433
+ hidden_states, residual, forward_batch
434
+ )
448
435
 
449
436
  return hidden_states, residual
450
437
 
@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
35
35
  VocabParallelEmbedding,
36
36
  )
37
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
38
- from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
38
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
+ from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
39
40
 
40
41
 
41
42
  class LlamaDecoderLayer(LlamaDecoderLayer):
@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
59
60
  prefix=add_prefix("qkv_proj", prefix),
60
61
  )
61
62
 
63
+ if config.model_type == "llama4_text":
64
+ inter_size = config.intermediate_size_mlp
65
+ else:
66
+ inter_size = config.intermediate_size
67
+
68
+ self.mlp = LlamaMLP(
69
+ config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
70
+ )
71
+
62
72
  self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
63
73
 
64
74
  def forward(
@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
105
115
  config.hidden_size,
106
116
  prefix=add_prefix("embed_tokens", prefix),
107
117
  )
108
- self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
118
+
109
119
  if hasattr(config, "target_hidden_size"):
110
- self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size)
120
+ self.hidden_size_in = config.target_hidden_size
111
121
  else:
112
- self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
122
+ self.hidden_size_in = config.hidden_size
123
+
124
+ self.fc = torch.nn.Linear(
125
+ self.hidden_size_in * 3,
126
+ config.hidden_size,
127
+ bias=getattr(config, "bias", False),
128
+ )
129
+
130
+ self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
113
131
 
114
132
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
115
133
 
@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
179
197
 
180
198
  self.logits_processor = LogitsProcessor(config)
181
199
  self.capture_aux_hidden_states = True
200
+ self.hot_token_id = None
201
+
202
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
203
+ params_dict = dict(self.named_parameters())
204
+ # Define the parameter mapping for stacked parameters
205
+ stacked_params_mapping = [
206
+ # (param_name, shard_name, shard_id)
207
+ (".qkv_proj", ".q_proj", "q"),
208
+ (".qkv_proj", ".k_proj", "k"),
209
+ (".qkv_proj", ".v_proj", "v"),
210
+ (".gate_up_proj", ".gate_proj", 0),
211
+ (".gate_up_proj", ".up_proj", 1),
212
+ ]
182
213
 
183
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
184
214
  for name, loaded_weight in weights:
185
215
  if "d2t" in name:
186
216
  # d2t stores diffs between draft id and target id
187
217
  self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
188
-
189
- if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
190
- new_name = f"model.{name}"
191
- super().load_weights([(new_name, loaded_weight)])
192
- elif "lm_head" in name:
193
- super().load_weights([(name, loaded_weight)])
218
+ continue
219
+
220
+ if "t2d" in name:
221
+ continue
222
+
223
+ for param_name, weight_name, shard_id in stacked_params_mapping:
224
+ if weight_name not in name:
225
+ continue
226
+ name = name.replace(weight_name, param_name)
227
+ param_name = f"model.{name}" if name not in params_dict else name
228
+ if param_name in params_dict:
229
+ param = params_dict[param_name]
230
+ weight_loader = getattr(
231
+ param, "weight_loader", default_weight_loader
232
+ )
233
+ weight_loader(param, loaded_weight, shard_id)
234
+ break
235
+ else:
236
+ # Handle regular parameters
237
+ param_name = name if name in params_dict else f"model.{name}"
238
+ if param_name in params_dict:
239
+ param = params_dict[param_name]
240
+ weight_loader = getattr(
241
+ param, "weight_loader", default_weight_loader
242
+ )
243
+ weight_loader(param, loaded_weight)
194
244
 
195
245
  def get_hot_token_id(self):
196
246
  return self.hot_token_id
@@ -41,16 +41,16 @@ from sglang.srt.managers.schedule_batch import (
41
41
  MultimodalDataItem,
42
42
  MultimodalInputs,
43
43
  )
44
- from sglang.srt.mm_utils import (
45
- get_anyres_image_grid_shape,
46
- unpad_image,
47
- unpad_image_shape,
48
- )
49
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
45
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
46
  from sglang.srt.models.llama import LlamaForCausalLM
52
47
  from sglang.srt.models.mistral import MistralForCausalLM
53
48
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
49
+ from sglang.srt.multimodal.mm_utils import (
50
+ get_anyres_image_grid_shape,
51
+ unpad_image,
52
+ unpad_image_shape,
53
+ )
54
54
  from sglang.srt.utils import add_prefix, flatten_nested_list, logger
55
55
 
56
56
 
@@ -32,7 +32,7 @@ from transformers.activations import ACT2FN
32
32
  from transformers.cache_utils import DynamicCache, EncoderDecoderCache
33
33
  from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
34
34
  from transformers.models.whisper.modeling_whisper import (
35
- WHISPER_ATTENTION_CLASSES,
35
+ WhisperAttention,
36
36
  WhisperConfig,
37
37
  WhisperEncoder,
38
38
  )
@@ -1090,7 +1090,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
1090
1090
  def __init__(self, config: WhisperConfig, layer_idx: int = None):
1091
1091
  super().__init__()
1092
1092
  self.embed_dim = config.d_model
1093
- self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
1093
+ self.self_attn = WhisperAttention(
1094
1094
  embed_dim=self.embed_dim,
1095
1095
  num_heads=config.encoder_attention_heads,
1096
1096
  dropout=config.attention_dropout,
@@ -13,7 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Inference-only Mistral model."""
15
15
 
16
- from typing import List, Union
16
+ from typing import List
17
17
 
18
18
  import torch
19
19
  from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector