sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +59 -11
  10. sglang/srt/disaggregation/mini_lb.py +45 -8
  11. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  12. sglang/srt/disaggregation/prefill.py +24 -9
  13. sglang/srt/entrypoints/http_server.py +8 -2
  14. sglang/srt/function_call_parser.py +77 -5
  15. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  16. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  17. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  18. sglang/srt/layers/attention/vision.py +2 -0
  19. sglang/srt/layers/layernorm.py +38 -16
  20. sglang/srt/layers/logits_processor.py +2 -2
  21. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  25. sglang/srt/layers/pooler.py +6 -0
  26. sglang/srt/layers/quantization/awq.py +5 -1
  27. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  28. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  29. sglang/srt/layers/radix_attention.py +13 -3
  30. sglang/srt/layers/rotary_embedding.py +170 -126
  31. sglang/srt/managers/data_parallel_controller.py +10 -3
  32. sglang/srt/managers/io_struct.py +7 -0
  33. sglang/srt/managers/mm_utils.py +85 -28
  34. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  35. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  36. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  37. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  38. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  40. sglang/srt/managers/schedule_batch.py +29 -12
  41. sglang/srt/managers/scheduler.py +31 -20
  42. sglang/srt/managers/tokenizer_manager.py +5 -1
  43. sglang/srt/mem_cache/memory_pool.py +87 -0
  44. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  45. sglang/srt/model_executor/forward_batch_info.py +51 -95
  46. sglang/srt/model_executor/model_runner.py +11 -24
  47. sglang/srt/models/deepseek.py +12 -2
  48. sglang/srt/models/deepseek_nextn.py +101 -6
  49. sglang/srt/models/deepseek_v2.py +144 -70
  50. sglang/srt/models/deepseek_vl2.py +9 -4
  51. sglang/srt/models/gemma3_causal.py +1 -1
  52. sglang/srt/models/llama4.py +0 -1
  53. sglang/srt/models/minicpmo.py +5 -1
  54. sglang/srt/models/mllama4.py +2 -2
  55. sglang/srt/models/qwen2_5_vl.py +3 -6
  56. sglang/srt/models/qwen2_vl.py +3 -7
  57. sglang/srt/models/roberta.py +178 -0
  58. sglang/srt/openai_api/adapter.py +18 -8
  59. sglang/srt/server_args.py +15 -22
  60. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  61. sglang/srt/torch_memory_saver_adapter.py +10 -1
  62. sglang/srt/utils.py +2 -1
  63. sglang/test/runners.py +6 -13
  64. sglang/test/test_utils.py +36 -18
  65. sglang/version.py +1 -1
  66. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
  67. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
  68. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  69. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  70. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -80,7 +80,15 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
80
80
  from sglang.srt.managers.schedule_batch import global_server_args_dict
81
81
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
82
82
  from sglang.srt.model_loader.weight_utils import default_weight_loader
83
- from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
83
+ from sglang.srt.utils import (
84
+ BumpAllocator,
85
+ DeepEPMode,
86
+ add_prefix,
87
+ get_bool_env_var,
88
+ get_int_env_var,
89
+ is_cuda,
90
+ is_hip,
91
+ )
84
92
 
85
93
  _is_hip = is_hip()
86
94
  _is_cuda = is_cuda()
@@ -315,12 +323,6 @@ class DeepseekV2MoE(nn.Module):
315
323
  self, hidden_states: torch.Tensor, forward_mode: ForwardMode
316
324
  ) -> torch.Tensor:
317
325
  shared_output = None
318
- topk_idx = torch.full(
319
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
320
- )
321
- topk_weights = torch.empty(
322
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
323
- )
324
326
  if (
325
327
  forward_mode is not None
326
328
  and not forward_mode.is_idle()
@@ -340,6 +342,13 @@ class DeepseekV2MoE(nn.Module):
340
342
  correction_bias=self.correction_bias,
341
343
  routed_scaling_factor=self.routed_scaling_factor,
342
344
  )
345
+ else:
346
+ topk_idx = torch.full(
347
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
348
+ )
349
+ topk_weights = torch.empty(
350
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
351
+ )
343
352
  if self.ep_size > 1:
344
353
  # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
