sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@
18
18
 
19
19
  import logging
20
20
  import os
21
+ from enum import IntEnum, auto
21
22
  from typing import Any, Dict, Iterable, Optional, Tuple
22
23
 
23
24
  import torch
@@ -50,13 +51,13 @@ from sglang.srt.layers.linear import (
50
51
  )
51
52
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
53
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
53
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
54
54
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
55
  from sglang.srt.layers.moe.topk import select_experts
56
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
+ from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
57
58
  from sglang.srt.layers.quantization.fp8_utils import (
58
59
  block_quant_to_tensor_quant,
59
- input_to_float8,
60
+ channel_quant_to_tensor_quant,
60
61
  normalize_e4m3fn_to_e4m3fnuz,
61
62
  )
62
63
  from sglang.srt.layers.quantization.int8_utils import (
@@ -78,7 +79,9 @@ _is_hip = is_hip()
78
79
  _is_cuda = is_cuda()
79
80
 
80
81
  if _is_cuda:
81
- from sgl_kernel import awq_dequantize, bmm_fp8
82
+ from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
83
+
84
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
82
85
  else:
83
86
  from vllm import _custom_ops as ops
84
87
 
@@ -92,6 +95,19 @@ expert_distribution_recorder = ExpertDistributionRecorder()
92
95
  logger = logging.getLogger(__name__)
93
96
 
94
97
 
98
+ class AttnForwardMethod(IntEnum):
99
+
100
+ # Use multi-head attention
101
+ MHA = auto()
102
+
103
+ # Use absorbed multi-latent attention
104
+ MLA = auto()
105
+
106
+ # Use multi-head attention, but with KV cache chunked.
107
+ # This method can avoid OOM when prefix lengths are long.
108
+ MHA_CHUNKED_KV = auto()
109
+
110
+
95
111
  class DeepseekV2MLP(nn.Module):
96
112
  def __init__(
97
113
  self,
@@ -178,7 +194,6 @@ class DeepseekV2MoE(nn.Module):
178
194
  else 0
179
195
  )
180
196
 
181
- self.routed_scaling_factor = config.routed_scaling_factor
182
197
  if self.tp_size > config.n_routed_experts:
183
198
  raise ValueError(
184
199
  f"Tensor parallel size {self.tp_size} is greater than "
@@ -278,10 +293,7 @@ class DeepseekV2MoE(nn.Module):
278
293
  return self.forward_deepep(hidden_states, forward_mode)
279
294
 
280
295
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
281
- if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
282
- shared_output = self.shared_experts(hidden_states)
283
- else:
284
- shared_output = None
296
+ shared_output = self._forward_shared_experts(hidden_states)
285
297
  # router_logits: (num_tokens, n_experts)
286
298
  router_logits = self.gate(hidden_states)
287
299
  final_hidden_states = (
@@ -311,8 +323,7 @@ class DeepseekV2MoE(nn.Module):
311
323
  ):
312
324
  # router_logits: (num_tokens, n_experts)
313
325
  router_logits = self.gate(hidden_states)
314
- if self.n_shared_experts is not None:
315
- shared_output = self.shared_experts(hidden_states)
326
+ shared_output = self._forward_shared_experts(hidden_states)
316
327
  topk_weights, topk_idx = select_experts(
317
328
  hidden_states=hidden_states,
318
329
  router_logits=router_logits,
@@ -324,6 +335,7 @@ class DeepseekV2MoE(nn.Module):
324
335
  correction_bias=self.correction_bias,
325
336
  )
326
337
  if self.ep_size > 1:
338
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
327
339
  (
328
340
  hidden_states,
329
341
  topk_idx,
@@ -336,19 +348,15 @@ class DeepseekV2MoE(nn.Module):
336
348
  hidden_states,
337
349
  topk_idx,
338
350
  topk_weights,
339
- self.num_experts,
340
351
  forward_mode=forward_mode,
341
352
  )
342
- final_hidden_states = (
343
- self.experts(
344
- hidden_states=hidden_states,
345
- reorder_topk_ids=reorder_topk_ids,
346
- seg_indptr=seg_indptr,
347
- masked_m=masked_m,
348
- expected_m=expected_m,
349
- forward_mode=forward_mode,
350
- )
351
- * self.routed_scaling_factor
353
+ final_hidden_states = self.experts(
354
+ hidden_states=hidden_states,
355
+ reorder_topk_ids=reorder_topk_ids,
356
+ seg_indptr=seg_indptr,
357
+ masked_m=masked_m,
358
+ expected_m=expected_m,
359
+ forward_mode=forward_mode,
352
360
  )
353
361
  if self.ep_size > 1:
354
362
  final_hidden_states = self.deepep_dispatcher.combine(
@@ -357,11 +365,19 @@ class DeepseekV2MoE(nn.Module):
357
365
  topk_weights,
358
366
  forward_mode,
359
367
  )
368
+ final_hidden_states *= self.routed_scaling_factor
369
+
360
370
  if shared_output is not None:
361
371
  final_hidden_states = final_hidden_states + shared_output
362
372
 
363
373
  return final_hidden_states
364
374
 
375
+ def _forward_shared_experts(self, hidden_states):
376
+ if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
377
+ return self.shared_experts(hidden_states)
378
+ else:
379
+ return None
380
+
365
381
 
366
382
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
367
383
  import math
@@ -489,6 +505,7 @@ class DeepseekV2Attention(nn.Module):
489
505
  self.scaling,
490
506
  num_kv_heads=self.num_local_heads,
491
507
  layer_id=layer_id,
508
+ quant_config=quant_config,
492
509
  prefix=add_prefix("attn", prefix),
493
510
  )
494
511
 
@@ -669,6 +686,7 @@ class DeepseekV2AttentionMLA(nn.Module):
669
686
  num_kv_heads=1,
670
687
  layer_id=layer_id,
671
688
  v_head_dim=self.kv_lora_rank,
689
+ quant_config=quant_config,
672
690
  prefix=add_prefix("attn_mqa", prefix),
673
691
  )
674
692
 
@@ -679,6 +697,7 @@ class DeepseekV2AttentionMLA(nn.Module):
679
697
  num_kv_heads=self.num_local_heads,
680
698
  layer_id=layer_id,
681
699
  v_head_dim=self.v_head_dim,
700
+ quant_config=quant_config,
682
701
  prefix=add_prefix("attn_mha", prefix),
683
702
  )
