sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -483,7 +483,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
|
|
483
483
|
).squeeze(1)
|
484
484
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
485
485
|
|
486
|
-
def init_cuda_graph_state(self, max_bs: int):
|
486
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
487
487
|
"""Initialize CUDA graph state for the attention backend.
|
488
488
|
|
489
489
|
Args:
|
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
629
629
|
# For multi-head latent attention
|
630
630
|
q_rope: Optional[torch.Tensor] = None,
|
631
631
|
k_rope: Optional[torch.Tensor] = None,
|
632
|
+
sinks: Optional[torch.Tensor] = None,
|
632
633
|
):
|
633
634
|
if k is not None:
|
634
635
|
assert v is not None
|
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
687
688
|
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
688
689
|
)
|
689
690
|
|
691
|
+
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
692
|
+
kwargs = {}
|
693
|
+
if sinks is not None:
|
694
|
+
kwargs["sinks"] = sinks
|
695
|
+
|
690
696
|
# Get the appropriate page table based on whether we're using local attention
|
691
697
|
if use_local_attn:
|
692
698
|
local_metadata = metadata.local_attn_metadata
|
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
737
743
|
k_descale=k_descale,
|
738
744
|
v_descale=v_descale,
|
739
745
|
return_softmax_lse=use_cascade_attn,
|
746
|
+
**kwargs,
|
740
747
|
)
|
741
748
|
|
742
749
|
if use_cascade_attn:
|
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
757
764
|
k_descale=k_descale,
|
758
765
|
v_descale=v_descale,
|
759
766
|
return_softmax_lse=True,
|
767
|
+
**kwargs,
|
760
768
|
)
|
761
769
|
o, _ = merge_state_v2_wrapper(
|
762
770
|
o,
|
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
898
906
|
# For multi-head latent attention
|
899
907
|
q_rope: Optional[torch.Tensor] = None,
|
900
908
|
k_rope: Optional[torch.Tensor] = None,
|
909
|
+
sinks: Optional[torch.Tensor] = None,
|
901
910
|
) -> torch.Tensor:
|
902
911
|
if k is not None:
|
903
912
|
assert v is not None
|
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
943
952
|
)
|
944
953
|
causal = not layer.is_cross_attention
|
945
954
|
|
955
|
+
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
956
|
+
kwargs = {}
|
957
|
+
if sinks is not None:
|
958
|
+
kwargs["sinks"] = sinks
|
959
|
+
|
946
960
|
k_descale, v_descale = None, None
|
947
961
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
948
962
|
# has corresponding quantization method so that layer.k_scale is not None,
|
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
985
999
|
softcap=layer.logit_cap,
|
986
1000
|
k_descale=k_descale,
|
987
1001
|
v_descale=v_descale,
|
1002
|
+
**kwargs,
|
988
1003
|
)
|
989
1004
|
elif use_local_attn:
|
990
1005
|
# Use chunked (local) attention batching for self-attention
|
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1003
1018
|
softcap=layer.logit_cap,
|
1004
1019
|
k_descale=k_descale,
|
1005
1020
|
v_descale=v_descale,
|
1021
|
+
**kwargs,
|
1006
1022
|
)
|
1007
1023
|
else:
|
1008
1024
|
page_table = metadata.page_table
|
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1030
1046
|
k_descale=k_descale,
|
1031
1047
|
v_descale=v_descale,
|
1032
1048
|
return_softmax_lse=use_cascade_attn,
|
1049
|
+
**kwargs,
|
1033
1050
|
)
|
1034
1051
|
if use_cascade_attn:
|
1035
1052
|
o, softmax_lse, *rest = result
|
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1050
1067
|
k_descale=k_descale,
|
1051
1068
|
v_descale=v_descale,
|
1052
1069
|
return_softmax_lse=True,
|
1070
|
+
**kwargs,
|
1053
1071
|
)
|
1054
1072
|
)
|
1055
1073
|
o, _ = merge_state_v2(
|
@@ -66,6 +66,10 @@ class PrefillMetadata:
|
|
66
66
|
# Reuse this workspace buffer across all flashinfer wrappers
|
67
67
|
global_workspace_buffer = None
|
68
68
|
|
69
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
70
|
+
# This is used to remove some host-to-device copy overhead.
|
71
|
+
global_override_indptr_cpu = None
|
72
|
+
|
69
73
|
|
70
74
|
class FlashInferAttnBackend(AttentionBackend):
|
71
75
|
"""Flashinfer attention kernels."""
|
@@ -118,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
118
122
|
# Allocate buffers
|
119
123
|
global global_workspace_buffer
|
120
124
|
if global_workspace_buffer is None:
|
125
|
+
# different from flashinfer zero_init_global_workspace_buffer
|
121
126
|
global_workspace_buffer = torch.empty(
|
122
127
|
global_config.flashinfer_workspace_size,
|
123
128
|
dtype=torch.uint8,
|
@@ -205,6 +210,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
205
210
|
self.indices_updater_decode.update(
|
206
211
|
forward_batch.req_pool_indices,
|
207
212
|
forward_batch.seq_lens,
|
213
|
+
forward_batch.seq_lens_cpu,
|
208
214
|
forward_batch.seq_lens_sum,
|
209
215
|
decode_wrappers=self.decode_wrappers,
|
210
216
|
encoder_lens=forward_batch.encoder_lens,
|
@@ -215,6 +221,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
215
221
|
self.indices_updater_prefill.update(
|
216
222
|
forward_batch.req_pool_indices,
|
217
223
|
forward_batch.seq_lens,
|
224
|
+
forward_batch.seq_lens_cpu,
|
218
225
|
forward_batch.seq_lens_sum,
|
219
226
|
prefix_lens=None,
|
220
227
|
prefill_wrappers=self.prefill_wrappers_paged,
|
@@ -229,6 +236,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
229
236
|
self.indices_updater_prefill.update(
|
230
237
|
forward_batch.req_pool_indices,
|
231
238
|
forward_batch.seq_lens,
|
239
|
+
forward_batch.seq_lens_cpu,
|
232
240
|
forward_batch.seq_lens_sum,
|
233
241
|
prefix_lens=None,
|
234
242
|
prefill_wrappers=self.prefill_wrappers_verify,
|
@@ -252,6 +260,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
252
260
|
self.indices_updater_prefill.update(
|
253
261
|
forward_batch.req_pool_indices,
|
254
262
|
forward_batch.seq_lens,
|
263
|
+
forward_batch.seq_lens_cpu,
|
255
264
|
forward_batch.seq_lens_sum,
|
256
265
|
prefix_lens,
|
257
266
|
prefill_wrappers=self.prefill_wrappers_paged,
|
@@ -327,6 +336,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
327
336
|
self.indices_updater_decode.update(
|
328
337
|
req_pool_indices,
|
329
338
|
seq_lens,
|
339
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
330
340
|
seq_lens_sum,
|
331
341
|
decode_wrappers=decode_wrappers,
|
332
342
|
encoder_lens=encoder_lens,
|
@@ -358,6 +368,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
358
368
|
self.indices_updater_prefill.update(
|
359
369
|
req_pool_indices,
|
360
370
|
seq_lens,
|
371
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
361
372
|
seq_lens_sum,
|
362
373
|
prefix_lens=None,
|
363
374
|
prefill_wrappers=prefill_wrappers,
|
@@ -387,6 +398,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
387
398
|
self.indices_updater_prefill.update(
|
388
399
|
req_pool_indices,
|
389
400
|
seq_lens,
|
401
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
390
402
|
seq_lens_sum,
|
391
403
|
prefix_lens=None,
|
392
404
|
prefill_wrappers=prefill_wrappers,
|
@@ -414,6 +426,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
414
426
|
self.indices_updater_decode.update(
|
415
427
|
req_pool_indices[:bs],
|
416
428
|
seq_lens[:bs],
|
429
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
417
430
|
seq_lens_sum,
|
418
431
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
419
432
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
@@ -423,6 +436,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
423
436
|
self.indices_updater_prefill.update(
|
424
437
|
req_pool_indices[:bs],
|
425
438
|
seq_lens[:bs],
|
439
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
426
440
|
seq_lens_sum,
|
427
441
|
prefix_lens=None,
|
428
442
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
@@ -434,6 +448,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
434
448
|
self.indices_updater_prefill.update(
|
435
449
|
req_pool_indices[:bs],
|
436
450
|
seq_lens[:bs],
|
451
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
437
452
|
seq_lens_sum,
|
438
453
|
prefix_lens=None,
|
439
454
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
@@ -581,7 +596,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
581
596
|
|
582
597
|
|
583
598
|
class FlashInferIndicesUpdaterDecode:
|
584
|
-
def __init__(self, model_runner: ModelRunner, attn_backend:
|
599
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
585
600
|
# Parse Constants
|
586
601
|
self.num_qo_heads = (
|
587
602
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -614,6 +629,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
614
629
|
self,
|
615
630
|
req_pool_indices: torch.Tensor,
|
616
631
|
seq_lens: torch.Tensor,
|
632
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
617
633
|
seq_lens_sum: int,
|
618
634
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
619
635
|
encoder_lens: Optional[torch.Tensor],
|
@@ -626,6 +642,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
626
642
|
self,
|
627
643
|
req_pool_indices: torch.Tensor,
|
628
644
|
seq_lens: torch.Tensor,
|
645
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
629
646
|
seq_lens_sum: int,
|
630
647
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
631
648
|
encoder_lens: Optional[torch.Tensor],
|
@@ -640,30 +657,39 @@ class FlashInferIndicesUpdaterDecode:
|
|
640
657
|
self.kv_indptr[0],
|
641
658
|
None,
|
642
659
|
spec_info,
|
660
|
+
seq_lens_cpu,
|
643
661
|
)
|
644
662
|
|
645
663
|
def update_sliding_window(
|
646
664
|
self,
|
647
665
|
req_pool_indices: torch.Tensor,
|
648
666
|
seq_lens: torch.Tensor,
|
667
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
649
668
|
seq_lens_sum: int,
|
650
669
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
651
670
|
encoder_lens: Optional[torch.Tensor],
|
652
671
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
653
672
|
):
|
673
|
+
assert self.sliding_window_size is not None
|
654
674
|
for wrapper_id in range(2):
|
655
675
|
if wrapper_id == 0:
|
656
676
|
# Sliding window attention
|
657
|
-
paged_kernel_lens_tmp = torch.
|
658
|
-
seq_lens,
|
659
|
-
torch.tensor(self.sliding_window_size + 1),
|
677
|
+
paged_kernel_lens_tmp = torch.clamp(
|
678
|
+
seq_lens, max=self.sliding_window_size + 1
|
660
679
|
)
|
661
|
-
|
680
|
+
if seq_lens_cpu is not None:
|
681
|
+
seq_lens_cpu_tmp = torch.clamp(
|
682
|
+
seq_lens_cpu, max=self.sliding_window_size + 1
|
683
|
+
)
|
684
|
+
paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
|
685
|
+
else:
|
686
|
+
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
|
662
687
|
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
|
663
688
|
else:
|
664
689
|
# Full attention
|
665
690
|
paged_kernel_lens_tmp = seq_lens
|
666
691
|
paged_kernel_lens_sum_tmp = seq_lens_sum
|
692
|
+
seq_lens_cpu_tmp = seq_lens_cpu
|
667
693
|
kv_start_idx_tmp = None
|
668
694
|
|
669
695
|
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
@@ -678,6 +704,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
678
704
|
self.kv_indptr[wrapper_id],
|
679
705
|
kv_start_idx_tmp,
|
680
706
|
spec_info,
|
707
|
+
seq_lens_cpu=seq_lens_cpu_tmp,
|
681
708
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
682
709
|
)
|
683
710
|
|
@@ -685,6 +712,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
685
712
|
self,
|
686
713
|
req_pool_indices: torch.Tensor,
|
687
714
|
seq_lens: torch.Tensor,
|
715
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
688
716
|
seq_lens_sum: int,
|
689
717
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
690
718
|
encoder_lens: Optional[torch.Tensor],
|
@@ -709,6 +737,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
709
737
|
self.kv_indptr[wrapper_id],
|
710
738
|
kv_start_idx,
|
711
739
|
spec_info,
|
740
|
+
seq_lens_cpu=seq_lens_cpu,
|
712
741
|
)
|
713
742
|
|
714
743
|
def call_begin_forward(
|
@@ -720,6 +749,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
720
749
|
kv_indptr: torch.Tensor,
|
721
750
|
kv_start_idx: torch.Tensor,
|
722
751
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
752
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
723
753
|
use_sliding_window_kv_pool: bool = False,
|
724
754
|
):
|
725
755
|
if spec_info is None:
|
@@ -756,6 +786,14 @@ class FlashInferIndicesUpdaterDecode:
|
|
756
786
|
)
|
757
787
|
)
|
758
788
|
|
789
|
+
global global_override_indptr_cpu
|
790
|
+
locally_override = False
|
791
|
+
if seq_lens_cpu is not None and global_override_indptr_cpu is None:
|
792
|
+
locally_override = True
|
793
|
+
global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
|
794
|
+
global_override_indptr_cpu[0] = 0
|
795
|
+
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
796
|
+
|
759
797
|
wrapper.begin_forward(
|
760
798
|
kv_indptr,
|
761
799
|
kv_indices,
|
@@ -769,9 +807,12 @@ class FlashInferIndicesUpdaterDecode:
|
|
769
807
|
non_blocking=True,
|
770
808
|
)
|
771
809
|
|
810
|
+
if locally_override:
|
811
|
+
global_override_indptr_cpu = None
|
812
|
+
|
772
813
|
|
773
814
|
class FlashInferIndicesUpdaterPrefill:
|
774
|
-
def __init__(self, model_runner: ModelRunner, attn_backend:
|
815
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
775
816
|
# Parse Constants
|
776
817
|
self.num_qo_heads = (
|
777
818
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -806,6 +847,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
806
847
|
self,
|
807
848
|
req_pool_indices: torch.Tensor,
|
808
849
|
seq_lens: torch.Tensor,
|
850
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
809
851
|
seq_lens_sum: int,
|
810
852
|
prefix_lens: torch.Tensor,
|
811
853
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -820,6 +862,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
820
862
|
self,
|
821
863
|
req_pool_indices: torch.Tensor,
|
822
864
|
seq_lens: torch.Tensor,
|
865
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
823
866
|
seq_lens_sum: int,
|
824
867
|
prefix_lens: torch.Tensor,
|
825
868
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -828,6 +871,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
828
871
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
829
872
|
):
|
830
873
|
if use_ragged:
|
874
|
+
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
875
|
+
# and forward_batch.extend_seq_lens_cpu
|
831
876
|
paged_kernel_lens = prefix_lens
|
832
877
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
833
878
|
else:
|
@@ -853,6 +898,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
853
898
|
self,
|
854
899
|
req_pool_indices: torch.Tensor,
|
855
900
|
seq_lens: torch.Tensor,
|
901
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
856
902
|
seq_lens_sum: int,
|
857
903
|
prefix_lens: torch.Tensor,
|
858
904
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -898,6 +944,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
898
944
|
self,
|
899
945
|
req_pool_indices: torch.Tensor,
|
900
946
|
seq_lens: torch.Tensor,
|
947
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
901
948
|
seq_lens_sum: int,
|
902
949
|
prefix_lens: torch.Tensor,
|
903
950
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -1020,11 +1067,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1020
1067
|
)
|
1021
1068
|
|
1022
1069
|
|
1023
|
-
# Use as a fast path to override the indptr in flashinfer's plan function
|
1024
|
-
# This is used to remove some host-to-device copy overhead.
|
1025
|
-
global global_override_indptr_cpu
|
1026
|
-
|
1027
|
-
|
1028
1070
|
class FlashInferMultiStepDraftBackend:
|
1029
1071
|
"""
|
1030
1072
|
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
@@ -1056,7 +1098,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1056
1098
|
self.kv_last_page_len = torch.ones(
|
1057
1099
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
1058
1100
|
)
|
1059
|
-
self.attn_backends = []
|
1101
|
+
self.attn_backends: List[FlashInferAttnBackend] = []
|
1060
1102
|
for i in range(self.speculative_num_steps):
|
1061
1103
|
self.attn_backends.append(
|
1062
1104
|
FlashInferAttnBackend(
|
@@ -1176,7 +1218,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1176
1218
|
encoder_lens=None,
|
1177
1219
|
forward_mode=ForwardMode.DECODE,
|
1178
1220
|
spec_info=forward_batch.spec_info,
|
1179
|
-
seq_lens_cpu=
|
1221
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
1180
1222
|
)
|
1181
1223
|
|
1182
1224
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
@@ -81,6 +81,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
81
81
|
# Allocate buffers
|
82
82
|
global global_workspace_buffer
|
83
83
|
if global_workspace_buffer is None:
|
84
|
+
# different from flashinfer zero_init_global_workspace_buffer
|
84
85
|
global_workspace_buffer = torch.empty(
|
85
86
|
global_config.flashinfer_workspace_size,
|
86
87
|
dtype=torch.uint8,
|
@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
|
|
57
57
|
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
58
58
|
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
59
59
|
|
60
|
+
# Parse args
|
60
61
|
self.skip_prefill = skip_prefill
|
61
|
-
|
62
62
|
max_bs = model_runner.req_to_token_pool.size
|
63
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
64
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
65
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
66
|
+
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
67
|
+
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
68
|
+
self.num_head = (
|
69
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
70
|
+
)
|
71
|
+
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
72
|
+
get_attention_tp_size()
|
73
|
+
)
|
74
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
75
|
+
self.max_context_len = model_runner.model_config.context_len
|
76
|
+
self.device = model_runner.device
|
77
|
+
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
78
|
+
self.static_kv_splits = get_bool_env_var(
|
79
|
+
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
80
|
+
)
|
81
|
+
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
63
82
|
|
83
|
+
# Check arguments
|
64
84
|
assert not (
|
65
85
|
model_runner.sliding_window_size is not None
|
66
86
|
and model_runner.model_config.is_encoder_decoder
|
67
87
|
), "Sliding window and cross attention are not supported together"
|
68
|
-
self.sliding_window_size = model_runner.sliding_window_size
|
69
88
|
|
89
|
+
# Initialize buffers
|
70
90
|
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
|
71
91
|
if kv_indptr_buf is None:
|
72
92
|
self.kv_indptr = torch.zeros(
|
@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
87
107
|
# When provided a buffer, create a clone for the second buffer
|
88
108
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
89
109
|
|
90
|
-
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
91
|
-
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
92
|
-
|
93
110
|
if not self.skip_prefill:
|
94
111
|
self.qo_indptr = torch.zeros(
|
95
112
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
99
116
|
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
100
117
|
)
|
101
118
|
|
102
|
-
|
103
|
-
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
104
|
-
|
105
|
-
self.num_head = (
|
106
|
-
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
107
|
-
)
|
108
|
-
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
109
|
-
get_attention_tp_size()
|
110
|
-
)
|
111
|
-
|
112
|
-
self.static_kv_splits = get_bool_env_var(
|
113
|
-
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
114
|
-
)
|
115
|
-
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
116
|
-
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
117
|
-
|
119
|
+
# Initialize forward metadata
|
118
120
|
self.forward_metadata: ForwardMetadata = None
|
119
121
|
|
120
|
-
self.max_context_len = model_runner.model_config.context_len
|
121
|
-
|
122
|
-
self.device = model_runner.device
|
123
|
-
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
124
|
-
|
125
122
|
def get_num_kv_splits(
|
126
123
|
self,
|
127
124
|
num_kv_splits: torch.Tensor,
|
@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
333
330
|
mask_indptr = None
|
334
331
|
attn_logits = None
|
335
332
|
attn_lse = None
|
336
|
-
max_extend_len =
|
333
|
+
max_extend_len = max(forward_batch.extend_seq_lens_cpu)
|
337
334
|
num_kv_splits = None
|
338
335
|
|
339
336
|
self.forward_metadata = ForwardMetadata(
|
@@ -23,10 +23,12 @@ if TYPE_CHECKING:
|
|
23
23
|
from sglang.srt.speculative.spec_info import SpecInfo
|
24
24
|
|
25
25
|
# Constants
|
26
|
-
DEFAULT_WORKSPACE_SIZE_MB =
|
26
|
+
DEFAULT_WORKSPACE_SIZE_MB = (
|
27
|
+
512 # Memory workspace size in MB, todo(Yingyi): read from config
|
28
|
+
)
|
27
29
|
|
28
30
|
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
29
|
-
|
31
|
+
global_zero_init_workspace_buffer = None
|
30
32
|
|
31
33
|
|
32
34
|
@dataclass
|
@@ -73,14 +75,14 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
73
75
|
# Workspace allocation
|
74
76
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
75
77
|
# Allocate buffers
|
76
|
-
global
|
77
|
-
if
|
78
|
-
|
78
|
+
global global_zero_init_workspace_buffer
|
79
|
+
if global_zero_init_workspace_buffer is None:
|
80
|
+
global_zero_init_workspace_buffer = torch.zeros(
|
79
81
|
self.workspace_size,
|
80
82
|
dtype=torch.uint8,
|
81
83
|
device=model_runner.device,
|
82
84
|
)
|
83
|
-
self.workspace_buffer =
|
85
|
+
self.workspace_buffer = global_zero_init_workspace_buffer
|
84
86
|
|
85
87
|
# CUDA graph state
|
86
88
|
self.decode_cuda_graph_metadata = {}
|