345
354
  (
@@ -435,12 +444,12 @@ class DeepseekV2AttentionMLA(nn.Module):
435
444
 
436
445
  # For tensor parallel attention
437
446
  if self.q_lora_rank is not None:
438
- self.q_a_proj = ReplicatedLinear(
447
+ self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
439
448
  self.hidden_size,
440
- self.q_lora_rank,
449
+ self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
441
450
  bias=False,
442
451
  quant_config=quant_config,
443
- prefix=add_prefix("q_a_proj", prefix),
452
+ prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
444
453
  )
445
454
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
446
455
  self.q_b_proj = ColumnParallelLinear(
@@ -462,6 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
462
471
  tp_rank=attn_tp_rank,
463
472
  tp_size=attn_tp_size,
464
473
  )
474
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
475
+ self.hidden_size,
476
+ self.kv_lora_rank + self.qk_rope_head_dim,
477
+ bias=False,
478
+ quant_config=quant_config,
479
+ prefix=add_prefix("kv_a_proj_with_mqa", prefix),
480
+ )
481
+
465
482
  self.kv_b_proj = ColumnParallelLinear(
466
483
  self.kv_lora_rank,
467
484
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -482,14 +499,6 @@ class DeepseekV2AttentionMLA(nn.Module):
482
499
  tp_rank=attn_tp_rank,
483
500
  tp_size=attn_tp_size,
484
501
  )
485
-
486
- self.kv_a_proj_with_mqa = ReplicatedLinear(
487
- self.hidden_size,
488
- self.kv_lora_rank + self.qk_rope_head_dim,
489
- bias=False,
490
- quant_config=quant_config,
491
- prefix=add_prefix("kv_a_proj_with_mqa", prefix),
492
- )
493
502
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
494
503
 
495
504
  if rope_scaling:
@@ -549,10 +558,14 @@ class DeepseekV2AttentionMLA(nn.Module):
549
558
  "disable_chunked_prefix_cache"
550
559
  ]
551
560
  self.attention_backend = global_server_args_dict["attention_backend"]
552
- self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
561
+ self.rocm_fused_decode_mla = get_bool_env_var(
562
+ "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
563
+ )
553
564
 
554
565
  # TODO: Design a finer way to determine the threshold
555
- self.chunked_prefix_cache_threshold = 8192
566
+ self.chunked_prefix_cache_threshold = get_int_env_var(
567
+ "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
568
+ )
556
569
 
557
570
  def dispatch_attn_forward_method(
558
571
  self, forward_batch: ForwardBatch
@@ -571,13 +584,17 @@ class DeepseekV2AttentionMLA(nn.Module):
571
584
  return AttnForwardMethod.MLA
572
585
  elif self.attention_backend == "fa3":
573
586
  # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
587
+ if forward_batch.extend_prefix_lens_cpu is not None:
588
+ sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
574
589
  if (
575
590
  forward_batch.forward_mode.is_extend()
576
591
  and not self.disable_chunked_prefix_cache
577
592
  and not forward_batch.forward_mode.is_target_verify()
578
593
  and not forward_batch.forward_mode.is_draft_extend()
579
- and sum(forward_batch.extend_prefix_lens_cpu)
580
- >= self.chunked_prefix_cache_threshold
594
+ and (
595
+ sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
596
+ or sum_extend_prefix_lens == 0
597
+ )
581
598
  ):
582
599
  return AttnForwardMethod.MHA_CHUNKED_KV
583
600
  else:
@@ -640,15 +657,18 @@ class DeepseekV2AttentionMLA(nn.Module):
640
657
  forward_batch: ForwardBatch,
641
658
  ) -> torch.Tensor:
642
659
  if self.q_lora_rank is not None:
643
- q = self.q_a_proj(hidden_states)[0]
660
+ q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
661
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
662
+ )
644
663
  q = self.q_a_layernorm(q)
645
664
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
646
665
  else:
647
666
  q = self.q_proj(hidden_states)[0].view(
648
667
  -1, self.num_local_heads, self.qk_head_dim
649
668
  )
669
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
670
+
650
671
  _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
651
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
652
672
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
653
673
  latent_cache = latent_cache.unsqueeze(1)
654
674
  kv_a = self.kv_a_layernorm(kv_a.contiguous())