684
703
 
@@ -689,30 +708,54 @@ class DeepseekV2AttentionMLA(nn.Module):
689
708
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
690
709
  "flashinfer_mla_disable_ragged"
691
710
  ]
711
+ self.disable_chunked_prefix_cache = global_server_args_dict[
712
+ "disable_chunked_prefix_cache"
713
+ ]
692
714
  self.attention_backend = global_server_args_dict["attention_backend"]
693
715
  self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
694
716
 
695
- def no_absorb(self, forward_batch: ForwardBatch) -> bool:
717
+ # TODO: Design a finer way to determine the threshold
718
+ self.chunked_prefix_cache_threshold = 8192
719
+
720
+ def dispatch_attn_forward_method(
721
+ self, forward_batch: ForwardBatch
722
+ ) -> AttnForwardMethod:
696
723
  if self.attention_backend == "flashinfer":
697
724
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
698
- return (
725
+ if (
699
726
  not self.flashinfer_mla_disable_ragged
700
727
  and forward_batch.forward_mode.is_extend()
701
728
  and not forward_batch.forward_mode.is_target_verify()
702
729
  and not forward_batch.forward_mode.is_draft_extend()
703
730
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
704
- )
731
+ ):
732
+ return AttnForwardMethod.MHA
733
+ else:
734
+ return AttnForwardMethod.MLA
705
735
  elif self.attention_backend == "fa3":
706
- # Flash Attention: Keep absorbing for all extend/decode
707
- return False
736
+ # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
737
+ if (
738
+ forward_batch.forward_mode.is_extend()
739
+ and not self.disable_chunked_prefix_cache
740
+ and not forward_batch.forward_mode.is_target_verify()
741
+ and not forward_batch.forward_mode.is_draft_extend()
742
+ and sum(forward_batch.extend_prefix_lens_cpu)
743
+ >= self.chunked_prefix_cache_threshold
744
+ ):
745
+ return AttnForwardMethod.MHA_CHUNKED_KV
746
+ else:
747
+ return AttnForwardMethod.MLA
708
748
  else:
