sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
|
24
24
|
|
25
25
|
from sglang.global_config import global_config
|
26
26
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
27
|
-
from sglang.srt.layers.attention.
|
28
|
-
create_flashinfer_kv_indices_triton,
|
29
|
-
)
|
27
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
30
28
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
31
29
|
from sglang.srt.layers.utils import is_sm100_supported
|
32
30
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -61,6 +59,115 @@ class PrefillMetadata:
|
|
61
59
|
global_workspace_buffer = None
|
62
60
|
|
63
61
|
|
62
|
+
class FlashInferMhaChunkKVRunner:
|
63
|
+
def __init__(
|
64
|
+
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
|
65
|
+
):
|
66
|
+
# Parse Constants
|
67
|
+
self.num_local_heads = (
|
68
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
69
|
+
)
|
70
|
+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
71
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
72
|
+
self.v_head_dim = model_runner.model_config.v_head_dim
|
73
|
+
self.data_type = model_runner.dtype
|
74
|
+
self.q_data_type = model_runner.dtype
|
75
|
+
|
76
|
+
# Buffers and wrappers
|
77
|
+
self.qo_indptr = attn_backend.qo_indptr
|
78
|
+
self.workspace_buffer = attn_backend.workspace_buffer
|
79
|
+
self.fmha_backend = attn_backend.fmha_backend
|
80
|
+
|
81
|
+
self.chunk_ragged_wrappers = []
|
82
|
+
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
|
83
|
+
|
84
|
+
def update_prefix_chunks(self, num_prefix_chunks: int):
|
85
|
+
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
|
86
|
+
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
|
87
|
+
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
88
|
+
)
|
89
|
+
self.chunk_ragged_wrappers.append(ragged_wrapper)
|
90
|
+
|
91
|
+
def update_wrapper(
|
92
|
+
self,
|
93
|
+
forward_batch: ForwardBatch,
|
94
|
+
):
|
95
|
+
assert forward_batch.num_prefix_chunks is not None
|
96
|
+
num_prefix_chunks = forward_batch.num_prefix_chunks
|
97
|
+
self.update_prefix_chunks(num_prefix_chunks)
|
98
|
+
|
99
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
100
|
+
seq_lens = forward_batch.seq_lens
|
101
|
+
|
102
|
+
bs = len(seq_lens)
|
103
|
+
qo_indptr = self.qo_indptr
|
104
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
105
|
+
qo_indptr = qo_indptr[: bs + 1]
|
106
|
+
|
107
|
+
for chunk_idx in range(forward_batch.num_prefix_chunks):
|
108
|
+
# MHA for chunked prefix kv cache when running model with MLA
|
109
|
+
assert forward_batch.prefix_chunk_idx is not None
|
110
|
+
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
111
|
+
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
112
|
+
|
113
|
+
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
|
114
|
+
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
115
|
+
wrapper.begin_forward(
|
116
|
+
qo_indptr=qo_indptr,
|
117
|
+
kv_indptr=kv_indptr,
|
118
|
+
num_qo_heads=self.num_local_heads,
|
119
|
+
num_kv_heads=self.num_local_heads,
|
120
|
+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
121
|
+
head_dim_vo=self.v_head_dim,
|
122
|
+
q_data_type=self.q_data_type,
|
123
|
+
causal=False,
|
124
|
+
)
|
125
|
+
# ragged prefill
|
126
|
+
self.ragged_wrapper.begin_forward(
|
127
|
+
qo_indptr=qo_indptr,
|
128
|
+
kv_indptr=qo_indptr,
|
129
|
+
num_qo_heads=self.num_local_heads,
|
130
|
+
num_kv_heads=self.num_local_heads,
|
131
|
+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
132
|
+
head_dim_vo=self.v_head_dim,
|
133
|
+
q_data_type=self.q_data_type,
|
134
|
+
causal=True,
|
135
|
+
)
|
136
|
+
|
137
|
+
def forward(
|
138
|
+
self,
|
139
|
+
q: torch.Tensor,
|
140
|
+
k: torch.Tensor,
|
141
|
+
v: torch.Tensor,
|
142
|
+
layer: RadixAttention,
|
143
|
+
forward_batch: ForwardBatch,
|
144
|
+
):
|
145
|
+
logits_soft_cap = layer.logit_cap
|
146
|
+
if forward_batch.attn_attend_prefix_cache:
|
147
|
+
chunk_idx = forward_batch.prefix_chunk_idx
|
148
|
+
assert chunk_idx >= 0
|
149
|
+
wrapper = self.chunk_ragged_wrappers[chunk_idx]
|
150
|
+
o1, s1 = wrapper.forward_return_lse(
|
151
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
152
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
153
|
+
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
154
|
+
causal=False,
|
155
|
+
sm_scale=layer.scaling,
|
156
|
+
logits_soft_cap=logits_soft_cap,
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
o1, s1 = self.ragged_wrapper.forward_return_lse(
|
160
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
161
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
162
|
+
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
|
163
|
+
causal=True,
|
164
|
+
sm_scale=layer.scaling,
|
165
|
+
logits_soft_cap=logits_soft_cap,
|
166
|
+
)
|
167
|
+
|
168
|
+
return o1, s1
|
169
|
+
|
170
|
+
|
64
171
|
class FlashInferMLAAttnBackend(AttentionBackend):
|
65
172
|
"""Flashinfer attention kernels."""
|
66
173
|
|
@@ -72,11 +179,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
72
179
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
73
180
|
):
|
74
181
|
super().__init__()
|
75
|
-
|
76
182
|
# Parse constants
|
77
183
|
self.max_context_len = model_runner.model_config.context_len
|
78
184
|
self.device = model_runner.device
|
79
185
|
self.skip_prefill = skip_prefill
|
186
|
+
self.enable_chunk_kv = (
|
187
|
+
not skip_prefill
|
188
|
+
and global_server_args_dict["disaggregation_mode"] != "decode"
|
189
|
+
and not global_server_args_dict["disable_chunked_prefix_cache"]
|
190
|
+
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
191
|
+
)
|
192
|
+
self.page_size = model_runner.page_size
|
80
193
|
|
81
194
|
# Allocate buffers
|
82
195
|
global global_workspace_buffer
|
@@ -97,23 +210,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
97
210
|
else:
|
98
211
|
self.kv_indptr = kv_indptr_buf
|
99
212
|
|
213
|
+
self.kv_indices = torch.empty(
|
214
|
+
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
|
215
|
+
dtype=torch.int32,
|
216
|
+
device=model_runner.device,
|
217
|
+
)
|
218
|
+
|
100
219
|
if not self.skip_prefill:
|
101
220
|
self.qo_indptr = torch.zeros(
|
102
221
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
103
222
|
)
|
104
223
|
|
105
224
|
if q_indptr_decode_buf is None:
|
225
|
+
# A hack to pre-initialize large batch size for dp attention
|
226
|
+
if model_runner.server_args.enable_dp_attention:
|
227
|
+
max_bs = model_runner.server_args.dp_size * max_bs
|
106
228
|
self.q_indptr_decode = torch.arange(
|
107
229
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
108
230
|
)
|
231
|
+
|
109
232
|
else:
|
110
233
|
self.q_indptr_decode = q_indptr_decode_buf
|
111
234
|
|
112
|
-
fmha_backend = "auto"
|
235
|
+
self.fmha_backend = "auto"
|
113
236
|
if is_sm100_supported():
|
114
|
-
fmha_backend = "cutlass"
|
237
|
+
self.fmha_backend = "cutlass"
|
115
238
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
116
|
-
self.workspace_buffer, "NHD", backend=fmha_backend
|
239
|
+
self.workspace_buffer, "NHD", backend=self.fmha_backend
|
117
240
|
)
|
118
241
|
|
119
242
|
if not self.skip_prefill:
|
@@ -137,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
137
260
|
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
138
261
|
model_runner, self
|
139
262
|
)
|
263
|
+
if self.enable_chunk_kv:
|
264
|
+
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
|
140
265
|
|
141
266
|
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
142
267
|
model_runner, self
|
@@ -148,6 +273,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
148
273
|
self.prefill_cuda_graph_metadata = {} # For verify
|
149
274
|
|
150
275
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
276
|
+
|
151
277
|
if forward_batch.forward_mode.is_decode_or_idle():
|
152
278
|
self.indices_updater_decode.update(
|
153
279
|
forward_batch.req_pool_indices,
|
@@ -205,16 +331,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
205
331
|
max_num_tokens: int,
|
206
332
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
207
333
|
):
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
dtype=torch.int32,
|
212
|
-
device="cuda",
|
213
|
-
)
|
214
|
-
else:
|
215
|
-
cuda_graph_kv_indices = kv_indices_buf
|
216
|
-
|
217
|
-
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
334
|
+
self.cuda_graph_kv_indices = (
|
335
|
+
self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
|
336
|
+
)
|
218
337
|
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
219
338
|
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
220
339
|
self.cuda_graph_kv_lens = torch.ones(
|
@@ -240,6 +359,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
240
359
|
forward_mode: ForwardMode,
|
241
360
|
spec_info: Optional[SpecInfo],
|
242
361
|
):
|
362
|
+
|
243
363
|
if forward_mode.is_decode_or_idle():
|
244
364
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
245
365
|
self.workspace_buffer,
|
@@ -250,7 +370,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
250
370
|
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
251
371
|
backend="auto",
|
252
372
|
)
|
253
|
-
|
254
373
|
seq_lens_sum = seq_lens.sum().item()
|
255
374
|
self.indices_updater_decode.update(
|
256
375
|
req_pool_indices,
|
@@ -321,11 +440,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
321
440
|
spec_info: Optional[SpecInfo],
|
322
441
|
seq_lens_cpu: Optional[torch.Tensor],
|
323
442
|
):
|
443
|
+
|
324
444
|
if forward_mode.is_decode_or_idle():
|
325
445
|
assert seq_lens_cpu is not None
|
326
446
|
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
447
|
+
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
|
327
448
|
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
328
|
-
|
449
|
+
num_pages_per_req, dim=0
|
329
450
|
)
|
330
451
|
self.fast_decode_kwargs.update(
|
331
452
|
{
|
@@ -334,7 +455,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
334
455
|
"kv_len_arr_cpu": kv_len_arr_cpu,
|
335
456
|
}
|
336
457
|
)
|
337
|
-
|
338
458
|
self.indices_updater_decode.update(
|
339
459
|
req_pool_indices[:bs],
|
340
460
|
seq_lens[:bs],
|
@@ -370,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
370
490
|
def get_cuda_graph_seq_len_fill_value(self):
|
371
491
|
return 1
|
372
492
|
|
493
|
+
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
494
|
+
"""Init the metadata for a forward pass."""
|
495
|
+
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
496
|
+
|
373
497
|
def forward_extend(
|
374
498
|
self,
|
375
499
|
q: torch.Tensor,
|
@@ -381,6 +505,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
381
505
|
q_rope: Optional[torch.Tensor] = None,
|
382
506
|
k_rope: Optional[torch.Tensor] = None,
|
383
507
|
):
|
508
|
+
if (
|
509
|
+
forward_batch.attn_attend_prefix_cache is not None
|
510
|
+
and forward_batch.mha_return_lse
|
511
|
+
): # MHA Chunk
|
512
|
+
assert self.enable_chunk_kv
|
513
|
+
assert q_rope is None
|
514
|
+
assert k_rope is None
|
515
|
+
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
|
516
|
+
return o1, s1
|
384
517
|
|
385
518
|
cache_loc = forward_batch.out_cache_loc
|
386
519
|
logits_soft_cap = layer.logit_cap
|
@@ -401,7 +534,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
401
534
|
q_rope = q_rope.view(
|
402
535
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
403
536
|
)
|
404
|
-
|
405
537
|
if self.forward_metadata.use_ragged:
|
406
538
|
# ragged prefill
|
407
539
|
if q_rope is not None:
|
@@ -411,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
411
543
|
k = torch.cat([k, k_rope], dim=-1)
|
412
544
|
o = self.prefill_wrapper_ragged.forward(
|
413
545
|
qall,
|
414
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
415
|
-
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
546
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
547
|
+
v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
416
548
|
causal=True,
|
417
549
|
sm_scale=layer.scaling,
|
418
550
|
logits_soft_cap=logits_soft_cap,
|
@@ -422,6 +554,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
422
554
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
423
555
|
q.dtype
|
424
556
|
)
|
557
|
+
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
558
|
+
|
425
559
|
if q_rope is None:
|
426
560
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
427
561
|
q, q_rope = (
|
@@ -483,17 +617,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
483
617
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
484
618
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
485
619
|
|
486
|
-
|
620
|
+
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
487
621
|
q.dtype
|
488
622
|
)
|
623
|
+
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
489
624
|
|
490
625
|
o = q_nope.new_empty(q_nope.shape)
|
491
|
-
# Direct call to run without the wrapper
|
492
626
|
o = decode_wrapper.run(
|
493
627
|
q_nope,
|
494
628
|
q_rope,
|
495
|
-
|
496
|
-
|
629
|
+
k_buf[:, :, : layer.v_head_dim],
|
630
|
+
k_buf[:, :, layer.v_head_dim :],
|
497
631
|
out=o,
|
498
632
|
)
|
499
633
|
|
@@ -512,9 +646,10 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
512
646
|
self.scaling = model_runner.model_config.scaling
|
513
647
|
self.data_type = model_runner.dtype
|
514
648
|
self.attn_backend = attn_backend
|
515
|
-
|
649
|
+
self.page_size = model_runner.page_size
|
516
650
|
# Buffers and wrappers
|
517
651
|
self.kv_indptr = attn_backend.kv_indptr
|
652
|
+
self.kv_indices = attn_backend.kv_indices
|
518
653
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
519
654
|
self.q_indptr = attn_backend.q_indptr_decode
|
520
655
|
|
@@ -558,13 +693,17 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
558
693
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
559
694
|
sm_scale = self.scaling
|
560
695
|
if spec_info is None:
|
561
|
-
|
696
|
+
num_pages_per_req = (
|
697
|
+
paged_kernel_lens + self.page_size - 1
|
698
|
+
) // self.page_size
|
699
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
562
700
|
kv_indptr = kv_indptr[: bs + 1]
|
563
701
|
kv_indices = (
|
564
|
-
|
702
|
+
self.kv_indices[: kv_indptr[-1]]
|
565
703
|
if not init_metadata_replay
|
566
704
|
else fast_decode_kwargs["kv_indices"]
|
567
705
|
)
|
706
|
+
|
568
707
|
create_flashinfer_kv_indices_triton[(bs,)](
|
569
708
|
self.req_to_token,
|
570
709
|
req_pool_indices,
|
@@ -573,39 +712,40 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
573
712
|
None,
|
574
713
|
kv_indices,
|
575
714
|
self.req_to_token.shape[1],
|
715
|
+
self.page_size,
|
576
716
|
)
|
577
717
|
else:
|
578
718
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
579
719
|
|
580
720
|
if not init_metadata_replay:
|
581
721
|
wrapper.plan(
|
582
|
-
q_indptr,
|
583
|
-
kv_indptr,
|
584
|
-
kv_indices,
|
585
|
-
kv_lens,
|
586
|
-
self.num_local_heads,
|
587
|
-
self.kv_lora_rank,
|
588
|
-
self.qk_rope_head_dim,
|
589
|
-
|
590
|
-
False,
|
591
|
-
sm_scale,
|
592
|
-
self.data_type,
|
593
|
-
self.data_type,
|
722
|
+
qo_indptr=q_indptr,
|
723
|
+
kv_indptr=kv_indptr,
|
724
|
+
kv_indices=kv_indices,
|
725
|
+
kv_len_arr=kv_lens,
|
726
|
+
num_heads=self.num_local_heads,
|
727
|
+
head_dim_ckv=self.kv_lora_rank,
|
728
|
+
head_dim_kpe=self.qk_rope_head_dim,
|
729
|
+
page_size=self.page_size,
|
730
|
+
causal=False,
|
731
|
+
sm_scale=sm_scale,
|
732
|
+
q_data_type=self.data_type,
|
733
|
+
kv_data_type=self.data_type,
|
594
734
|
)
|
595
735
|
else:
|
596
736
|
wrapper.plan(
|
597
|
-
fast_decode_kwargs["qo_indptr_cpu"],
|
598
|
-
fast_decode_kwargs["kv_indptr_cpu"],
|
599
|
-
kv_indices,
|
600
|
-
fast_decode_kwargs["kv_len_arr_cpu"],
|
601
|
-
self.num_local_heads,
|
602
|
-
self.kv_lora_rank,
|
603
|
-
self.qk_rope_head_dim,
|
604
|
-
|
605
|
-
False,
|
606
|
-
sm_scale,
|
607
|
-
self.data_type,
|
608
|
-
self.data_type,
|
737
|
+
qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
|
738
|
+
kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
|
739
|
+
kv_indices=kv_indices,
|
740
|
+
kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
|
741
|
+
num_heads=self.num_local_heads,
|
742
|
+
head_dim_ckv=self.kv_lora_rank,
|
743
|
+
head_dim_kpe=self.qk_rope_head_dim,
|
744
|
+
page_size=self.page_size,
|
745
|
+
causal=False,
|
746
|
+
sm_scale=sm_scale,
|
747
|
+
q_data_type=self.data_type,
|
748
|
+
kv_data_type=self.data_type,
|
609
749
|
)
|
610
750
|
|
611
751
|
|
@@ -627,12 +767,14 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
627
767
|
# Buffers and wrappers
|
628
768
|
self.kv_indptr = attn_backend.kv_indptr
|
629
769
|
self.qo_indptr = attn_backend.qo_indptr
|
770
|
+
self.kv_indices = attn_backend.kv_indices
|
630
771
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
631
772
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
773
|
+
self.page_size = model_runner.page_size
|
632
774
|
|
633
775
|
def update(
|
634
776
|
self,
|
635
|
-
req_pool_indices: torch.
|
777
|
+
req_pool_indices: torch.Tensor,
|
636
778
|
seq_lens: torch.Tensor,
|
637
779
|
seq_lens_sum: int,
|
638
780
|
prefix_lens: torch.Tensor,
|
@@ -646,7 +788,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
646
788
|
else:
|
647
789
|
paged_kernel_lens = seq_lens
|
648
790
|
paged_kernel_lens_sum = seq_lens_sum
|
649
|
-
|
650
791
|
self.call_begin_forward(
|
651
792
|
self.prefill_wrapper_ragged,
|
652
793
|
prefill_wrapper_paged,
|
@@ -680,13 +821,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
680
821
|
|
681
822
|
if spec_info is None:
|
682
823
|
assert len(seq_lens) == len(req_pool_indices)
|
683
|
-
|
824
|
+
num_pages_per_req = (
|
825
|
+
paged_kernel_lens + self.page_size - 1
|
826
|
+
) // self.page_size
|
827
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
684
828
|
kv_indptr = kv_indptr[: bs + 1]
|
685
|
-
kv_indices =
|
686
|
-
paged_kernel_lens_sum,
|
687
|
-
dtype=torch.int32,
|
688
|
-
device=req_pool_indices.device,
|
689
|
-
)
|
829
|
+
kv_indices = self.kv_indices[: kv_indptr[-1]]
|
690
830
|
create_flashinfer_kv_indices_triton[(bs,)](
|
691
831
|
self.req_to_token,
|
692
832
|
req_pool_indices,
|
@@ -695,6 +835,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
695
835
|
None,
|
696
836
|
kv_indices,
|
697
837
|
self.req_to_token.shape[1],
|
838
|
+
self.page_size,
|
698
839
|
)
|
699
840
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
700
841
|
qo_indptr = qo_indptr[: bs + 1]
|
@@ -712,7 +853,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
712
853
|
self.req_to_token,
|
713
854
|
)
|
714
855
|
)
|
715
|
-
|
716
856
|
if use_ragged:
|
717
857
|
# ragged prefill
|
718
858
|
wrapper_ragged.begin_forward(
|
@@ -723,23 +863,30 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
723
863
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
724
864
|
head_dim_vo=self.v_head_dim,
|
725
865
|
q_data_type=self.q_data_type,
|
866
|
+
causal=True,
|
726
867
|
)
|
727
868
|
else:
|
728
869
|
# mla paged prefill
|
729
|
-
|
870
|
+
if spec_info is not None:
|
871
|
+
assert (
|
872
|
+
self.page_size == 1
|
873
|
+
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
|
874
|
+
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
|
875
|
+
else:
|
876
|
+
kv_lens = paged_kernel_lens.to(torch.int32)
|
730
877
|
wrapper_paged.plan(
|
731
|
-
qo_indptr,
|
732
|
-
kv_indptr,
|
733
|
-
kv_indices,
|
734
|
-
kv_len_arr,
|
735
|
-
self.num_local_heads,
|
736
|
-
self.kv_lora_rank,
|
737
|
-
self.qk_rope_head_dim,
|
738
|
-
|
739
|
-
True,
|
740
|
-
sm_scale,
|
741
|
-
self.q_data_type,
|
742
|
-
self.data_type,
|
878
|
+
qo_indptr=qo_indptr,
|
879
|
+
kv_indptr=kv_indptr,
|
880
|
+
kv_indices=kv_indices,
|
881
|
+
kv_len_arr=kv_lens,
|
882
|
+
num_heads=self.num_local_heads,
|
883
|
+
head_dim_ckv=self.kv_lora_rank,
|
884
|
+
head_dim_kpe=self.qk_rope_head_dim,
|
885
|
+
page_size=self.page_size,
|
886
|
+
causal=True,
|
887
|
+
sm_scale=sm_scale,
|
888
|
+
q_data_type=self.q_data_type,
|
889
|
+
kv_data_type=self.data_type,
|
743
890
|
)
|
744
891
|
|
745
892
|
|
@@ -834,6 +981,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
834
981
|
call_fn(i, forward_batch)
|
835
982
|
|
836
983
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
984
|
+
|
837
985
|
kv_indices = torch.zeros(
|
838
986
|
(
|
839
987
|
self.speculative_num_steps,
|
@@ -869,6 +1017,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
869
1017
|
)
|
870
1018
|
|
871
1019
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
1020
|
+
|
872
1021
|
def call_fn(i, forward_batch):
|
873
1022
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
874
1023
|
forward_batch.batch_size,
|
@@ -885,6 +1034,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
885
1034
|
def init_forward_metadata_replay_cuda_graph(
|
886
1035
|
self, forward_batch: ForwardBatch, bs: int
|
887
1036
|
):
|
1037
|
+
|
888
1038
|
def call_fn(i, forward_batch):
|
889
1039
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
890
1040
|
bs,
|