@@ -682,18 +702,17 @@ class DeepseekV2AttentionMLA(nn.Module):
682
702
  forward_batch: ForwardBatch,
683
703
  zero_allocator: BumpAllocator,
684
704
  ) -> torch.Tensor:
685
- q_len = hidden_states.shape[0]
686
- q_input = hidden_states.new_empty(
687
- q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
688
- )
689
705
  if self.q_lora_rank is not None:
690
- q = self.q_a_proj(hidden_states)[0]
706
+ q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
707
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
708
+ )
691
709
  q = self.q_a_layernorm(q)
692
710
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
693
711
  else:
694
712
  q = self.q_proj(hidden_states)[0].view(
695
713
  -1, self.num_local_heads, self.qk_head_dim
696
714
  )
715
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
697
716
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
698
717
 
699
718
  if self.use_deep_gemm_bmm:
@@ -729,20 +748,23 @@ class DeepseekV2AttentionMLA(nn.Module):
729
748
  )
730
749
  else:
731
750
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
732
- q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
733
751
 
734
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
735
- v_input = latent_cache[..., : self.kv_lora_rank]
736
- v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
737
- k_input = latent_cache.unsqueeze(1)
738
- k_input[..., : self.kv_lora_rank] = v_input
739
- k_pe = k_input[..., self.kv_lora_rank :]
752
+ q_nope_out = q_nope_out.transpose(0, 1)
753
+
754
+ k_nope = latent_cache[..., : self.kv_lora_rank]
755
+ k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
756
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
740
757
 
741
758
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
742
- q_input[..., self.kv_lora_rank :] = q_pe
743
- k_input[..., self.kv_lora_rank :] = k_pe
744
759
 
745
- attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
760
+ if self.attention_backend == "fa3":
761
+ attn_output = self.attn_mqa(
762
+ q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
763
+ )
764
+ else:
765
+ q = torch.cat([q_nope_out, q_pe], dim=-1)
766
+ k = torch.cat([k_nope, k_pe], dim=-1)
767
+ attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
746
768
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
747
769
 
748
770
  if self.use_deep_gemm_bmm:
@@ -802,13 +824,16 @@ class DeepseekV2AttentionMLA(nn.Module):
802
824
  q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
803
825
  )
804
826
  if self.q_lora_rank is not None:
805
- q = self.q_a_proj(hidden_states)[0]
827
+ q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
828
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
829
+ )
806
830
  q = self.q_a_layernorm(q)
807
831
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
808
832
  else:
809
833
  q = self.q_proj(hidden_states)[0].view(
810
834
  -1, self.num_local_heads, self.qk_head_dim
811
835
  )
836
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
812
837
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
813
838
 
814
839
  if self.w_kc.dtype == torch.float8_e4m3fnuz:
@@ -829,8 +854,6 @@ class DeepseekV2AttentionMLA(nn.Module):
829
854
  else:
830
855
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
831
856
  q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
832
-
833
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
834
857
  v_input = latent_cache[..., : self.kv_lora_rank]
835
858
  v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
836
859
  k_input = latent_cache.unsqueeze(1)
@@ -1001,15 +1024,17 @@ class DeepseekV2AttentionMLA(nn.Module):
1001
1024
 
1002
1025
  # First do normal mha forward to get output for extended part
1003
1026
  if self.q_lora_rank is not None:
1004
- q = self.q_a_proj(hidden_states)[0]
1027
+ q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
1028
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1029
+ )
1005
1030
  q = self.q_a_layernorm(q)
1006
1031
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1007
1032
  else:
1008
1033
  q = self.q_proj(hidden_states)[0].view(
1009
1034
  -1, self.num_local_heads, self.qk_head_dim
1010
1035
  )
1036
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1011
1037
  _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1012
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1013
1038
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1014
1039
  latent_cache = latent_cache.unsqueeze(1)
1015
1040
  kv_a = self.kv_a_layernorm(kv_a.contiguous())
@@ -1414,11 +1439,27 @@ class DeepseekV2ForCausalLM(nn.Module):
1414
1439
  self.config = config
1415
1440
  self.tp_size = get_tensor_model_parallel_world_size()
1416
1441
  self.quant_config = quant_config