709
749
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
710
- return (
750
+ if (
711
751
  forward_batch.forward_mode.is_extend()
712
752
  and not forward_batch.forward_mode.is_target_verify()
713
753
  and not forward_batch.forward_mode.is_draft_extend()
714
754
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
715
- )
755
+ ):
756
+ return AttnForwardMethod.MHA
757
+ else:
758
+ return AttnForwardMethod.MLA
716
759
 
717
760
  def forward(
718
761
  self,
@@ -726,8 +769,14 @@ class DeepseekV2AttentionMLA(nn.Module):
726
769
  ), "short-circuiting allreduce will lead to hangs"
727
770
  return hidden_states
728
771
 
729
- if self.no_absorb(forward_batch):
772
+ attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
773
+
774
+ if attn_forward_method == AttnForwardMethod.MHA:
730
775
  return self.forward_normal(positions, hidden_states, forward_batch)
776
+ elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
777
+ return self.forward_normal_chunked_kv(
778
+ positions, hidden_states, forward_batch
779
+ )
731
780
  else:
732
781
  if _is_hip:
733
782
  if (
@@ -811,8 +860,8 @@ class DeepseekV2AttentionMLA(nn.Module):
811
860
  self.w_kc.to(torch.bfloat16) * self.w_scale,
812
861
  )
813
862
  elif self.w_kc.dtype == torch.float8_e4m3fn:
814
- q_nope_val, q_nope_scale = input_to_float8(
815
- q_nope.transpose(0, 1), torch.float8_e4m3fn
863
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
864
+ q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
816
865
  )
817
866
  q_nope_out = bmm_fp8(
818
867
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -842,8 +891,8 @@ class DeepseekV2AttentionMLA(nn.Module):
842
891
  self.w_vc.to(torch.bfloat16) * self.w_scale,
843
892
  )
844
893
  elif self.w_vc.dtype == torch.float8_e4m3fn:
845
- attn_output_val, attn_output_scale = input_to_float8(
846
- attn_output.transpose(0, 1), torch.float8_e4m3fn
894
+ attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
895
+ attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
847
896
  )
848
897
  attn_bmm_output = bmm_fp8(
849
898
  attn_output_val,
@@ -889,8 +938,8 @@ class DeepseekV2AttentionMLA(nn.Module):
889
938
  self.w_kc.to(torch.bfloat16) * self.w_scale,
890
939
  )
891
940
  elif self.w_kc.dtype == torch.float8_e4m3fn:
892
- q_nope_val, q_nope_scale = input_to_float8(
893
- q_nope.transpose(0, 1), torch.float8_e4m3fn
941
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
942
+ q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
894
943
  )
895
944
  q_nope_out = bmm_fp8(
896
945
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -985,8 +1034,8 @@ class DeepseekV2AttentionMLA(nn.Module):
985
1034
  self.w_vc.to(torch.bfloat16) * self.w_scale,
986
1035
  )
987
1036
  elif self.w_vc.dtype == torch.float8_e4m3fn:
988
- attn_output_val, attn_output_scale = input_to_float8(
989
- attn_output.transpose(0, 1), torch.float8_e4m3fn
1037
+ attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1038
+ attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
990
1039
  )
