sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,7 @@ try:
|
|
32
32
|
mha_batch_prefill_func,
|
33
33
|
paged_attention_ragged,
|
34
34
|
)
|
35
|
-
from aiter.mla import mla_decode_fwd
|
35
|
+
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
36
36
|
except ImportError:
|
37
37
|
print(
|
38
38
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
@@ -52,10 +52,8 @@ class ForwardMetadata:
|
|
52
52
|
kv_indices: torch.Tensor
|
53
53
|
qo_indptr: torch.Tensor
|
54
54
|
kv_last_page_len: torch.Tensor
|
55
|
-
max_extend_len: int
|
56
|
-
max_prefix_extend_len: int
|
57
55
|
max_q_len: int
|
58
|
-
max_kv_len: int
|
56
|
+
max_kv_len: Optional[int]
|
59
57
|
|
60
58
|
|
61
59
|
global_workspace_buffer = None
|
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
|
|
71
69
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
72
70
|
):
|
73
71
|
super().__init__()
|
72
|
+
# Lazy import to avoid the initialization of cuda context
|
73
|
+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
74
|
+
extend_attention_fwd,
|
75
|
+
)
|
76
|
+
|
77
|
+
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
74
78
|
|
75
79
|
self.device = model_runner.device
|
76
80
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
77
81
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
82
|
+
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
78
83
|
self.num_head = (
|
79
84
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
80
85
|
)
|
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
|
|
157
162
|
spec_info = forward_batch.spec_info
|
158
163
|
qo_indptr = None
|
159
164
|
kv_last_page_len = None
|
160
|
-
|
165
|
+
max_q_len = None
|
161
166
|
|
162
167
|
if forward_batch.forward_mode.is_decode_or_idle():
|
163
168
|
if spec_info is None:
|
164
169
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
165
170
|
kv_indptr = kv_indptr[: bs + 1]
|
166
|
-
kv_indices = torch.
|
171
|
+
kv_indices = torch.empty(
|
167
172
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
168
173
|
)
|
169
174
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
|
|
183
188
|
qo_indptr = self.qo_indptr_[: bs + 1]
|
184
189
|
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
|
185
190
|
kv_last_page_len = self.kv_last_page_len[:bs]
|
186
|
-
|
191
|
+
max_q_len = 1
|
187
192
|
|
188
193
|
self.forward_metadata = ForwardMetadata(
|
189
194
|
kv_indptr,
|
190
195
|
kv_indices,
|
191
196
|
qo_indptr,
|
192
197
|
kv_last_page_len,
|
193
|
-
|
194
|
-
None,
|
195
|
-
None,
|
198
|
+
max_q_len,
|
196
199
|
None,
|
197
200
|
)
|
198
201
|
|
199
202
|
elif forward_batch.forward_mode.is_draft_extend():
|
200
203
|
if self.use_mla:
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
spec_info=None,
|
204
|
+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
205
|
+
spec_info.generate_attn_arg_prefill(
|
206
|
+
forward_batch.req_pool_indices,
|
207
|
+
forward_batch.seq_lens,
|
208
|
+
forward_batch.seq_lens_sum,
|
209
|
+
self.req_to_token,
|
210
|
+
)
|
209
211
|
)
|
210
212
|
self.forward_metadata = ForwardMetadata(
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
self.mla_indices_updater_prefill.kv_last_page_len,
|
215
|
-
self.
|
216
|
-
|
217
|
-
|
218
|
-
None,
|
213
|
+
kv_indptr,
|
214
|
+
kv_indices,
|
215
|
+
qo_indptr,
|
216
|
+
# self.mla_indices_updater_prefill.kv_last_page_len,
|
217
|
+
self.kv_last_page_len[:bs],
|
218
|
+
max(forward_batch.extend_seq_lens_cpu),
|
219
|
+
forward_batch.seq_lens_cpu.max().item(),
|
219
220
|
)
|
220
221
|
else:
|
221
222
|
self.indices_updater_prefill.update(
|
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
|
|
231
232
|
self.indices_updater_prefill.kv_indices,
|
232
233
|
None,
|
233
234
|
None,
|
234
|
-
None,
|
235
|
-
None,
|
236
235
|
self.indices_updater_prefill.max_q_len,
|
237
236
|
self.indices_updater_prefill.max_kv_len,
|
238
237
|
)
|
239
238
|
elif forward_batch.forward_mode.is_target_verify():
|
240
239
|
if self.use_mla:
|
241
|
-
|
242
|
-
|
240
|
+
draft_num = spec_info.draft_token_num
|
241
|
+
kv_lens = forward_batch.seq_lens + draft_num
|
242
|
+
kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs
|
243
|
+
device = forward_batch.seq_lens.device
|
244
|
+
|
245
|
+
qo_indptr = torch.arange(
|
246
|
+
0,
|
247
|
+
(1 + bs) * draft_num,
|
248
|
+
step=draft_num,
|
249
|
+
dtype=torch.int32,
|
250
|
+
device=device,
|
251
|
+
)
|
252
|
+
kv_indptr = self.kv_indptr
|
253
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
254
|
+
kv_indptr = kv_indptr[: bs + 1]
|
255
|
+
kv_indices = torch.empty(
|
256
|
+
kv_lens_sum,
|
257
|
+
dtype=torch.int32,
|
258
|
+
device=device,
|
259
|
+
)
|
260
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
261
|
+
self.req_to_token,
|
243
262
|
forward_batch.req_pool_indices,
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
263
|
+
kv_lens,
|
264
|
+
kv_indptr,
|
265
|
+
None,
|
266
|
+
kv_indices,
|
267
|
+
self.req_to_token.stride(0),
|
249
268
|
)
|
250
269
|
self.forward_metadata = ForwardMetadata(
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
self.mla_indices_updater_prefill.kv_last_page_len,
|
255
|
-
self.
|
256
|
-
|
257
|
-
None,
|
270
|
+
kv_indptr,
|
271
|
+
kv_indices,
|
272
|
+
qo_indptr,
|
273
|
+
# self.mla_indices_updater_prefill.kv_last_page_len,
|
274
|
+
self.kv_last_page_len[:bs],
|
275
|
+
draft_num,
|
258
276
|
None,
|
259
277
|
)
|
260
278
|
else:
|
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
|
|
271
289
|
self.indices_updater_prefill.kv_indices,
|
272
290
|
None,
|
273
291
|
None,
|
274
|
-
None,
|
275
|
-
None,
|
276
292
|
self.indices_updater_prefill.max_q_len,
|
277
293
|
self.indices_updater_prefill.max_kv_len,
|
278
294
|
)
|
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
|
|
283
299
|
extend_no_prefix = False
|
284
300
|
else:
|
285
301
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
286
|
-
|
287
302
|
if self.use_mla:
|
288
303
|
self.mla_indices_updater_prefill.update(
|
289
304
|
forward_batch.req_pool_indices,
|
290
|
-
|
291
|
-
|
305
|
+
forward_batch.extend_prefix_lens,
|
306
|
+
sum(forward_batch.extend_prefix_lens_cpu),
|
292
307
|
forward_batch.extend_seq_lens,
|
293
|
-
|
308
|
+
max(forward_batch.extend_seq_lens_cpu),
|
309
|
+
forward_batch.seq_lens_cpu.max().item(),
|
294
310
|
spec_info=None,
|
295
311
|
)
|
312
|
+
self.mla_indices_updater_prefill.kv_indptr += (
|
313
|
+
self.mla_indices_updater_prefill.qo_indptr
|
314
|
+
)
|
296
315
|
self.forward_metadata = ForwardMetadata(
|
297
316
|
self.mla_indices_updater_prefill.kv_indptr,
|
298
317
|
self.mla_indices_updater_prefill.kv_indices,
|
299
318
|
self.mla_indices_updater_prefill.qo_indptr,
|
300
|
-
self.
|
301
|
-
self.mla_indices_updater_prefill.
|
302
|
-
self.mla_indices_updater_prefill.
|
303
|
-
None,
|
304
|
-
None,
|
319
|
+
self.kv_last_page_len[:bs],
|
320
|
+
self.mla_indices_updater_prefill.max_q_len,
|
321
|
+
self.mla_indices_updater_prefill.max_kv_len,
|
305
322
|
)
|
306
323
|
else:
|
307
324
|
self.indices_updater_prefill.update(
|
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
|
|
317
334
|
self.indices_updater_prefill.kv_indices,
|
318
335
|
None,
|
319
336
|
None,
|
320
|
-
None,
|
321
|
-
None,
|
322
337
|
self.indices_updater_prefill.max_q_len,
|
323
338
|
self.indices_updater_prefill.max_kv_len,
|
324
339
|
)
|
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
359
374
|
if forward_mode.is_decode_or_idle():
|
360
375
|
qo_indptr = None
|
361
376
|
kv_last_page_len = None
|
362
|
-
|
377
|
+
max_q_len = None
|
363
378
|
|
364
379
|
if spec_info is None:
|
365
380
|
kv_indptr = self.kv_indptr
|
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
|
|
383
398
|
qo_indptr[1 : bs + 1] = torch.cumsum(
|
384
399
|
self.cuda_graph_kv_last_page_len[:bs], dim=0
|
385
400
|
)
|
386
|
-
max_extend_len = 1
|
387
401
|
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
402
|
+
max_q_len = 1
|
388
403
|
|
389
404
|
self.forward_metadata = ForwardMetadata(
|
390
405
|
kv_indptr,
|
391
406
|
kv_indices,
|
392
407
|
qo_indptr,
|
393
408
|
kv_last_page_len,
|
394
|
-
|
395
|
-
None,
|
396
|
-
None,
|
409
|
+
max_q_len,
|
397
410
|
None,
|
398
411
|
)
|
399
412
|
|
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
|
|
419
432
|
kv_indices,
|
420
433
|
self.req_to_token.stride(0),
|
421
434
|
)
|
422
|
-
|
423
|
-
|
424
|
-
kv_last_page_len = None
|
435
|
+
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
436
|
+
max_q_len = self.num_draft_tokens
|
425
437
|
|
426
438
|
self.forward_metadata = ForwardMetadata(
|
427
439
|
kv_indptr,
|
428
440
|
kv_indices,
|
429
441
|
qo_indptr,
|
430
442
|
kv_last_page_len,
|
431
|
-
|
432
|
-
None,
|
433
|
-
None,
|
443
|
+
max_q_len,
|
434
444
|
None,
|
435
445
|
)
|
436
446
|
else:
|
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
|
|
448
458
|
self.indices_updater_prefill.kv_indices,
|
449
459
|
None,
|
450
460
|
None,
|
451
|
-
None,
|
452
|
-
None,
|
453
461
|
self.indices_updater_prefill.max_q_len,
|
454
462
|
self.indices_updater_prefill.max_kv_len,
|
455
463
|
)
|
456
|
-
|
464
|
+
elif forward_mode.is_draft_extend():
|
465
|
+
num_tokens_per_bs = self.speculative_num_steps + 1
|
466
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
467
|
+
qo_indptr[: bs + 1] = torch.arange(
|
468
|
+
0,
|
469
|
+
bs * num_tokens_per_bs + 1,
|
470
|
+
step=num_tokens_per_bs,
|
471
|
+
dtype=torch.int32,
|
472
|
+
device=self.device,
|
473
|
+
)
|
474
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
475
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
476
|
+
kv_indices = self.cuda_graph_kv_indices
|
477
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
478
|
+
self.req_to_token,
|
479
|
+
req_pool_indices,
|
480
|
+
seq_lens,
|
481
|
+
kv_indptr,
|
482
|
+
None,
|
483
|
+
kv_indices,
|
484
|
+
self.req_to_token.stride(0),
|
485
|
+
)
|
486
|
+
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
487
|
+
max_q_len = num_tokens_per_bs
|
488
|
+
self.forward_metadata = ForwardMetadata(
|
489
|
+
kv_indptr,
|
490
|
+
kv_indices,
|
491
|
+
qo_indptr,
|
492
|
+
kv_last_page_len,
|
493
|
+
max_q_len,
|
494
|
+
None,
|
495
|
+
)
|
457
496
|
else:
|
458
497
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
459
498
|
|
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
|
|
488
527
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
489
528
|
|
490
529
|
elif forward_mode.is_target_verify():
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
530
|
+
bs = len(req_pool_indices)
|
531
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
532
|
+
qo_indptr[: bs + 1] = torch.arange(
|
533
|
+
0,
|
534
|
+
(1 + bs) * self.num_draft_tokens,
|
535
|
+
step=self.num_draft_tokens,
|
536
|
+
dtype=torch.int32,
|
537
|
+
device=self.device,
|
538
|
+
)
|
539
|
+
kv_lens = seq_lens + self.num_draft_tokens
|
540
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
541
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
542
|
+
kv_indices = self.cuda_graph_kv_indices
|
543
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
544
|
+
self.req_to_token,
|
545
|
+
req_pool_indices,
|
546
|
+
kv_lens,
|
547
|
+
kv_indptr,
|
548
|
+
None,
|
549
|
+
kv_indices,
|
550
|
+
self.req_to_token.stride(0),
|
551
|
+
)
|
552
|
+
elif forward_mode.is_draft_extend():
|
553
|
+
seq_lens = seq_lens[:bs]
|
554
|
+
accept_lens = spec_info.accept_length[:bs]
|
555
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
556
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
|
557
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
558
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
559
|
+
kv_indices = self.cuda_graph_kv_indices
|
560
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
561
|
+
self.req_to_token,
|
562
|
+
req_pool_indices,
|
563
|
+
seq_lens,
|
564
|
+
kv_indptr,
|
565
|
+
None,
|
566
|
+
kv_indices,
|
567
|
+
self.req_to_token.stride(0),
|
498
568
|
)
|
499
569
|
else:
|
500
570
|
raise ValueError("Invalid forward mode")
|
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
|
|
530
600
|
)
|
531
601
|
|
532
602
|
if self.use_mla:
|
533
|
-
|
534
|
-
|
603
|
+
max_q_len = self.forward_metadata.max_q_len
|
604
|
+
max_kv_len = self.forward_metadata.max_kv_len
|
535
605
|
kv_indptr = self.forward_metadata.kv_indptr
|
536
606
|
kv_indices = self.forward_metadata.kv_indices
|
537
|
-
kv_last_page_lens = self.forward_metadata.kv_last_page_len
|
538
607
|
qo_indptr = self.forward_metadata.qo_indptr
|
539
608
|
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
540
609
|
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
|
|
552
621
|
v,
|
553
622
|
qo_indptr,
|
554
623
|
qo_indptr,
|
555
|
-
|
556
|
-
|
624
|
+
max_q_len,
|
625
|
+
max_q_len,
|
557
626
|
softmax_scale=layer.scaling,
|
558
627
|
causal=True,
|
559
628
|
)
|
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
|
|
599
668
|
v,
|
600
669
|
qo_indptr,
|
601
670
|
kv_indptr,
|
602
|
-
|
603
|
-
|
671
|
+
max_q_len,
|
672
|
+
max_kv_len,
|
604
673
|
softmax_scale=layer.scaling,
|
605
674
|
causal=True,
|
606
675
|
)
|
607
676
|
return o
|
677
|
+
elif forward_batch.forward_mode.is_target_verify():
|
678
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
679
|
+
mla_decode_fwd(
|
680
|
+
q,
|
681
|
+
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
682
|
+
o,
|
683
|
+
self.forward_metadata.qo_indptr,
|
684
|
+
self.forward_metadata.kv_indptr,
|
685
|
+
self.forward_metadata.kv_indices,
|
686
|
+
self.forward_metadata.kv_last_page_len,
|
687
|
+
self.forward_metadata.max_q_len,
|
688
|
+
layer.scaling,
|
689
|
+
layer.logit_cap,
|
690
|
+
)
|
691
|
+
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
|
692
|
+
return o
|
693
|
+
elif forward_batch.forward_mode.is_draft_extend():
|
694
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
695
|
+
causal = True
|
696
|
+
sliding_window_size = -1
|
697
|
+
kv_indptr = self.forward_metadata.kv_indptr
|
698
|
+
kv_indices = self.forward_metadata.kv_indices
|
699
|
+
mla_prefill_fwd(
|
700
|
+
q,
|
701
|
+
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
702
|
+
o,
|
703
|
+
self.forward_metadata.qo_indptr,
|
704
|
+
self.forward_metadata.kv_indptr,
|
705
|
+
self.forward_metadata.kv_indices,
|
706
|
+
self.forward_metadata.kv_last_page_len,
|
707
|
+
self.forward_metadata.max_q_len,
|
708
|
+
layer.scaling,
|
709
|
+
layer.logit_cap,
|
710
|
+
)
|
711
|
+
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
|
712
|
+
return o
|
713
|
+
# self.extend_attention_fwd(
|
714
|
+
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
715
|
+
# k.contiguous(),
|
716
|
+
# v.contiguous(),
|
717
|
+
# o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
718
|
+
# forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
719
|
+
# forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
720
|
+
# self.forward_metadata.qo_indptr,
|
721
|
+
# kv_indptr,
|
722
|
+
# kv_indices,
|
723
|
+
# None,
|
724
|
+
# causal,
|
725
|
+
# None,
|
726
|
+
# self.forward_metadata.max_q_len,
|
727
|
+
# layer.scaling,
|
728
|
+
# layer.logit_cap,
|
729
|
+
# sliding_window_size,
|
730
|
+
# )
|
731
|
+
# return o
|
732
|
+
else:
|
733
|
+
raise ValueError(
|
734
|
+
f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
|
735
|
+
)
|
608
736
|
else:
|
609
737
|
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
610
738
|
layer.layer_id
|
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
662
790
|
self.forward_metadata.kv_indptr,
|
663
791
|
self.forward_metadata.kv_indices,
|
664
792
|
self.forward_metadata.kv_last_page_len,
|
665
|
-
self.forward_metadata.
|
793
|
+
self.forward_metadata.max_q_len,
|
666
794
|
layer.scaling,
|
667
795
|
layer.logit_cap,
|
668
796
|
)
|
@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
816
944
|
self.kv_indices = None
|
817
945
|
self.qo_indptr = None
|
818
946
|
self.kv_last_page_len = None
|
819
|
-
self.
|
820
|
-
self.
|
947
|
+
self.max_q_len = 0
|
948
|
+
self.max_kv_len = 0
|
821
949
|
|
822
950
|
def update(
|
823
951
|
self,
|
824
952
|
req_pool_indices: torch.Tensor,
|
825
|
-
|
826
|
-
|
953
|
+
kv_lens: torch.Tensor,
|
954
|
+
kv_lens_sum: int,
|
827
955
|
extend_lens: torch.Tensor,
|
828
|
-
|
956
|
+
max_q_len: int,
|
957
|
+
max_kv_len: int,
|
829
958
|
spec_info: Optional[SpecInfo],
|
830
959
|
):
|
831
960
|
# Keep the signature for type checking. It will be assigned during runtime.
|
@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
834
963
|
def update_single_wrapper(
|
835
964
|
self,
|
836
965
|
req_pool_indices: torch.Tensor,
|
837
|
-
|
838
|
-
|
966
|
+
kv_lens: torch.Tensor,
|
967
|
+
kv_lens_sum: int,
|
839
968
|
extend_lens: torch.Tensor,
|
840
|
-
|
969
|
+
max_q_len: int,
|
970
|
+
max_kv_len: int,
|
841
971
|
spec_info: Optional[SpecInfo],
|
842
972
|
):
|
843
|
-
|
844
|
-
paged_kernel_lens = prefix_lens
|
845
|
-
paged_kernel_lens_sum = prefix_lens_sum
|
846
|
-
|
847
973
|
bs = len(req_pool_indices)
|
848
974
|
|
849
975
|
kv_indptr = self.attn_backend.kv_indptr
|
850
976
|
|
851
977
|
if spec_info is None:
|
852
978
|
# Normal extend
|
853
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(
|
979
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
854
980
|
kv_indptr = kv_indptr[: bs + 1]
|
855
981
|
kv_indices = torch.empty(
|
856
|
-
|
982
|
+
kv_lens_sum,
|
857
983
|
dtype=torch.int32,
|
858
984
|
device=req_pool_indices.device,
|
859
985
|
)
|
860
986
|
create_flashinfer_kv_indices_triton[(bs,)](
|
861
987
|
self.req_to_token,
|
862
988
|
req_pool_indices,
|
863
|
-
|
989
|
+
kv_lens,
|
864
990
|
kv_indptr,
|
865
991
|
None,
|
866
992
|
kv_indices,
|
@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
870
996
|
qo_indptr = self.attn_backend.qo_indptr
|
871
997
|
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
872
998
|
qo_indptr = qo_indptr[: bs + 1]
|
873
|
-
|
874
|
-
max_extend_len = torch.max(extend_lens).item()
|
875
|
-
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
|
876
|
-
kv_indptr += qo_indptr
|
877
999
|
else:
|
878
1000
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
879
1001
|
spec_info.generate_attn_arg_prefill(
|
880
1002
|
req_pool_indices,
|
881
|
-
|
882
|
-
|
1003
|
+
kv_lens,
|
1004
|
+
kv_lens_sum,
|
883
1005
|
self.req_to_token,
|
884
1006
|
)
|
885
1007
|
)
|
@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
887
1009
|
self.kv_indptr = kv_indptr
|
888
1010
|
self.kv_indices = kv_indices
|
889
1011
|
self.qo_indptr = qo_indptr
|
890
|
-
self.
|
891
|
-
self.
|
1012
|
+
self.max_q_len = max_q_len
|
1013
|
+
self.max_kv_len = max_kv_len
|
1014
|
+
|
1015
|
+
|
1016
|
+
class AiterMultiStepDraftBackend:
|
1017
|
+
"""
|
1018
|
+
Wrap multiple triton attention backends as one for multiple consecutive
|
1019
|
+
draft decoding steps.
|
1020
|
+
"""
|
1021
|
+
|
1022
|
+
def __init__(
|
1023
|
+
self,
|
1024
|
+
model_runner: ModelRunner,
|
1025
|
+
topk: int,
|
1026
|
+
speculative_num_steps: int,
|
1027
|
+
):
|
1028
|
+
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
1029
|
+
|
1030
|
+
self.topk = topk
|
1031
|
+
self.speculative_num_steps = speculative_num_steps
|
1032
|
+
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
1033
|
+
max_bs = model_runner.req_to_token_pool.size * self.topk
|
1034
|
+
self.kv_indptr = torch.zeros(
|
1035
|
+
(
|
1036
|
+
self.speculative_num_steps,
|
1037
|
+
max_bs + 1,
|
1038
|
+
),
|
1039
|
+
dtype=torch.int32,
|
1040
|
+
device=model_runner.device,
|
1041
|
+
)
|
1042
|
+
self.attn_backends = []
|
1043
|
+
for i in range(self.speculative_num_steps):
|
1044
|
+
self.attn_backends.append(
|
1045
|
+
AiterAttnBackend(
|
1046
|
+
model_runner,
|
1047
|
+
skip_prefill=True,
|
1048
|
+
kv_indptr_buf=self.kv_indptr[i],
|
1049
|
+
)
|
1050
|
+
)
|
1051
|
+
self.max_context_len = self.attn_backends[0].max_context_len
|
1052
|
+
self.num_head = (
|
1053
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
1054
|
+
)
|
1055
|
+
self.device = model_runner.device
|
1056
|
+
# Cached variables for generate_draft_decode_kv_indices
|
1057
|
+
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
1058
|
+
self.page_size = model_runner.server_args.page_size
|
1059
|
+
assert self.page_size == 1, "Page size must be 1"
|
1060
|
+
|
1061
|
+
def common_template(
|
1062
|
+
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
1063
|
+
):
|
1064
|
+
num_seqs = forward_batch.batch_size
|
1065
|
+
bs = self.topk * num_seqs
|
1066
|
+
seq_lens_sum = forward_batch.seq_lens_sum
|
1067
|
+
|
1068
|
+
self.generate_draft_decode_kv_indices[
|
1069
|
+
(self.speculative_num_steps, num_seqs, self.topk)
|
1070
|
+
](
|
1071
|
+
forward_batch.req_pool_indices,
|
1072
|
+
forward_batch.req_to_token_pool.req_to_token,
|
1073
|
+
forward_batch.seq_lens,
|
1074
|
+
kv_indices_buffer,
|
1075
|
+
self.kv_indptr,
|
1076
|
+
forward_batch.positions,
|
1077
|
+
self.pool_len,
|
1078
|
+
kv_indices_buffer.shape[1],
|
1079
|
+
self.kv_indptr.shape[1],
|
1080
|
+
triton.next_power_of_2(num_seqs),
|
1081
|
+
triton.next_power_of_2(self.speculative_num_steps),
|
1082
|
+
triton.next_power_of_2(bs),
|
1083
|
+
self.page_size,
|
1084
|
+
)
|
1085
|
+
|
1086
|
+
for i in range(self.speculative_num_steps):
|
1087
|
+
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
1088
|
+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
1089
|
+
: seq_lens_sum * self.topk + bs * (i + 1)
|
1090
|
+
]
|
1091
|
+
call_fn(i, forward_batch)
|
1092
|
+
|
1093
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1094
|
+
kv_indices = torch.empty(
|
1095
|
+
(
|
1096
|
+
self.speculative_num_steps,
|
1097
|
+
forward_batch.batch_size * self.topk * self.max_context_len,
|
1098
|
+
),
|
1099
|
+
dtype=torch.int32,
|
1100
|
+
device=self.device,
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
def call_fn(i, forward_batch):
|
1104
|
+
forward_batch.spec_info.kv_indptr = (
|
1105
|
+
forward_batch.spec_info.kv_indptr.clone()
|
1106
|
+
)
|
1107
|
+
forward_batch.spec_info.kv_indices = (
|
1108
|
+
forward_batch.spec_info.kv_indices.clone()
|
1109
|
+
)
|
1110
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
1111
|
+
|
1112
|
+
self.common_template(forward_batch, kv_indices, call_fn)
|
1113
|
+
|
1114
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1115
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
1116
|
+
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
1117
|
+
dtype=torch.int32,
|
1118
|
+
device=self.device,
|
1119
|
+
)
|
1120
|
+
for i in range(self.speculative_num_steps):
|
1121
|
+
self.attn_backends[i].init_cuda_graph_state(
|
1122
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
1126
|
+
def call_fn(i, forward_batch):
|
1127
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
1128
|
+
forward_batch.batch_size,
|
1129
|
+
forward_batch.batch_size * self.topk,
|
1130
|
+
forward_batch.req_pool_indices,
|
1131
|
+
forward_batch.seq_lens,
|
1132
|
+
encoder_lens=None,
|
1133
|
+
forward_mode=ForwardMode.DECODE,
|
1134
|
+
spec_info=forward_batch.spec_info,
|
1135
|
+
)
|
1136
|
+
|
1137
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1138
|
+
|
1139
|
+
def init_forward_metadata_replay_cuda_graph(
|
1140
|
+
self, forward_batch: ForwardBatch, bs: int
|
1141
|
+
):
|
1142
|
+
def call_fn(i, forward_batch):
|
1143
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1144
|
+
bs,
|
1145
|
+
forward_batch.req_pool_indices,
|
1146
|
+
forward_batch.seq_lens,
|
1147
|
+
seq_lens_sum=-1,
|
1148
|
+
encoder_lens=None,
|
1149
|
+
forward_mode=ForwardMode.DECODE,
|
1150
|
+
spec_info=forward_batch.spec_info,
|
1151
|
+
seq_lens_cpu=None,
|
1152
|
+
)
|
1153
|
+
|
1154
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|