1442
+ self.determine_n_share_experts_fusion()
1443
+ self.model = DeepseekV2Model(
1444
+ config, quant_config, prefix=add_prefix("model", prefix)
1445
+ )
1446
+ self.lm_head = ParallelLMHead(
1447
+ config.vocab_size,
1448
+ config.hidden_size,
1449
+ quant_config=quant_config,
1450
+ prefix=add_prefix("lm_head", prefix),
1451
+ )
1452
+ self.logits_processor = LogitsProcessor(config)
1453
+ self.dp_size = get_attention_dp_size()
1454
+
1455
+ def determine_n_share_experts_fusion(
1456
+ self, architecture: str = "DeepseekV3ForCausalLM"
1457
+ ):
1417
1458
  self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1418
1459
  if self.n_share_experts_fusion > 0:
1419
1460
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1420
1461
  if (
1421
- self.config.architectures[0] != "DeepseekV3ForCausalLM"
1462
+ self.config.architectures[0] != architecture
1422
1463
  or self.config.n_routed_experts != 256
1423
1464
  ):
1424
1465
  self.n_share_experts_fusion = 0
@@ -1433,7 +1474,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1433
1474
  elif self.n_share_experts_fusion == 0:
1434
1475
  if (
1435
1476
  torch.cuda.get_device_capability("cuda") >= (9, 0)
1436
- and self.config.architectures[0] == "DeepseekV3ForCausalLM"
1477
+ and self.config.architectures[0] == architecture
1437
1478
  and self.config.n_routed_experts == 256
1438
1479
  and (not global_server_args_dict["enable_deepep_moe"])
1439
1480
  ):
@@ -1443,18 +1484,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1443
1484
  "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
1444
1485
  )
1445
1486
 
1446
- self.model = DeepseekV2Model(
1447
- config, quant_config, prefix=add_prefix("model", prefix)
1448
- )
1449
- self.lm_head = ParallelLMHead(
1450
- config.vocab_size,
1451
- config.hidden_size,
1452
- quant_config=quant_config,
1453
- prefix=add_prefix("lm_head", prefix),
1454
- )
1455
- self.logits_processor = LogitsProcessor(config)
1456
- self.dp_size = get_attention_dp_size()
1457
-
1458
1487
  def get_input_embeddings(self) -> nn.Embedding:
1459
1488
  return self.model.embed_tokens
1460
1489
 
@@ -1592,7 +1621,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1592
1621
  if self.n_share_experts_fusion > 0:
1593
1622
  weights_list = list(weights)
1594
1623
  weights_dict = dict(weights_list)
1595
- if self.quant_config.get_name() == "w8a8_int8":
1624
+ if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1596
1625
  suffix_list = [
1597
1626
  "down_proj.weight",
1598
1627
  "down_proj.weight_scale",
@@ -1620,11 +1649,11 @@ class DeepseekV2ForCausalLM(nn.Module):
1620
1649
  desc=f"Cloning {self.n_share_experts_fusion} "
1621
1650
  "replicas of the shared expert into MoE",
1622
1651
  ):
1623
- for num_repeat in range(self.n_share_experts_fusion):
1624
- for suffix in suffix_list:
1625
- shared_expert_weight_name = (
1626
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
1627
- )
1652
+ for suffix in suffix_list:
1653
+ shared_expert_weight_name = (
1654
+ f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
1655
+ )
1656
+ for num_repeat in range(self.n_share_experts_fusion):
1628
1657
  weights_list.append(
1629
1658
  (
1630
1659
  f"model.layers.{moe_layer}."
@@ -1634,7 +1663,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1634
1663
  weights_dict[shared_expert_weight_name],
1635
1664
  )
1636
1665
  )
1637
- names_to_remove += [shared_expert_weight_name]
1666
+ names_to_remove += [shared_expert_weight_name]
1638
1667
  weights = [w for w in weights_list if w[0] not in names_to_remove]
1639
1668
 
1640
1669
  # Params for weights, fp8 weight scales, fp8 activation scales
@@ -1651,6 +1680,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1651
1680
  num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
1652
1681
  )
1653
1682
 
1683
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
1684
+ fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
1685
+ self.config.q_lora_rank is not None
1686
+ )
1687
+ cached_a_proj = {} if fuse_qkv_a_proj else None
1688
+
1654
1689
  params_dict = dict(self.named_parameters())