991
1040
  attn_bmm_output = bmm_fp8(
992
1041
  attn_output_val,
@@ -1002,6 +1051,127 @@ class DeepseekV2AttentionMLA(nn.Module):
1002
1051
 
1003
1052
  return output
1004
1053
 
1054
+ def _chunked_prefix_attn_mha(
1055
+ self,
1056
+ q: torch.Tensor,
1057
+ accum_output: torch.Tensor,
1058
+ accum_lse: torch.Tensor,
1059
+ forward_batch: ForwardBatch,
1060
+ ) -> torch.Tensor:
1061
+
1062
+ assert forward_batch.num_prefix_chunks is not None
1063
+ for i in range(forward_batch.num_prefix_chunks):
1064
+ forward_batch.set_prefix_chunk_idx(i)
1065
+
1066
+ # Fetch latent cache from memory pool with precomputed chunked kv indices
1067
+ latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
1068
+ self.attn_mha.layer_id
1069
+ )
1070
+ latent_cache = latent_cache_buf[
1071
+ forward_batch.prefix_chunk_kv_indices[i]
1072
+ ].contiguous()
1073
+
1074
+ kv_a_normed, k_pe = latent_cache.split(
1075
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
1076
+ )
1077
+ kv_a_normed = kv_a_normed.squeeze(1).contiguous()
1078
+ kv = self.kv_b_proj(kv_a_normed)[0]
1079
+ kv = kv.view(
1080
+ -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
1081
+ )
1082
+ v = kv[..., self.qk_nope_head_dim :]
1083
+ k_nope = kv[..., : self.qk_nope_head_dim]
1084
+
1085
+ k = torch.empty(
1086
+ (
1087
+ k_nope.shape[0],
1088
+ self.num_local_heads,
1089
+ self.qk_nope_head_dim + self.qk_rope_head_dim,
1090
+ ),
1091
+ dtype=v.dtype,
1092
+ device=v.device,
1093
+ )
1094
+ k[..., : self.qk_nope_head_dim] = k_nope
1095
+ k[..., self.qk_nope_head_dim :] = k_pe
1096
+
1097
+ output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1098
+ lse = torch.transpose(lse, 0, 1).contiguous()
1099
+ tmp_output = torch.empty_like(accum_output)
1100
+ tmp_lse = torch.empty_like(accum_lse)
1101
+ merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
1102
+ accum_output, accum_lse = tmp_output, tmp_lse
1103
+
1104
+ return accum_output
1105
+
1106
+ def forward_normal_chunked_kv(
1107
+ self,
1108
+ positions: torch.Tensor,
1109
+ hidden_states: torch.Tensor,
1110
+ forward_batch: ForwardBatch,
1111
+ ) -> torch.Tensor:
1112
+ # In normal mha, the k and v tensors will become overly large when the prefix length is long.
1113
+ # To avoid this, we split the kv cache into chunks and process them one after another.
1114
+ # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
1115
+ # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
1116
+ # will be helpful for understanding the purpose of this function.
1117
+
1118
+ # First do normal mha forward to get output for extended part
1119
+ if self.q_lora_rank is not None:
1120
+ q = self.q_a_proj(hidden_states)[0]
1121
+ q = self.q_a_layernorm(q)
1122
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1123
+ else:
1124
+ q = self.q_proj(hidden_states)[0].view(
1125
+ -1, self.num_local_heads, self.qk_head_dim
1126
+ )
1127
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1128
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1129
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1130
+ latent_cache = latent_cache.unsqueeze(1)
1131
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
1132
+ kv = self.kv_b_proj(kv_a)[0]
1133
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1134
+ k_nope = kv[..., : self.qk_nope_head_dim]
1135
+ v = kv[..., self.qk_nope_head_dim :]
1136
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
1137
+
1138
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1139
+ q[..., self.qk_nope_head_dim :] = q_pe
1140
+ k = torch.empty_like(q)
1141
+ k[..., : self.qk_nope_head_dim] = k_nope
1142
+ k[..., self.qk_nope_head_dim :] = k_pe
1143
+
1144
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1145
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
1146
+
1147
+ # Save latent cache
1148
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1149
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1150
+ )
1151
+
1152
+ # Do mha for extended part without prefix
1153
+ forward_batch.set_attn_attend_prefix_cache(False)
1154
+ attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1155
+ lse = torch.transpose(lse, 0, 1).contiguous()
1156
+
1157
+ # Do mha attention with chunked prefix cache if there are any sequence with prefix
1158
+ if any(forward_batch.extend_prefix_lens_cpu):
1159
+ # Only initialize the info once
1160
+ if forward_batch.num_prefix_chunks is None:
1161
+ forward_batch.prepare_chunked_prefix_cache_info(q.device)
1162
+
1163
+ forward_batch.set_attn_attend_prefix_cache(True)
1164
+ attn_output = self._chunked_prefix_attn_mha(
1165
+ q=q,
1166
+ accum_output=attn_output,
1167
+ accum_lse=lse,
1168
+ forward_batch=forward_batch,
1169
+ )
1170
+
1171
+ attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
1172
+ output, _ = self.o_proj(attn_output)
1173
+ return output
1174
+
1005
1175
 
