sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +330 -200
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +12 -5
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +25 -13
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +1 -0
- sglang/srt/layers/radix_attention.py +13 -1
- sglang/srt/layers/rotary_embedding.py +12 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +48 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +1 -0
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -18,6 +18,7 @@
|
|
18
18
|
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
+
from enum import IntEnum, auto
|
21
22
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
23
|
|
23
24
|
import torch
|
@@ -50,13 +51,13 @@ from sglang.srt.layers.linear import (
|
|
50
51
|
)
|
51
52
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
53
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
53
|
-
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
54
54
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
55
55
|
from sglang.srt.layers.moe.topk import select_experts
|
56
56
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
57
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
|
57
58
|
from sglang.srt.layers.quantization.fp8_utils import (
|
58
59
|
block_quant_to_tensor_quant,
|
59
|
-
|
60
|
+
channel_quant_to_tensor_quant,
|
60
61
|
normalize_e4m3fn_to_e4m3fnuz,
|
61
62
|
)
|
62
63
|
from sglang.srt.layers.quantization.int8_utils import (
|
@@ -78,7 +79,9 @@ _is_hip = is_hip()
|
|
78
79
|
_is_cuda = is_cuda()
|
79
80
|
|
80
81
|
if _is_cuda:
|
81
|
-
from sgl_kernel import awq_dequantize, bmm_fp8
|
82
|
+
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
83
|
+
|
84
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
82
85
|
else:
|
83
86
|
from vllm import _custom_ops as ops
|
84
87
|
|
@@ -92,6 +95,19 @@ expert_distribution_recorder = ExpertDistributionRecorder()
|
|
92
95
|
logger = logging.getLogger(__name__)
|
93
96
|
|
94
97
|
|
98
|
+
class AttnForwardMethod(IntEnum):
|
99
|
+
|
100
|
+
# Use multi-head attention
|
101
|
+
MHA = auto()
|
102
|
+
|
103
|
+
# Use absorbed multi-latent attention
|
104
|
+
MLA = auto()
|
105
|
+
|
106
|
+
# Use multi-head attention, but with KV cache chunked.
|
107
|
+
# This method can avoid OOM when prefix lengths are long.
|
108
|
+
MHA_CHUNKED_KV = auto()
|
109
|
+
|
110
|
+
|
95
111
|
class DeepseekV2MLP(nn.Module):
|
96
112
|
def __init__(
|
97
113
|
self,
|
@@ -178,7 +194,6 @@ class DeepseekV2MoE(nn.Module):
|
|
178
194
|
else 0
|
179
195
|
)
|
180
196
|
|
181
|
-
self.routed_scaling_factor = config.routed_scaling_factor
|
182
197
|
if self.tp_size > config.n_routed_experts:
|
183
198
|
raise ValueError(
|
184
199
|
f"Tensor parallel size {self.tp_size} is greater than "
|
@@ -278,10 +293,7 @@ class DeepseekV2MoE(nn.Module):
|
|
278
293
|
return self.forward_deepep(hidden_states, forward_mode)
|
279
294
|
|
280
295
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
281
|
-
|
282
|
-
shared_output = self.shared_experts(hidden_states)
|
283
|
-
else:
|
284
|
-
shared_output = None
|
296
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
285
297
|
# router_logits: (num_tokens, n_experts)
|
286
298
|
router_logits = self.gate(hidden_states)
|
287
299
|
final_hidden_states = (
|
@@ -311,8 +323,7 @@ class DeepseekV2MoE(nn.Module):
|
|
311
323
|
):
|
312
324
|
# router_logits: (num_tokens, n_experts)
|
313
325
|
router_logits = self.gate(hidden_states)
|
314
|
-
|
315
|
-
shared_output = self.shared_experts(hidden_states)
|
326
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
316
327
|
topk_weights, topk_idx = select_experts(
|
317
328
|
hidden_states=hidden_states,
|
318
329
|
router_logits=router_logits,
|
@@ -324,6 +335,7 @@ class DeepseekV2MoE(nn.Module):
|
|
324
335
|
correction_bias=self.correction_bias,
|
325
336
|
)
|
326
337
|
if self.ep_size > 1:
|
338
|
+
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
327
339
|
(
|
328
340
|
hidden_states,
|
329
341
|
topk_idx,
|
@@ -336,19 +348,15 @@ class DeepseekV2MoE(nn.Module):
|
|
336
348
|
hidden_states,
|
337
349
|
topk_idx,
|
338
350
|
topk_weights,
|
339
|
-
self.num_experts,
|
340
351
|
forward_mode=forward_mode,
|
341
352
|
)
|
342
|
-
final_hidden_states = (
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
forward_mode=forward_mode,
|
350
|
-
)
|
351
|
-
* self.routed_scaling_factor
|
353
|
+
final_hidden_states = self.experts(
|
354
|
+
hidden_states=hidden_states,
|
355
|
+
reorder_topk_ids=reorder_topk_ids,
|
356
|
+
seg_indptr=seg_indptr,
|
357
|
+
masked_m=masked_m,
|
358
|
+
expected_m=expected_m,
|
359
|
+
forward_mode=forward_mode,
|
352
360
|
)
|
353
361
|
if self.ep_size > 1:
|
354
362
|
final_hidden_states = self.deepep_dispatcher.combine(
|
@@ -357,11 +365,19 @@ class DeepseekV2MoE(nn.Module):
|
|
357
365
|
topk_weights,
|
358
366
|
forward_mode,
|
359
367
|
)
|
368
|
+
final_hidden_states *= self.routed_scaling_factor
|
369
|
+
|
360
370
|
if shared_output is not None:
|
361
371
|
final_hidden_states = final_hidden_states + shared_output
|
362
372
|
|
363
373
|
return final_hidden_states
|
364
374
|
|
375
|
+
def _forward_shared_experts(self, hidden_states):
|
376
|
+
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
377
|
+
return self.shared_experts(hidden_states)
|
378
|
+
else:
|
379
|
+
return None
|
380
|
+
|
365
381
|
|
366
382
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
367
383
|
import math
|
@@ -489,6 +505,7 @@ class DeepseekV2Attention(nn.Module):
|
|
489
505
|
self.scaling,
|
490
506
|
num_kv_heads=self.num_local_heads,
|
491
507
|
layer_id=layer_id,
|
508
|
+
quant_config=quant_config,
|
492
509
|
prefix=add_prefix("attn", prefix),
|
493
510
|
)
|
494
511
|
|
@@ -669,6 +686,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
669
686
|
num_kv_heads=1,
|
670
687
|
layer_id=layer_id,
|
671
688
|
v_head_dim=self.kv_lora_rank,
|
689
|
+
quant_config=quant_config,
|
672
690
|
prefix=add_prefix("attn_mqa", prefix),
|
673
691
|
)
|
674
692
|
|
@@ -679,6 +697,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
679
697
|
num_kv_heads=self.num_local_heads,
|
680
698
|
layer_id=layer_id,
|
681
699
|
v_head_dim=self.v_head_dim,
|
700
|
+
quant_config=quant_config,
|
682
701
|
prefix=add_prefix("attn_mha", prefix),
|
683
702
|
)
|
684
703
|
|
@@ -689,30 +708,54 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
689
708
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
690
709
|
"flashinfer_mla_disable_ragged"
|
691
710
|
]
|
711
|
+
self.disable_chunked_prefix_cache = global_server_args_dict[
|
712
|
+
"disable_chunked_prefix_cache"
|
713
|
+
]
|
692
714
|
self.attention_backend = global_server_args_dict["attention_backend"]
|
693
715
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
694
716
|
|
695
|
-
|
717
|
+
# TODO: Design a finer way to determine the threshold
|
718
|
+
self.chunked_prefix_cache_threshold = 8192
|
719
|
+
|
720
|
+
def dispatch_attn_forward_method(
|
721
|
+
self, forward_batch: ForwardBatch
|
722
|
+
) -> AttnForwardMethod:
|
696
723
|
if self.attention_backend == "flashinfer":
|
697
724
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
698
|
-
|
725
|
+
if (
|
699
726
|
not self.flashinfer_mla_disable_ragged
|
700
727
|
and forward_batch.forward_mode.is_extend()
|
701
728
|
and not forward_batch.forward_mode.is_target_verify()
|
702
729
|
and not forward_batch.forward_mode.is_draft_extend()
|
703
730
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
704
|
-
)
|
731
|
+
):
|
732
|
+
return AttnForwardMethod.MHA
|
733
|
+
else:
|
734
|
+
return AttnForwardMethod.MLA
|
705
735
|
elif self.attention_backend == "fa3":
|
706
|
-
# Flash Attention:
|
707
|
-
|
736
|
+
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
737
|
+
if (
|
738
|
+
forward_batch.forward_mode.is_extend()
|
739
|
+
and not self.disable_chunked_prefix_cache
|
740
|
+
and not forward_batch.forward_mode.is_target_verify()
|
741
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
742
|
+
and sum(forward_batch.extend_prefix_lens_cpu)
|
743
|
+
>= self.chunked_prefix_cache_threshold
|
744
|
+
):
|
745
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
746
|
+
else:
|
747
|
+
return AttnForwardMethod.MLA
|
708
748
|
else:
|
709
749
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
710
|
-
|
750
|
+
if (
|
711
751
|
forward_batch.forward_mode.is_extend()
|
712
752
|
and not forward_batch.forward_mode.is_target_verify()
|
713
753
|
and not forward_batch.forward_mode.is_draft_extend()
|
714
754
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
715
|
-
)
|
755
|
+
):
|
756
|
+
return AttnForwardMethod.MHA
|
757
|
+
else:
|
758
|
+
return AttnForwardMethod.MLA
|
716
759
|
|
717
760
|
def forward(
|
718
761
|
self,
|
@@ -726,8 +769,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
726
769
|
), "short-circuiting allreduce will lead to hangs"
|
727
770
|
return hidden_states
|
728
771
|
|
729
|
-
|
772
|
+
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
773
|
+
|
774
|
+
if attn_forward_method == AttnForwardMethod.MHA:
|
730
775
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
776
|
+
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
777
|
+
return self.forward_normal_chunked_kv(
|
778
|
+
positions, hidden_states, forward_batch
|
779
|
+
)
|
731
780
|
else:
|
732
781
|
if _is_hip:
|
733
782
|
if (
|
@@ -811,8 +860,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
811
860
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
812
861
|
)
|
813
862
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
814
|
-
q_nope_val, q_nope_scale =
|
815
|
-
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
863
|
+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
864
|
+
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
816
865
|
)
|
817
866
|
q_nope_out = bmm_fp8(
|
818
867
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -842,8 +891,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
842
891
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
843
892
|
)
|
844
893
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
845
|
-
attn_output_val, attn_output_scale =
|
846
|
-
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
894
|
+
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
895
|
+
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
847
896
|
)
|
848
897
|
attn_bmm_output = bmm_fp8(
|
849
898
|
attn_output_val,
|
@@ -889,8 +938,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
889
938
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
890
939
|
)
|
891
940
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
892
|
-
q_nope_val, q_nope_scale =
|
893
|
-
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
941
|
+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
942
|
+
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
894
943
|
)
|
895
944
|
q_nope_out = bmm_fp8(
|
896
945
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
@@ -985,8 +1034,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
985
1034
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
986
1035
|
)
|
987
1036
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
988
|
-
attn_output_val, attn_output_scale =
|
989
|
-
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
1037
|
+
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1038
|
+
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
990
1039
|
)
|
991
1040
|
attn_bmm_output = bmm_fp8(
|
992
1041
|
attn_output_val,
|
@@ -1002,6 +1051,127 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1002
1051
|
|
1003
1052
|
return output
|
1004
1053
|
|
1054
|
+
def _chunked_prefix_attn_mha(
|
1055
|
+
self,
|
1056
|
+
q: torch.Tensor,
|
1057
|
+
accum_output: torch.Tensor,
|
1058
|
+
accum_lse: torch.Tensor,
|
1059
|
+
forward_batch: ForwardBatch,
|
1060
|
+
) -> torch.Tensor:
|
1061
|
+
|
1062
|
+
assert forward_batch.num_prefix_chunks is not None
|
1063
|
+
for i in range(forward_batch.num_prefix_chunks):
|
1064
|
+
forward_batch.set_prefix_chunk_idx(i)
|
1065
|
+
|
1066
|
+
# Fetch latent cache from memory pool with precomputed chunked kv indices
|
1067
|
+
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
1068
|
+
self.attn_mha.layer_id
|
1069
|
+
)
|
1070
|
+
latent_cache = latent_cache_buf[
|
1071
|
+
forward_batch.prefix_chunk_kv_indices[i]
|
1072
|
+
].contiguous()
|
1073
|
+
|
1074
|
+
kv_a_normed, k_pe = latent_cache.split(
|
1075
|
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
1076
|
+
)
|
1077
|
+
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
|
1078
|
+
kv = self.kv_b_proj(kv_a_normed)[0]
|
1079
|
+
kv = kv.view(
|
1080
|
+
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
|
1081
|
+
)
|
1082
|
+
v = kv[..., self.qk_nope_head_dim :]
|
1083
|
+
k_nope = kv[..., : self.qk_nope_head_dim]
|
1084
|
+
|
1085
|
+
k = torch.empty(
|
1086
|
+
(
|
1087
|
+
k_nope.shape[0],
|
1088
|
+
self.num_local_heads,
|
1089
|
+
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
1090
|
+
),
|
1091
|
+
dtype=v.dtype,
|
1092
|
+
device=v.device,
|
1093
|
+
)
|
1094
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
1095
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
1096
|
+
|
1097
|
+
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
1098
|
+
lse = torch.transpose(lse, 0, 1).contiguous()
|
1099
|
+
tmp_output = torch.empty_like(accum_output)
|
1100
|
+
tmp_lse = torch.empty_like(accum_lse)
|
1101
|
+
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
1102
|
+
accum_output, accum_lse = tmp_output, tmp_lse
|
1103
|
+
|
1104
|
+
return accum_output
|
1105
|
+
|
1106
|
+
def forward_normal_chunked_kv(
|
1107
|
+
self,
|
1108
|
+
positions: torch.Tensor,
|
1109
|
+
hidden_states: torch.Tensor,
|
1110
|
+
forward_batch: ForwardBatch,
|
1111
|
+
) -> torch.Tensor:
|
1112
|
+
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
1113
|
+
# To avoid this, we split the kv cache into chunks and process them one after another.
|
1114
|
+
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
1115
|
+
# The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
|
1116
|
+
# will be helpful for understanding the purpose of this function.
|
1117
|
+
|
1118
|
+
# First do normal mha forward to get output for extended part
|
1119
|
+
if self.q_lora_rank is not None:
|
1120
|
+
q = self.q_a_proj(hidden_states)[0]
|
1121
|
+
q = self.q_a_layernorm(q)
|
1122
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1123
|
+
else:
|
1124
|
+
q = self.q_proj(hidden_states)[0].view(
|
1125
|
+
-1, self.num_local_heads, self.qk_head_dim
|
1126
|
+
)
|
1127
|
+
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1128
|
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
1129
|
+
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1130
|
+
latent_cache = latent_cache.unsqueeze(1)
|
1131
|
+
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
1132
|
+
kv = self.kv_b_proj(kv_a)[0]
|
1133
|
+
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1134
|
+
k_nope = kv[..., : self.qk_nope_head_dim]
|
1135
|
+
v = kv[..., self.qk_nope_head_dim :]
|
1136
|
+
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
1137
|
+
|
1138
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1139
|
+
q[..., self.qk_nope_head_dim :] = q_pe
|
1140
|
+
k = torch.empty_like(q)
|
1141
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
1142
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
1143
|
+
|
1144
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1145
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1146
|
+
|
1147
|
+
# Save latent cache
|
1148
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1149
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1150
|
+
)
|
1151
|
+
|
1152
|
+
# Do mha for extended part without prefix
|
1153
|
+
forward_batch.set_attn_attend_prefix_cache(False)
|
1154
|
+
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
1155
|
+
lse = torch.transpose(lse, 0, 1).contiguous()
|
1156
|
+
|
1157
|
+
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
1158
|
+
if any(forward_batch.extend_prefix_lens_cpu):
|
1159
|
+
# Only initialize the info once
|
1160
|
+
if forward_batch.num_prefix_chunks is None:
|
1161
|
+
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
1162
|
+
|
1163
|
+
forward_batch.set_attn_attend_prefix_cache(True)
|
1164
|
+
attn_output = self._chunked_prefix_attn_mha(
|
1165
|
+
q=q,
|
1166
|
+
accum_output=attn_output,
|
1167
|
+
accum_lse=lse,
|
1168
|
+
forward_batch=forward_batch,
|
1169
|
+
)
|
1170
|
+
|
1171
|
+
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
1172
|
+
output, _ = self.o_proj(attn_output)
|
1173
|
+
return output
|
1174
|
+
|
1005
1175
|
|
1006
1176
|
class DeepseekV2DecoderLayer(nn.Module):
|
1007
1177
|
|
@@ -1407,27 +1577,34 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1407
1577
|
w = self_attn.kv_b_proj.weight
|
1408
1578
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
1409
1579
|
# This may affect the accuracy of fp8 model.
|
1410
|
-
if
|
1580
|
+
if w.dtype in (
|
1411
1581
|
torch.float8_e4m3fn,
|
1412
1582
|
torch.float8_e4m3fnuz,
|
1413
1583
|
):
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
weight=
|
1420
|
-
|
1421
|
-
|
1584
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1585
|
+
weight_block_size = self.quant_config.weight_block_size
|
1586
|
+
if weight_block_size is not None:
|
1587
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1588
|
+
if _is_hip:
|
1589
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1590
|
+
weight=w,
|
1591
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1592
|
+
input_scale=None,
|
1593
|
+
)
|
1594
|
+
else:
|
1595
|
+
weight = w
|
1596
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1597
|
+
|
1598
|
+
w, scale = block_quant_to_tensor_quant(
|
1599
|
+
weight, weight_scale, weight_block_size
|
1422
1600
|
)
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
w, scale =
|
1428
|
-
weight, weight_scale, weight_block_size
|
1429
|
-
)
|
1601
|
+
self_attn.w_scale = scale
|
1602
|
+
else:
|
1603
|
+
weight = w
|
1604
|
+
weight_scale = self_attn.kv_b_proj.weight_scale
|
1605
|
+
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
1430
1606
|
self_attn.w_scale = scale
|
1607
|
+
|
1431
1608
|
if w.dtype == torch.int8:
|
1432
1609
|
if hasattr(self.quant_config, "weight_block_size"):
|
1433
1610
|
# block-wise int8 need it
|
@@ -1466,14 +1643,24 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1466
1643
|
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
|
1467
1644
|
weights_list = list(weights)
|
1468
1645
|
weights_dict = dict(weights_list)
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1646
|
+
if self.quant_config.get_name() == "w8a8_int8":
|
1647
|
+
suffix_list = [
|
1648
|
+
"down_proj.weight",
|
1649
|
+
"down_proj.weight_scale",
|
1650
|
+
"gate_proj.weight",
|
1651
|
+
"gate_proj.weight_scale",
|
1652
|
+
"up_proj.weight",
|
1653
|
+
"up_proj.weight_scale",
|
1654
|
+
]
|
1655
|
+
else:
|
1656
|
+
suffix_list = [
|
1657
|
+
"down_proj.weight",
|
1658
|
+
"down_proj.weight_scale_inv",
|
1659
|
+
"gate_proj.weight",
|
1660
|
+
"gate_proj.weight_scale_inv",
|
1661
|
+
"up_proj.weight",
|
1662
|
+
"up_proj.weight_scale_inv",
|
1663
|
+
]
|
1477
1664
|
names_to_remove = []
|
1478
1665
|
for moe_layer in tqdm(
|
1479
1666
|
range(
|
sglang/srt/models/exaone.py
CHANGED
sglang/srt/models/gemma.py
CHANGED
sglang/srt/models/gemma2.py
CHANGED
@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
|
|
193
193
|
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
194
194
|
# attention backend in `ForwardBatch`.
|
195
195
|
sliding_window_size=self.sliding_window,
|
196
|
+
quant_config=quant_config,
|
196
197
|
prefix=add_prefix("attn", prefix),
|
197
198
|
)
|
198
199
|
|
sglang/srt/models/gpt2.py
CHANGED
sglang/srt/models/gpt_bigcode.py
CHANGED
sglang/srt/models/granite.py
CHANGED
sglang/srt/models/grok.py
CHANGED
sglang/srt/models/internlm2.py
CHANGED