1655
1690
  for name, loaded_weight in weights:
1656
1691
  # TODO(HandH1998): Modify it when nextn is supported.
@@ -1706,11 +1741,50 @@ class DeepseekV2ForCausalLM(nn.Module):
1706
1741
  if name.endswith(".bias") and name not in params_dict:
1707
1742
  continue
1708
1743
 
1709
- param = params_dict[name]
1710
- weight_loader = getattr(
1711
- param, "weight_loader", default_weight_loader
1712
- )
1713
- weight_loader(param, loaded_weight)
1744
+ if fuse_qkv_a_proj and (
1745
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
1746
+ ):
1747
+ cached_a_proj[name] = loaded_weight
1748
+ q_a_proj_name = (
1749
+ name
1750
+ if "q_a_proj" in name
1751
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
1752
+ )
1753
+ kv_a_proj_name = (
1754
+ name
1755
+ if "kv_a_proj_with_mqa" in name
1756
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
1757
+ )
1758
+
1759
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
1760
+ if (
1761
+ q_a_proj_name in cached_a_proj
1762
+ and kv_a_proj_name in cached_a_proj
1763
+ ):
1764
+
1765
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
1766
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
1767
+ fused_weight = torch.cat(
1768
+ [q_a_proj_weight, kv_a_proj_weight], dim=0
1769
+ )
1770
+
1771
+ param_name = name.replace(
1772
+ "q_a_proj", "fused_qkv_a_proj_with_mqa"
1773
+ )
1774
+ param = params_dict[param_name]
1775
+
1776
+ weight_loader = getattr(
1777
+ param, "weight_loader", default_weight_loader
1778
+ )
1779
+ weight_loader(param, fused_weight)
1780
+ cached_a_proj.pop(q_a_proj_name)
1781
+ cached_a_proj.pop(kv_a_proj_name)
1782
+ else:
1783
+ param = params_dict[name]
1784
+ weight_loader = getattr(
1785
+ param, "weight_loader", default_weight_loader
1786
+ )
1787
+ weight_loader(param, loaded_weight)
1714
1788
 
1715
1789
  self.post_load_weights()
1716
1790
 
