sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
|
24
24
|
|
25
25
|
from sglang.global_config import global_config
|
26
26
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
27
|
-
from sglang.srt.layers.attention.
|
28
|
-
create_flashinfer_kv_indices_triton,
|
29
|
-
)
|
27
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
30
28
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
31
29
|
from sglang.srt.layers.utils import is_sm100_supported
|
32
30
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -61,6 +59,115 @@ class PrefillMetadata:
|
|
61
59
|
global_workspace_buffer = None
|
62
60
|
|
63
61
|
|
62
|
+
class FlashInferMhaChunkKVRunner:
|
63
|
+
def __init__(
|
64
|
+
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
|
65
|
+
):
|
66
|
+
# Parse Constants
|
67
|
+
self.num_local_heads = (
|
68
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
69
|
+
)
|
70
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
71
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
72
|
+
self.v_head_dim = model_runner.model_config.v_head_dim
|
73
|
+
self.data_type = model_runner.dtype
|
74
|
+
self.q_data_type = model_runner.dtype
|
75
|
+
|
76
|
+
# Buffers and wrappers
|
77
|
+
self.qo_indptr = attn_backend.qo_indptr
|
78
|
+
self.workspace_buffer = attn_backend.workspace_buffer
|
79
|
+
self.fmha_backend = attn_backend.fmha_backend
|
80
|
+
|
81
|
+
self.chunk_ragged_wrappers = []
|
82
|
+
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
|
83
|
+
|
84
|
+
def update_prefix_chunks(self, num_prefix_chunks: int):
|
85
|
+
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
|
86
|
+
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
|
87
|
+
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
88
|
+
)
|
89
|
+
self.chunk_ragged_wrappers.append(ragged_wrapper)
|
90
|
+
|
91
|
+
def update_wrapper(
|
92
|
+
self,
|
93
|
+
forward_batch: ForwardBatch,
|
94
|
+
):
|
95
|
+
assert forward_batch.num_prefix_chunks is not None
|
96
|
+
num_prefix_chunks = forward_batch.num_prefix_chunks
|
97
|
+
self.update_prefix_chunks(num_prefix_chunks)
|
98
|
+
|
99
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
100
|
+
seq_lens = forward_batch.seq_lens
|
101
|
+
|
102
|
+
bs = len(seq_lens)
|
103
|
+
qo_indptr = self.qo_indptr
|
104
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
105
|
+
qo_indptr = qo_indptr[: bs + 1]
|
106
|
+
|
107
|
+
for chunk_idx in range(forward_batch.num_prefix_chunks):
|
108
|
+
# MHA for chunked prefix kv cache when running model with MLA
|
109
|
+
assert forward_batch.prefix_chunk_idx is not None
|
110
|
+
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
111
|
+
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
112
|
+
|
113
|
+
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
|
114
|
+
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
115
|
+
wrapper.begin_forward(
|
116
|
+
qo_indptr=qo_indptr,
|
117
|
+
kv_indptr=kv_indptr,
|
118
|
+
num_qo_heads=self.num_local_heads,
|
119
|
+
num_kv_heads=self.num_local_heads,
|
120
|
+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
121
|
+
head_dim_vo=self.v_head_dim,
|
122
|
+
q_data_type=self.q_data_type,
|
123
|
+
causal=False,
|
124
|
+
)
|
125
|
+
# ragged prefill
|
126
|
+
self.ragged_wrapper.begin_forward(
|
127
|
+
qo_indptr=qo_indptr,
|
128
|
+
kv_indptr=qo_indptr,
|
129
|
+
num_qo_heads=self.num_local_heads,
|
130
|
+
num_kv_heads=self.num_local_heads,
|
131
|
+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
132
|
+
head_dim_vo=self.v_head_dim,
|
133
|
+
q_data_type=self.q_data_type,
|
134
|
+
causal=True,
|
135
|
+
)
|
136
|
+
|
137
|
+
def forward(
|
138
|
+
self,
|
139
|
+
q: torch.Tensor,
|
140
|
+
k: torch.Tensor,
|
141
|
+
v: torch.Tensor,
|
142
|
+
layer: RadixAttention,
|
143
|
+
forward_batch: ForwardBatch,
|
144
|
+
):
|
145
|
+
logits_soft_cap = layer.logit_cap
|
146
|
+
if forward_batch.attn_attend_prefix_cache:
|
147
|
+
chunk_idx = forward_batch.prefix_chunk_idx
|
148
|
+
assert chunk_idx >= 0
|
149
|
+
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
150
|
+
o1, s1 = wrapper.forward_return_lse(
|
151
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
152
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
153
|
+
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
154
|
+
causal=False,
|
155
|
+
sm_scale=layer.scaling,
|
156
|
+
logits_soft_cap=logits_soft_cap,
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
o1, s1 = self.ragged_wrapper.forward_return_lse(
|
160
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
161
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
162
|
+
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
163
|
+
causal=True,
|
164
|
+
sm_scale=layer.scaling,
|
165
|
+
logits_soft_cap=logits_soft_cap,
|
166
|
+
)
|
167
|
+
|
168
|
+
return o1, s1
|
169
|
+
|
170
|
+
|
64
171
|
class FlashInferMLAAttnBackend(AttentionBackend):
|
65
172
|
"""Flashinfer attention kernels."""
|
66
173
|
|
@@ -72,15 +179,22 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
72
179
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
73
180
|
):
|
74
181
|
super().__init__()
|
75
|
-
|
76
182
|
# Parse constants
|
77
183
|
self.max_context_len = model_runner.model_config.context_len
|
78
184
|
self.device = model_runner.device
|
79
185
|
self.skip_prefill = skip_prefill
|
186
|
+
self.enable_chunk_kv = (
|
187
|
+
not skip_prefill
|
188
|
+
and global_server_args_dict["disaggregation_mode"] != "decode"
|
189
|
+
and not global_server_args_dict["disable_chunked_prefix_cache"]
|
190
|
+
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
191
|
+
)
|
192
|
+
self.page_size = model_runner.page_size
|
80
193
|
|
81
194
|
# Allocate buffers
|
82
195
|
global global_workspace_buffer
|
83
196
|
if global_workspace_buffer is None:
|
197
|
+
# different from flashinfer zero_init_global_workspace_buffer
|
84
198
|
global_workspace_buffer = torch.empty(
|
85
199
|
global_config.flashinfer_workspace_size,
|
86
200
|
dtype=torch.uint8,
|
@@ -96,23 +210,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
96
210
|
else:
|
97
211
|
self.kv_indptr = kv_indptr_buf
|
98
212
|
|
213
|
+
self.kv_indices = torch.empty(
|
214
|
+
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
|
215
|
+
dtype=torch.int32,
|
216
|
+
device=model_runner.device,
|
217
|
+
)
|
218
|
+
|
99
219
|
if not self.skip_prefill:
|
100
220
|
self.qo_indptr = torch.zeros(
|
101
221
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
102
222
|
)
|
103
223
|
|
104
224
|
if q_indptr_decode_buf is None:
|
225
|
+
# A hack to pre-initialize large batch size for dp attention
|
226
|
+
if model_runner.server_args.enable_dp_attention:
|
227
|
+
max_bs = model_runner.server_args.dp_size * max_bs
|
105
228
|
self.q_indptr_decode = torch.arange(
|
106
229
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
107
230
|
)
|
231
|
+
|
108
232
|
else:
|
109
233
|
self.q_indptr_decode = q_indptr_decode_buf
|
110
234
|
|
111
|
-
fmha_backend = "auto"
|
235
|
+
self.fmha_backend = "auto"
|
112
236
|
if is_sm100_supported():
|
113
|
-
fmha_backend = "cutlass"
|
237
|
+
self.fmha_backend = "cutlass"
|
114
238
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
115
|
-
self.workspace_buffer, "NHD", backend=fmha_backend
|
239
|
+
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
116
240
|
)
|
117
241
|
|
118
242
|
if not self.skip_prefill:
|
@@ -136,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
136
260
|
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
137
261
|
model_runner, self
|
138
262
|
)
|
263
|
+
if self.enable_chunk_kv:
|
264
|
+
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
|
139
265
|
|
140
266
|
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
141
267
|
model_runner, self
|
@@ -147,6 +273,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
147
273
|
self.prefill_cuda_graph_metadata = {} # For verify
|
148
274
|
|
149
275
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
276
|
+
|
150
277
|
if forward_batch.forward_mode.is_decode_or_idle():
|
151
278
|
self.indices_updater_decode.update(
|
152
279
|
forward_batch.req_pool_indices,
|
@@ -204,16 +331,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
204
331
|
max_num_tokens: int,
|
205
332
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
206
333
|
):
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
dtype=torch.int32,
|
211
|
-
device="cuda",
|
212
|
-
)
|
213
|
-
else:
|
214
|
-
cuda_graph_kv_indices = kv_indices_buf
|
215
|
-
|
216
|
-
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
334
|
+
self.cuda_graph_kv_indices = (
|
335
|
+
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
|
336
|
+
)
|
217
337
|
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
218
338
|
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
219
339
|
self.cuda_graph_kv_lens = torch.ones(
|
@@ -239,6 +359,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
239
359
|
forward_mode: ForwardMode,
|
240
360
|
spec_info: Optional[SpecInfo],
|
241
361
|
):
|
362
|
+
|
242
363
|
if forward_mode.is_decode_or_idle():
|
243
364
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
244
365
|
self.workspace_buffer,
|
@@ -249,7 +370,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
249
370
|
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
250
371
|
backend="auto",
|
251
372
|
)
|
252
|
-
|
253
373
|
seq_lens_sum = seq_lens.sum().item()
|
254
374
|
self.indices_updater_decode.update(
|
255
375
|
req_pool_indices,
|
@@ -320,11 +440,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
320
440
|
spec_info: Optional[SpecInfo],
|
321
441
|
seq_lens_cpu: Optional[torch.Tensor],
|
322
442
|
):
|
443
|
+
|
323
444
|
if forward_mode.is_decode_or_idle():
|
324
445
|
assert seq_lens_cpu is not None
|
325
446
|
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
447
|
+
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
|
326
448
|
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
327
|
-
|
449
|
+
num_pages_per_req, dim=0
|
328
450
|
)
|
329
451
|
self.fast_decode_kwargs.update(
|
330
452
|
{
|
@@ -333,7 +455,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
333
455
|
"kv_len_arr_cpu": kv_len_arr_cpu,
|
334
456
|
}
|
335
457
|
)
|
336
|
-
|
337
458
|
self.indices_updater_decode.update(
|
338
459
|
req_pool_indices[:bs],
|
339
460
|
seq_lens[:bs],
|
@@ -369,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
369
490
|
def get_cuda_graph_seq_len_fill_value(self):
|
370
491
|
return 1
|
371
492
|
|
493
|
+
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
494
|
+
"""Init the metadata for a forward pass."""
|
495
|
+
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
496
|
+
|
372
497
|
def forward_extend(
|
373
498
|
self,
|
374
499
|
q: torch.Tensor,
|
@@ -380,6 +505,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
380
505
|
q_rope: Optional[torch.Tensor] = None,
|
381
506
|
k_rope: Optional[torch.Tensor] = None,
|
382
507
|
):
|
508
|
+
if (
|
509
|
+
forward_batch.attn_attend_prefix_cache is not None
|
510
|
+
and forward_batch.mha_return_lse
|
511
|
+
): # MHA Chunk
|
512
|
+
assert self.enable_chunk_kv
|
513
|
+
assert q_rope is None
|
514
|
+
assert k_rope is None
|
515
|
+
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
|
516
|
+
return o1, s1
|
383
517
|
|
384
518
|
cache_loc = forward_batch.out_cache_loc
|
385
519
|
logits_soft_cap = layer.logit_cap
|
@@ -400,7 +534,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
400
534
|
q_rope = q_rope.view(
|
401
535
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
402
536
|
)
|
403
|
-
|
404
537
|
if self.forward_metadata.use_ragged:
|
405
538
|
# ragged prefill
|
406
539
|
if q_rope is not None:
|
@@ -410,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
410
543
|
k = torch.cat([k, k_rope], dim=-1)
|
411
544
|
o = self.prefill_wrapper_ragged.forward(
|
412
545
|
qall,
|
413
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
414
|
-
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
546
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
547
|
+
v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
415
548
|
causal=True,
|
416
549
|
sm_scale=layer.scaling,
|
417
550
|
logits_soft_cap=logits_soft_cap,
|
@@ -421,6 +554,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
421
554
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
422
555
|
q.dtype
|
423
556
|
)
|
557
|
+
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
558
|
+
|
424
559
|
if q_rope is None:
|
425
560
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
426
561
|
q, q_rope = (
|
@@ -482,17 +617,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
482
617
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
483
618
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
484
619
|
|
485
|
-
|
620
|
+
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
486
621
|
q.dtype
|
487
622
|
)
|
623
|
+
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
488
624
|
|
489
625
|
o = q_nope.new_empty(q_nope.shape)
|
490
|
-
# Direct call to run without the wrapper
|
491
626
|
o = decode_wrapper.run(
|
492
627
|
q_nope,
|
493
628
|
q_rope,
|
494
|
-
|
495
|
-
|
629
|
+
k_buf[:, :, : layer.v_head_dim],
|
630
|
+
k_buf[:, :, layer.v_head_dim :],
|
496
631
|
out=o,
|
497
632
|
)
|
498
633
|
|
@@ -511,9 +646,10 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
511
646
|
self.scaling = model_runner.model_config.scaling
|
512
647
|
self.data_type = model_runner.dtype
|
513
648
|
self.attn_backend = attn_backend
|
514
|
-
|
649
|
+
self.page_size = model_runner.page_size
|
515
650
|
# Buffers and wrappers
|
516
651
|
self.kv_indptr = attn_backend.kv_indptr
|
652
|
+
self.kv_indices = attn_backend.kv_indices
|
517
653
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
518
654
|
self.q_indptr = attn_backend.q_indptr_decode
|
519
655
|
|
@@ -557,13 +693,17 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
557
693
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
558
694
|
sm_scale = self.scaling
|
559
695
|
if spec_info is None:
|
560
|
-
|
696
|
+
num_pages_per_req = (
|
697
|
+
paged_kernel_lens + self.page_size - 1
|
698
|
+
) // self.page_size
|
699
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
561
700
|
kv_indptr = kv_indptr[: bs + 1]
|
562
701
|
kv_indices = (
|
563
|
-
|
702
|
+
self.kv_indices[: kv_indptr[-1]]
|
564
703
|
if not init_metadata_replay
|
565
704
|
else fast_decode_kwargs["kv_indices"]
|
566
705
|
)
|
706
|
+
|
567
707
|
create_flashinfer_kv_indices_triton[(bs,)](
|
568
708
|
self.req_to_token,
|
569
709
|
req_pool_indices,
|
@@ -572,39 +712,40 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
572
712
|
None,
|
573
713
|
kv_indices,
|
574
714
|
self.req_to_token.shape[1],
|
715
|
+
self.page_size,
|
575
716
|
)
|
576
717
|
else:
|
577
718
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
578
719
|
|
579
720
|
if not init_metadata_replay:
|
580
721
|
wrapper.plan(
|
581
|
-
q_indptr,
|
582
|
-
kv_indptr,
|
583
|
-
kv_indices,
|
584
|
-
kv_lens,
|
585
|
-
self.num_local_heads,
|
586
|
-
self.kv_lora_rank,
|
587
|
-
self.qk_rope_head_dim,
|
588
|
-
|
589
|
-
False,
|
590
|
-
sm_scale,
|
591
|
-
self.data_type,
|
592
|
-
self.data_type,
|
722
|
+
qo_indptr=q_indptr,
|
723
|
+
kv_indptr=kv_indptr,
|
724
|
+
kv_indices=kv_indices,
|
725
|
+
kv_len_arr=kv_lens,
|
726
|
+
num_heads=self.num_local_heads,
|
727
|
+
head_dim_ckv=self.kv_lora_rank,
|
728
|
+
head_dim_kpe=self.qk_rope_head_dim,
|
729
|
+
page_size=self.page_size,
|
730
|
+
causal=False,
|
731
|
+
sm_scale=sm_scale,
|
732
|
+
q_data_type=self.data_type,
|
733
|
+
kv_data_type=self.data_type,
|
593
734
|
)
|
594
735
|
else:
|
595
736
|
wrapper.plan(
|
596
|
-
fast_decode_kwargs["qo_indptr_cpu"],
|
597
|
-
fast_decode_kwargs["kv_indptr_cpu"],
|
598
|
-
kv_indices,
|
599
|
-
fast_decode_kwargs["kv_len_arr_cpu"],
|
600
|
-
self.num_local_heads,
|
601
|
-
self.kv_lora_rank,
|
602
|
-
self.qk_rope_head_dim,
|
603
|
-
|
604
|
-
False,
|
605
|
-
sm_scale,
|
606
|
-
self.data_type,
|
607
|
-
self.data_type,
|
737
|
+
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
|
738
|
+
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
|
739
|
+
kv_indices=kv_indices,
|
740
|
+
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
|
741
|
+
num_heads=self.num_local_heads,
|
742
|
+
head_dim_ckv=self.kv_lora_rank,
|
743
|
+
head_dim_kpe=self.qk_rope_head_dim,
|
744
|
+
page_size=self.page_size,
|
745
|
+
causal=False,
|
746
|
+
sm_scale=sm_scale,
|
747
|
+
q_data_type=self.data_type,
|
748
|
+
kv_data_type=self.data_type,
|
608
749
|
)
|
609
750
|
|
610
751
|
|
@@ -626,12 +767,14 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
626
767
|
# Buffers and wrappers
|
627
768
|
self.kv_indptr = attn_backend.kv_indptr
|
628
769
|
self.qo_indptr = attn_backend.qo_indptr
|
770
|
+
self.kv_indices = attn_backend.kv_indices
|
629
771
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
630
772
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
773
|
+
self.page_size = model_runner.page_size
|
631
774
|
|
632
775
|
def update(
|
633
776
|
self,
|
634
|
-
req_pool_indices: torch.
|
777
|
+
req_pool_indices: torch.Tensor,
|
635
778
|
seq_lens: torch.Tensor,
|
636
779
|
seq_lens_sum: int,
|
637
780
|
prefix_lens: torch.Tensor,
|
@@ -645,7 +788,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
645
788
|
else:
|
646
789
|
paged_kernel_lens = seq_lens
|
647
790
|
paged_kernel_lens_sum = seq_lens_sum
|
648
|
-
|
649
791
|
self.call_begin_forward(
|
650
792
|
self.prefill_wrapper_ragged,
|
651
793
|
prefill_wrapper_paged,
|
@@ -679,13 +821,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
679
821
|
|
680
822
|
if spec_info is None:
|
681
823
|
assert len(seq_lens) == len(req_pool_indices)
|
682
|
-
|
824
|
+
num_pages_per_req = (
|
825
|
+
paged_kernel_lens + self.page_size - 1
|
826
|
+
) // self.page_size
|
827
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
683
828
|
kv_indptr = kv_indptr[: bs + 1]
|
684
|
-
kv_indices =
|
685
|
-
paged_kernel_lens_sum,
|
686
|
-
dtype=torch.int32,
|
687
|
-
device=req_pool_indices.device,
|
688
|
-
)
|
829
|
+
kv_indices = self.kv_indices[: kv_indptr[-1]]
|
689
830
|
create_flashinfer_kv_indices_triton[(bs,)](
|
690
831
|
self.req_to_token,
|
691
832
|
req_pool_indices,
|
@@ -694,6 +835,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
694
835
|
None,
|
695
836
|
kv_indices,
|
696
837
|
self.req_to_token.shape[1],
|
838
|
+
self.page_size,
|
697
839
|
)
|
698
840
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
699
841
|
qo_indptr = qo_indptr[: bs + 1]
|
@@ -711,7 +853,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
711
853
|
self.req_to_token,
|
712
854
|
)
|
713
855
|
)
|
714
|
-
|
715
856
|
if use_ragged:
|
716
857
|
# ragged prefill
|
717
858
|
wrapper_ragged.begin_forward(
|
@@ -722,23 +863,30 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
722
863
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
723
864
|
head_dim_vo=self.v_head_dim,
|
724
865
|
q_data_type=self.q_data_type,
|
866
|
+
causal=True,
|
725
867
|
)
|
726
868
|
else:
|
727
869
|
# mla paged prefill
|
728
|
-
|
870
|
+
if spec_info is not None:
|
871
|
+
assert (
|
872
|
+
self.page_size == 1
|
873
|
+
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
|
874
|
+
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
|
875
|
+
else:
|
876
|
+
kv_lens = paged_kernel_lens.to(torch.int32)
|
729
877
|
wrapper_paged.plan(
|
730
|
-
qo_indptr,
|
731
|
-
kv_indptr,
|
732
|
-
kv_indices,
|
733
|
-
kv_len_arr,
|
734
|
-
self.num_local_heads,
|
735
|
-
self.kv_lora_rank,
|
736
|
-
self.qk_rope_head_dim,
|
737
|
-
|
738
|
-
True,
|
739
|
-
sm_scale,
|
740
|
-
self.q_data_type,
|
741
|
-
self.data_type,
|
878
|
+
qo_indptr=qo_indptr,
|
879
|
+
kv_indptr=kv_indptr,
|
880
|
+
kv_indices=kv_indices,
|
881
|
+
kv_len_arr=kv_lens,
|
882
|
+
num_heads=self.num_local_heads,
|
883
|
+
head_dim_ckv=self.kv_lora_rank,
|
884
|
+
head_dim_kpe=self.qk_rope_head_dim,
|
885
|
+
page_size=self.page_size,
|
886
|
+
causal=True,
|
887
|
+
sm_scale=sm_scale,
|
888
|
+
q_data_type=self.q_data_type,
|
889
|
+
kv_data_type=self.data_type,
|
742
890
|
)
|
743
891
|
|
744
892
|
|
@@ -833,6 +981,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
833
981
|
call_fn(i, forward_batch)
|
834
982
|
|
835
983
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
984
|
+
|
836
985
|
kv_indices = torch.zeros(
|
837
986
|
(
|
838
987
|
self.speculative_num_steps,
|
@@ -868,6 +1017,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
868
1017
|
)
|
869
1018
|
|
870
1019
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
1020
|
+
|
871
1021
|
def call_fn(i, forward_batch):
|
872
1022
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
873
1023
|
forward_batch.batch_size,
|
@@ -884,6 +1034,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
884
1034
|
def init_forward_metadata_replay_cuda_graph(
|
885
1035
|
self, forward_batch: ForwardBatch, bs: int
|
886
1036
|
):
|
1037
|
+
|
887
1038
|
def call_fn(i, forward_batch):
|
888
1039
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
889
1040
|
bs,
|