1006
1176
  class DeepseekV2DecoderLayer(nn.Module):
1007
1177
 
@@ -1407,27 +1577,34 @@ class DeepseekV2ForCausalLM(nn.Module):
1407
1577
  w = self_attn.kv_b_proj.weight
1408
1578
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1409
1579
  # This may affect the accuracy of fp8 model.
1410
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1580
+ if w.dtype in (
1411
1581
  torch.float8_e4m3fn,
1412
1582
  torch.float8_e4m3fnuz,
1413
1583
  ):
1414
- weight_block_size = self.quant_config.weight_block_size
1415
- if weight_block_size is not None:
1416
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1417
- if _is_hip:
1418
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1419
- weight=w,
1420
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1421
- input_scale=None,
1584
+ if hasattr(self.quant_config, "weight_block_size"):
1585
+ weight_block_size = self.quant_config.weight_block_size
1586
+ if weight_block_size is not None:
1587
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1588
+ if _is_hip:
1589
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1590
+ weight=w,
1591
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1592
+ input_scale=None,
1593
+ )
1594
+ else:
1595
+ weight = w
1596
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1597
+
1598
+ w, scale = block_quant_to_tensor_quant(
1599
+ weight, weight_scale, weight_block_size
1422
1600
  )
1423
- else:
1424
- weight = w
1425
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1426
-
1427
- w, scale = block_quant_to_tensor_quant(
1428
- weight, weight_scale, weight_block_size
1429
- )
1601
+ self_attn.w_scale = scale
1602
+ else:
1603
+ weight = w
1604
+ weight_scale = self_attn.kv_b_proj.weight_scale
1605
+ w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1430
1606
  self_attn.w_scale = scale
1607
+
1431
1608
  if w.dtype == torch.int8:
1432
1609
  if hasattr(self.quant_config, "weight_block_size"):
1433
1610
  # block-wise int8 need it
@@ -1466,14 +1643,24 @@ class DeepseekV2ForCausalLM(nn.Module):
1466
1643
  if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
1467
1644
  weights_list = list(weights)
1468
1645
  weights_dict = dict(weights_list)
1469
- suffix_list = [
1470
- "down_proj.weight",
1471
- "down_proj.weight_scale_inv",
1472
- "gate_proj.weight",
1473
- "gate_proj.weight_scale_inv",
1474
- "up_proj.weight",
1475
- "up_proj.weight_scale_inv",
1476
- ]
1646
+ if self.quant_config.get_name() == "w8a8_int8":
1647
+ suffix_list = [
1648
+ "down_proj.weight",
1649
+ "down_proj.weight_scale",
1650
+ "gate_proj.weight",
1651
+ "gate_proj.weight_scale",
1652
+ "up_proj.weight",
1653
+ "up_proj.weight_scale",
1654
+ ]
1655
+ else:
1656
+ suffix_list = [
1657
+ "down_proj.weight",
1658
+ "down_proj.weight_scale_inv",
1659
+ "gate_proj.weight",
1660
+ "gate_proj.weight_scale_inv",
1661
+ "up_proj.weight",
1662
+ "up_proj.weight_scale_inv",
1663
+ ]
1477
1664
  names_to_remove = []
1478
1665
  for moe_layer in tqdm(
1479
1666
  range(
@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
155
155
  self.scaling,
156
156
  num_kv_heads=self.num_kv_heads,
157
157
  layer_id=layer_id,
158
+ quant_config=quant_config,
158
159
  )
159
160
 
160
161
  def forward(
@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
137
137
  self.scaling,
138
138
  num_kv_heads=self.num_kv_heads,
139
139
  layer_id=layer_id,
140
+ quant_config=quant_config,
140
141
  prefix=add_prefix("attn", prefix),
141
142
  )
142
143
 
@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
163
163
  if use_sliding_window
164
164
  else None
165
165
  ),
166
+ quant_config=quant_config,
166
167
  prefix=add_prefix("attn", prefix),
167
168
  )
168
169
 
@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
193
193
  # Module must also define `get_attention_sliding_window_size` to correctly initialize
194
194
  # attention backend in `ForwardBatch`.
195
195
  sliding_window_size=self.sliding_window,
196
+ quant_config=quant_config,
196
197
  prefix=add_prefix("attn", prefix),
197
198
  )
198
199
 
sglang/srt/models/gpt2.py CHANGED
@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
78
78
  scaling=self.scale,
79
79
  num_kv_heads=total_num_heads,
80
80
  layer_id=layer_id,
81
+ quant_config=quant_config,
81
82
  )
82
83
 
83
84
  def forward(
@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
87
87
  scaling=self.scale,
88
88
  num_kv_heads=self.num_kv_heads,
89
89
  layer_id=layer_id,
90
+ quant_config=quant_config,
90
91
  prefix=add_prefix("attn", prefix),
91
92
  )
92
93
 
@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
158
158
  self.scaling,
159
159
  num_kv_heads=self.num_kv_heads,
160
160
  layer_id=layer_id,
161
+ quant_config=quant_config,
161
162
  prefix=add_prefix("attn", prefix),
162
163
  )
163
164
 
sglang/srt/models/grok.py CHANGED
@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
215
215
  num_kv_heads=self.num_kv_heads,
216
216
  layer_id=layer_id,
217
217
  logit_cap=logit_cap,
218
+ quant_config=quant_config,
218
219
  )
219
220
 
220
221
  def forward(
@@ -145,6 +145,7 @@ class InternLM2Attention(nn.Module):
145
145
  self.scaling,
146
146
  self.num_kv_heads,
147
147
  layer_id,
148
+ quant_config=quant_config,
148
149
  prefix=add_prefix("attn", prefix),
149
150
  )
150
151
 
@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
63
63
  hidden_act: str,
64
64
  quant_config: Optional[QuantizationConfig] = None,
65
65
  prefix: str = "",
66
+ reduce_results: bool = True,
66
67
  ) -> None:
67
68
  super().__init__()
68
69
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
78
79
  bias=False,
79
80
  quant_config=quant_config,
80
81
  prefix=add_prefix("down_proj", prefix),
82
+ reduce_results=reduce_results,
81
83
  )
82
84
  if hidden_act != "silu":
83
85
  raise ValueError(
@@ -168,6 +170,7 @@ class LlamaAttention(nn.Module):
168
170
  self.scaling,
169
171
  num_kv_heads=self.num_kv_heads,
170
172
  layer_id=layer_id,
173
+ quant_config=quant_config,
171
174
  prefix=add_prefix("attn", prefix),
172
175
  )
173
176
 
@@ -281,7 +284,7 @@ class LlamaModel(nn.Module):
281
284
  self.layers = make_layers(
282
285
  config.num_hidden_layers,
283
286
  lambda idx, prefix: LlamaDecoderLayer(
284
- config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
287
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
285
288
  ),
286
289
  prefix="model.layers",
287
290
  )
@@ -375,9 +378,7 @@ class LlamaForCausalLM(nn.Module):
375
378
  super().__init__()
376
379
  self.config = config
377
380
  self.quant_config = quant_config
378
- self.model = LlamaModel(
379
- config, quant_config=quant_config, prefix=add_prefix("model", prefix)
380
- )
381
+ self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
381
382
  # Llama 3.2 1B Instruct set tie_word_embeddings to True
382
383
  # Llama 3.1 8B Instruct set tie_word_embeddings to False
383
384
  if self.config.tie_word_embeddings:
@@ -402,6 +403,14 @@ class LlamaForCausalLM(nn.Module):
402
403
 
403
404
  self.capture_aux_hidden_states = False
404
405
 
406
+ def _init_model(
407
+ self,
408
+ config: LlamaConfig,
409
+ quant_config: Optional[QuantizationConfig] = None,
410
+ prefix: str = "",
411
+ ):
412
+ return LlamaModel(config, quant_config=quant_config, prefix=prefix)
413
+
405
414
  @torch.no_grad()
406
415
  def forward(
407
416
  self,