@@ -12,12 +12,13 @@ from sglang.srt.configs.deepseekvl2 import (
12
12
  from sglang.srt.layers.linear import ReplicatedLinear
13
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
14
  from sglang.srt.managers.mm_utils import (
15
- MultiModalityDataPaddingPatternImageTokens,
15
+ MultiModalityDataPaddingPatternMultimodalTokens,
16
16
  general_mm_embed_routine,
17
17
  )
18
18
  from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
19
19
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
20
  from sglang.srt.model_loader.weight_utils import default_weight_loader
21
+ from sglang.srt.models.deepseek import DeepseekForCausalLM
21
22
  from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
22
23
 
23
24
 
@@ -189,7 +190,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
189
190
 
190
191
  # ----------- language model ------------
191
192
  language_config = config.language_config
192
- self.language_model = DeepseekV2ForCausalLM(language_config)
193
+ if language_config.use_mla:
194
+ self.language_model = DeepseekV2ForCausalLM(language_config)
195
+ else:
196
+ # deepseek-vl2-tiny forbids mla
197
+ self.language_model = DeepseekForCausalLM(language_config)
193
198
 
194
199
  def _init_vision_module(
195
200
  self, vision_config, quant_config: Optional[QuantizationConfig]
@@ -249,8 +254,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
249
254
  weights_loader(param, loaded_weight)
250
255
 
251
256
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
252
- helper = MultiModalityDataPaddingPatternImageTokens(
253
- image_token_id=image_inputs.im_token_id
257
+ helper = MultiModalityDataPaddingPatternMultimodalTokens(
258
+ [image_inputs.im_token_id]
254
259
  )
255
260
  return helper.pad_input_tokens(input_ids, image_inputs)
256
261
 
@@ -189,7 +189,7 @@ class Gemma3Attention(nn.Module):
189
189
  self.scaling,
190
190
  num_kv_heads=self.num_kv_heads,
191
191
  layer_id=layer_id,
192
- logit_cap=getattr(self.config, "attn_logit_softcapping", None),
192
+ logit_cap=0.0,
193
193
  # Module must also define `get_attention_sliding_window_size` to correctly initialize
194
194
  # attention backend in `ForwardBatch`.
195
195
  sliding_window_size=self.sliding_window,
@@ -260,7 +260,6 @@ class Llama4Attention(nn.Module):
260
260
  if self.rotary_emb is not None:
261
261
  q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
262
262
  q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
263
- assert (q_out_unused is q_view) and (k_out_unused is k_view)
264
263
  del q_view, k_view, q_out_unused, k_out_unused
265
264
 
266
265
  if self.qk_norm is not None:
@@ -43,6 +43,7 @@ from sglang.srt.managers.mm_utils import (
43
43
  general_mm_embed_routine,
44
44
  )
45
45
  from sglang.srt.managers.schedule_batch import (
46
+ Modality,
46
47
  MultimodalDataItem,
47
48
  MultimodalInputs,
48
49
  flatten_nested_list,
@@ -1834,7 +1835,10 @@ class MiniCPMO(MiniCPMBaseModel):
1834
1835
  language_model=self.llm,
1835
1836
  image_data_embedding_func=self.get_image_feature,
1836
1837
  audio_data_embedding_func=self.get_audio_feature,
1837
- placeholder_token_ids=placeholder_token_ids,
1838
+ placeholder_tokens={
1839
+ Modality.IMAGE: placeholder_token_ids,
1840
+ Modality.AUDIO: placeholder_token_ids,
1841
+ },
1838
1842
  positions=positions,
1839
1843
  )
1840
1844
  return hidden_states
@@ -10,7 +10,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
10
10
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
11
11
  from sglang.srt.layers.quantization import QuantizationConfig
12
12
  from sglang.srt.managers.mm_utils import (
13
- MultiModalityDataPaddingPatternImageTokens,
13
+ MultiModalityDataPaddingPatternMultimodalTokens,
14
14
  general_mm_embed_routine,
15
15
  )
16
16
  from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
@@ -53,7 +53,7 @@ class Llama4ForConditionalGeneration(nn.Module):
53
53
  # Get all special token IDs
54
54
  im_token_id: int = mm_inputs.im_token_id
55
55
 
56
- pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
56
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
57
57
  return pattern.pad_input_tokens(input_ids, mm_inputs)
58
58
 
59
59
  def get_image_feature(
@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
49
49
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
50
50
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
51
51
  from sglang.srt.managers.mm_utils import (
52
- MultiModalityDataPaddingPatternTokenPairs,
52
+ MultiModalityDataPaddingPatternMultimodalTokens,
53
53
  general_mm_embed_routine,
54
54
  )
55
55
  from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
488
488
 
489
489
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
490
490
  # Get all special token IDs
491
- im_start_id: int = mm_inputs.im_start_id
492
- im_end_id: int = mm_inputs.im_end_id
493
-
494
- media_token_pairs = [(im_start_id, im_end_id)]
495
- pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
491
+ im_token_id: int = mm_inputs.im_token_id
492
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
496
493
  return pattern.pad_input_tokens(input_ids, mm_inputs)
497
494
 
498
495
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
@@ -42,7 +42,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
44
44
  from sglang.srt.managers.mm_utils import (
45
- MultiModalityDataPaddingPatternTokenPairs,
45
+ MultiModalityDataPaddingPatternMultimodalTokens,
46
46
  general_mm_embed_routine,
47
47
  )
48
48
  from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
@@ -490,15 +490,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
490
490
  self.logits_processor = LogitsProcessor(config)
491
491
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
492
492
 
493
- # Use grid_t * grid_w * grid_h to pad tokens for each image
494
- # add replaced padding by unique image hash
495
493
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
496
494
  # Get all special token IDs
497
- im_start_id: int = mm_inputs.im_start_id
498
- im_end_id: int = mm_inputs.im_end_id
495
+ im_token_id: int = mm_inputs.im_token_id
499
496
 
500
- media_token_pairs = [(im_start_id, im_end_id)]
501
- pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
497
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
502
498
  return pattern.pad_input_tokens(input_ids, mm_inputs)
503
499
 
504
500
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: