sglang 0.4.1.post2__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/srt/layers/attention/__init__.py +14 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +211 -81
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/logits_processor.py +167 -212
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/managers/detokenizer_manager.py +2 -0
- sglang/srt/managers/io_struct.py +12 -3
- sglang/srt/managers/schedule_batch.py +26 -2
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +71 -27
- sglang/srt/managers/tokenizer_manager.py +29 -20
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/model_executor/cuda_graph_runner.py +118 -73
- sglang/srt/model_executor/forward_batch_info.py +33 -8
- sglang/srt/model_executor/model_runner.py +63 -61
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +97 -26
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +21 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +9 -5
- sglang/srt/server_args.py +109 -51
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +618 -0
- sglang/srt/speculative/eagle_worker.py +170 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +15 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
- sglang-0.4.1.post4.dist-info/RECORD +329 -0
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
- sglang-0.4.1.post2.dist-info/RECORD +0 -197
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
|
|
63
63
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
64
64
|
from sglang.srt.server import _set_envs_and_config
|
65
65
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
66
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
66
67
|
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
67
68
|
|
68
69
|
|
@@ -214,6 +215,7 @@ def extend(reqs, model_runner):
|
|
214
215
|
tree_cache=None,
|
215
216
|
model_config=model_runner.model_config,
|
216
217
|
enable_overlap=False,
|
218
|
+
spec_algorithm=SpeculativeAlgorithm.NONE,
|
217
219
|
)
|
218
220
|
batch.prepare_for_extend()
|
219
221
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1,10 +1,14 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Optional
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
3
5
|
|
4
6
|
import torch
|
5
7
|
|
6
|
-
|
7
|
-
from sglang.srt.
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
11
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
8
12
|
|
9
13
|
|
10
14
|
class AttentionBackend(ABC):
|
@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
|
|
22
26
|
def init_forward_metadata_capture_cuda_graph(
|
23
27
|
self,
|
24
28
|
bs: int,
|
29
|
+
num_tokens: int,
|
25
30
|
req_pool_indices: torch.Tensor,
|
26
31
|
seq_lens: torch.Tensor,
|
27
|
-
encoder_lens: Optional[torch.Tensor]
|
32
|
+
encoder_lens: Optional[torch.Tensor],
|
33
|
+
forward_mode: ForwardMode,
|
34
|
+
spec_info: Optional[SpecInfo],
|
28
35
|
):
|
29
36
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
30
37
|
raise NotImplementedError()
|
@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
|
|
35
42
|
req_pool_indices: torch.Tensor,
|
36
43
|
seq_lens: torch.Tensor,
|
37
44
|
seq_lens_sum: int,
|
38
|
-
encoder_lens: Optional[torch.Tensor]
|
45
|
+
encoder_lens: Optional[torch.Tensor],
|
46
|
+
forward_mode: ForwardMode,
|
47
|
+
spec_info: Optional[SpecInfo],
|
39
48
|
):
|
40
49
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
41
50
|
raise NotImplementedError()
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import torch
|
6
|
-
import torch.nn as nn
|
7
6
|
|
8
7
|
from sglang.srt.layers.attention import AttentionBackend
|
9
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
52
51
|
|
53
52
|
self.forward_metadata = None
|
54
53
|
|
55
|
-
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
56
|
-
|
57
54
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
58
55
|
"""Init auxiliary variables for triton attention backend."""
|
59
56
|
|
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
115
112
|
ds_req_to_token,
|
116
113
|
)
|
117
114
|
|
118
|
-
def init_cuda_graph_state(self, max_bs: int):
|
119
|
-
# TODO(Andy): Support CUDA graph for double sparse attention
|
120
|
-
raise ValueError(
|
121
|
-
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
122
|
-
)
|
123
|
-
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
124
|
-
|
125
|
-
self.cuda_graph_start_loc = torch.zeros(
|
126
|
-
(max_bs,), dtype=torch.int32, device="cuda"
|
127
|
-
)
|
128
|
-
self.cuda_graph_attn_logits = torch.empty(
|
129
|
-
(
|
130
|
-
self.num_head,
|
131
|
-
self.cuda_graph_max_total_num_tokens,
|
132
|
-
),
|
133
|
-
dtype=self.reduce_dtype,
|
134
|
-
device="cuda",
|
135
|
-
)
|
136
|
-
|
137
|
-
def init_forward_metadata_capture_cuda_graph(
|
138
|
-
self,
|
139
|
-
bs: int,
|
140
|
-
req_pool_indices: torch.Tensor,
|
141
|
-
seq_lens: torch.Tensor,
|
142
|
-
encoder_lens=None,
|
143
|
-
):
|
144
|
-
# NOTE: encoder_lens expected to be zeros or None
|
145
|
-
self.forward_metadata = (
|
146
|
-
self.cuda_graph_start_loc,
|
147
|
-
self.cuda_graph_attn_logits,
|
148
|
-
self.cuda_graph_max_seq_len,
|
149
|
-
None,
|
150
|
-
)
|
151
|
-
|
152
|
-
def init_forward_metadata_replay_cuda_graph(
|
153
|
-
self,
|
154
|
-
bs: int,
|
155
|
-
req_pool_indices: torch.Tensor,
|
156
|
-
seq_lens: torch.Tensor,
|
157
|
-
seq_lens_sum: int,
|
158
|
-
encoder_lens=None,
|
159
|
-
):
|
160
|
-
# NOTE: encoder_lens expected to be zeros or None
|
161
|
-
self.cuda_graph_start_loc.zero_()
|
162
|
-
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
163
|
-
|
164
|
-
def get_cuda_graph_seq_len_fill_value(self):
|
165
|
-
return 1
|
166
|
-
|
167
115
|
def forward_extend(
|
168
116
|
self,
|
169
117
|
q,
|
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|
10
10
|
import os
|
11
11
|
from dataclasses import dataclass
|
12
12
|
from enum import Enum, auto
|
13
|
-
from typing import TYPE_CHECKING, List, Union
|
13
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
14
14
|
|
15
15
|
import torch
|
16
16
|
import triton
|
@@ -18,12 +18,13 @@ import triton.language as tl
|
|
18
18
|
|
19
19
|
from sglang.global_config import global_config
|
20
20
|
from sglang.srt.layers.attention import AttentionBackend
|
21
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
22
22
|
from sglang.srt.utils import is_flashinfer_available
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
25
25
|
from sglang.srt.layers.radix_attention import RadixAttention
|
26
26
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
27
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
27
28
|
|
28
29
|
if is_flashinfer_available():
|
29
30
|
from flashinfer import (
|
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
113
114
|
# Two wrappers: one for sliding window attention and one for full attention.
|
114
115
|
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
115
116
|
self.prefill_wrappers_paged = []
|
117
|
+
self.prefill_wrappers_verify = []
|
116
118
|
self.decode_wrappers = []
|
117
119
|
for _ in range(self.num_wrappers):
|
118
120
|
self.prefill_wrappers_paged.append(
|
119
121
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
120
122
|
)
|
123
|
+
self.prefill_wrappers_verify.append(
|
124
|
+
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
125
|
+
)
|
121
126
|
self.decode_wrappers.append(
|
122
127
|
BatchDecodeWithPagedKVCacheWrapper(
|
123
128
|
self.workspace_buffer,
|
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
135
140
|
# Other metadata
|
136
141
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
137
142
|
self.decode_cuda_graph_metadata = {}
|
143
|
+
self.prefill_cuda_graph_metadata = {}
|
138
144
|
|
139
145
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
140
146
|
if forward_batch.forward_mode.is_decode():
|
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
144
150
|
forward_batch.seq_lens_sum,
|
145
151
|
decode_wrappers=self.decode_wrappers,
|
146
152
|
encoder_lens=forward_batch.encoder_lens,
|
153
|
+
spec_info=forward_batch.spec_info,
|
147
154
|
)
|
148
155
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
156
|
+
elif forward_batch.forward_mode.is_draft_extend():
|
157
|
+
self.indices_updater_prefill.update(
|
158
|
+
forward_batch.req_pool_indices,
|
159
|
+
forward_batch.seq_lens,
|
160
|
+
forward_batch.seq_lens_sum,
|
161
|
+
prefix_lens=None,
|
162
|
+
prefill_wrappers=self.prefill_wrappers_paged,
|
163
|
+
use_ragged=False,
|
164
|
+
encoder_lens=forward_batch.encoder_lens,
|
165
|
+
spec_info=forward_batch.spec_info,
|
166
|
+
)
|
167
|
+
self.forward_metadata = PrefillMetadata(
|
168
|
+
self.prefill_wrappers_paged, False, False
|
169
|
+
)
|
170
|
+
elif forward_batch.forward_mode.is_target_verify():
|
171
|
+
self.indices_updater_prefill.update(
|
172
|
+
forward_batch.req_pool_indices,
|
173
|
+
forward_batch.seq_lens,
|
174
|
+
forward_batch.seq_lens_sum,
|
175
|
+
prefix_lens=None,
|
176
|
+
prefill_wrappers=self.prefill_wrappers_verify,
|
177
|
+
use_ragged=False,
|
178
|
+
encoder_lens=forward_batch.encoder_lens,
|
179
|
+
spec_info=forward_batch.spec_info,
|
180
|
+
)
|
181
|
+
self.forward_metadata = PrefillMetadata(
|
182
|
+
self.prefill_wrappers_verify, False, False
|
183
|
+
)
|
149
184
|
else:
|
150
185
|
prefix_lens = forward_batch.extend_prefix_lens
|
151
186
|
|
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
165
200
|
prefill_wrappers=self.prefill_wrappers_paged,
|
166
201
|
use_ragged=use_ragged,
|
167
202
|
encoder_lens=forward_batch.encoder_lens,
|
203
|
+
spec_info=None,
|
168
204
|
)
|
169
205
|
self.forward_metadata = PrefillMetadata(
|
170
206
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
@@ -180,37 +216,82 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
180
216
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
181
217
|
]
|
182
218
|
|
219
|
+
self.cuda_graph_custom_mask = torch.zeros(
|
220
|
+
(max_bs * self.max_context_len),
|
221
|
+
dtype=torch.uint8,
|
222
|
+
device="cuda",
|
223
|
+
)
|
224
|
+
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
225
|
+
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
226
|
+
|
183
227
|
def init_forward_metadata_capture_cuda_graph(
|
184
228
|
self,
|
185
229
|
bs: int,
|
230
|
+
num_tokens: int,
|
186
231
|
req_pool_indices: torch.Tensor,
|
187
232
|
seq_lens: torch.Tensor,
|
188
|
-
encoder_lens: torch.Tensor
|
233
|
+
encoder_lens: Optional[torch.Tensor],
|
234
|
+
forward_mode: ForwardMode,
|
235
|
+
spec_info: Optional[SpecInfo],
|
189
236
|
):
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
237
|
+
if forward_mode.is_decode():
|
238
|
+
decode_wrappers = []
|
239
|
+
for i in range(self.num_wrappers):
|
240
|
+
decode_wrappers.append(
|
241
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
242
|
+
self.workspace_buffer,
|
243
|
+
"NHD",
|
244
|
+
use_cuda_graph=True,
|
245
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
246
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
247
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
248
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
249
|
+
:num_tokens
|
250
|
+
],
|
251
|
+
)
|
201
252
|
)
|
253
|
+
seq_lens_sum = seq_lens.sum().item()
|
254
|
+
self.indices_updater_decode.update(
|
255
|
+
req_pool_indices,
|
256
|
+
seq_lens,
|
257
|
+
seq_lens_sum,
|
258
|
+
decode_wrappers=decode_wrappers,
|
259
|
+
encoder_lens=encoder_lens,
|
260
|
+
spec_info=spec_info,
|
202
261
|
)
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
262
|
+
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
263
|
+
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
264
|
+
elif forward_mode.is_target_verify():
|
265
|
+
prefill_wrappers = []
|
266
|
+
for i in range(self.num_wrappers):
|
267
|
+
prefill_wrappers.append(
|
268
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
269
|
+
self.workspace_buffer,
|
270
|
+
"NHD",
|
271
|
+
use_cuda_graph=True,
|
272
|
+
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
|
273
|
+
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
|
274
|
+
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
275
|
+
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
276
|
+
custom_mask_buf=self.cuda_graph_custom_mask,
|
277
|
+
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
278
|
+
)
|
279
|
+
)
|
280
|
+
seq_lens_sum = seq_lens.sum().item()
|
281
|
+
self.indices_updater_prefill.update(
|
282
|
+
req_pool_indices,
|
283
|
+
seq_lens,
|
284
|
+
seq_lens_sum,
|
285
|
+
prefix_lens=None,
|
286
|
+
prefill_wrappers=prefill_wrappers,
|
287
|
+
use_ragged=False,
|
288
|
+
encoder_lens=encoder_lens,
|
289
|
+
spec_info=spec_info,
|
290
|
+
)
|
291
|
+
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
|
292
|
+
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
|
293
|
+
else:
|
294
|
+
raise ValueError(f"Invalid mode: {forward_mode=}")
|
214
295
|
|
215
296
|
def init_forward_metadata_replay_cuda_graph(
|
216
297
|
self,
|
@@ -218,24 +299,41 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
218
299
|
req_pool_indices: torch.Tensor,
|
219
300
|
seq_lens: torch.Tensor,
|
220
301
|
seq_lens_sum: int,
|
221
|
-
encoder_lens: torch.Tensor
|
302
|
+
encoder_lens: Optional[torch.Tensor],
|
303
|
+
forward_mode: ForwardMode,
|
304
|
+
spec_info: Optional[SpecInfo],
|
222
305
|
):
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
306
|
+
if forward_mode.is_decode():
|
307
|
+
self.indices_updater_decode.update(
|
308
|
+
req_pool_indices[:bs],
|
309
|
+
seq_lens[:bs],
|
310
|
+
seq_lens_sum,
|
311
|
+
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
312
|
+
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
313
|
+
spec_info=spec_info,
|
314
|
+
)
|
315
|
+
elif forward_mode.is_target_verify():
|
316
|
+
self.indices_updater_prefill.update(
|
317
|
+
req_pool_indices[:bs],
|
318
|
+
seq_lens[:bs],
|
319
|
+
seq_lens_sum,
|
320
|
+
prefix_lens=None,
|
321
|
+
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
322
|
+
use_ragged=False,
|
323
|
+
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
324
|
+
spec_info=spec_info,
|
325
|
+
)
|
326
|
+
else:
|
327
|
+
raise ValueError("Invalid forward mode")
|
230
328
|
|
231
329
|
def get_cuda_graph_seq_len_fill_value(self):
|
232
330
|
return 0
|
233
331
|
|
234
332
|
def forward_extend(
|
235
333
|
self,
|
236
|
-
q,
|
237
|
-
k,
|
238
|
-
v,
|
334
|
+
q: torch.Tensor,
|
335
|
+
k: torch.Tensor,
|
336
|
+
v: torch.Tensor,
|
239
337
|
layer: RadixAttention,
|
240
338
|
forward_batch: ForwardBatch,
|
241
339
|
save_kv_cache=True,
|
@@ -293,9 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
293
391
|
|
294
392
|
def forward_decode(
|
295
393
|
self,
|
296
|
-
q,
|
297
|
-
k,
|
298
|
-
v,
|
394
|
+
q: torch.Tensor,
|
395
|
+
k: torch.Tensor,
|
396
|
+
v: torch.Tensor,
|
299
397
|
layer: RadixAttention,
|
300
398
|
forward_batch: ForwardBatch,
|
301
399
|
save_kv_cache=True,
|
@@ -348,7 +446,6 @@ class FlashInferIndicesUpdaterDecode:
|
|
348
446
|
self.data_type = model_runner.kv_cache_dtype
|
349
447
|
self.q_data_type = model_runner.dtype
|
350
448
|
self.sliding_window_size = model_runner.sliding_window_size
|
351
|
-
|
352
449
|
self.attn_backend = attn_backend
|
353
450
|
|
354
451
|
# Buffers and wrappers
|
@@ -371,7 +468,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
371
468
|
seq_lens: torch.Tensor,
|
372
469
|
seq_lens_sum: int,
|
373
470
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
374
|
-
encoder_lens: torch.Tensor,
|
471
|
+
encoder_lens: Optional[torch.Tensor],
|
472
|
+
spec_info: Optional[SpecInfo],
|
375
473
|
):
|
376
474
|
# Keep the signature for type checking. It will be assigned during runtime.
|
377
475
|
raise NotImplementedError()
|
@@ -382,7 +480,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
382
480
|
seq_lens: torch.Tensor,
|
383
481
|
seq_lens_sum: int,
|
384
482
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
385
|
-
encoder_lens: torch.Tensor,
|
483
|
+
encoder_lens: Optional[torch.Tensor],
|
484
|
+
spec_info: Optional[SpecInfo],
|
386
485
|
):
|
387
486
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
388
487
|
self.call_begin_forward(
|
@@ -392,6 +491,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
392
491
|
seq_lens_sum,
|
393
492
|
self.kv_indptr[0],
|
394
493
|
None,
|
494
|
+
spec_info,
|
395
495
|
)
|
396
496
|
|
397
497
|
def update_sliding_window(
|
@@ -400,7 +500,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
400
500
|
seq_lens: torch.Tensor,
|
401
501
|
seq_lens_sum: int,
|
402
502
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
403
|
-
encoder_lens: torch.Tensor,
|
503
|
+
encoder_lens: Optional[torch.Tensor],
|
504
|
+
spec_info: Optional[SpecInfo],
|
404
505
|
):
|
405
506
|
for wrapper_id in range(2):
|
406
507
|
if wrapper_id == 0:
|
@@ -424,6 +525,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
424
525
|
paged_kernel_lens_sum_tmp,
|
425
526
|
self.kv_indptr[wrapper_id],
|
426
527
|
kv_start_idx_tmp,
|
528
|
+
spec_info,
|
427
529
|
)
|
428
530
|
|
429
531
|
def update_cross_attention(
|
@@ -432,7 +534,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
432
534
|
seq_lens: torch.Tensor,
|
433
535
|
seq_lens_sum: int,
|
434
536
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
435
|
-
encoder_lens: torch.Tensor,
|
537
|
+
encoder_lens: Optional[torch.Tensor],
|
538
|
+
spec_info: Optional[SpecInfo],
|
436
539
|
):
|
437
540
|
for wrapper_id in range(2):
|
438
541
|
if wrapper_id == 0:
|
@@ -452,6 +555,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
452
555
|
seq_lens_sum,
|
453
556
|
self.kv_indptr[wrapper_id],
|
454
557
|
kv_start_idx,
|
558
|
+
spec_info,
|
455
559
|
)
|
456
560
|
|
457
561
|
def call_begin_forward(
|
@@ -462,23 +566,30 @@ class FlashInferIndicesUpdaterDecode:
|
|
462
566
|
paged_kernel_lens_sum: int,
|
463
567
|
kv_indptr: torch.Tensor,
|
464
568
|
kv_start_idx: torch.Tensor,
|
569
|
+
spec_info: Optional[SpecInfo],
|
465
570
|
):
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
571
|
+
if spec_info is None:
|
572
|
+
bs = len(req_pool_indices)
|
573
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
574
|
+
kv_indptr = kv_indptr[: bs + 1]
|
575
|
+
kv_indices = torch.empty(
|
576
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
577
|
+
)
|
578
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
579
|
+
self.req_to_token,
|
580
|
+
req_pool_indices,
|
581
|
+
paged_kernel_lens,
|
582
|
+
kv_indptr,
|
583
|
+
kv_start_idx,
|
584
|
+
kv_indices,
|
585
|
+
self.req_to_token.shape[1],
|
586
|
+
)
|
587
|
+
else:
|
588
|
+
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
|
589
|
+
req_pool_indices,
|
590
|
+
paged_kernel_lens,
|
591
|
+
self.req_to_token,
|
592
|
+
)
|
482
593
|
|
483
594
|
wrapper.end_forward()
|
484
595
|
wrapper.begin_forward(
|
@@ -507,7 +618,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
507
618
|
self.data_type = model_runner.kv_cache_dtype
|
508
619
|
self.q_data_type = model_runner.dtype
|
509
620
|
self.sliding_window_size = model_runner.sliding_window_size
|
510
|
-
|
511
621
|
self.attn_backend = attn_backend
|
512
622
|
|
513
623
|
# Buffers and wrappers
|
@@ -534,7 +644,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
534
644
|
prefix_lens: torch.Tensor,
|
535
645
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
536
646
|
use_ragged: bool,
|
537
|
-
encoder_lens: torch.Tensor,
|
647
|
+
encoder_lens: Optional[torch.Tensor],
|
648
|
+
spec_info: Optional[SpecInfo],
|
538
649
|
):
|
539
650
|
# Keep the signature for type checking. It will be assigned during runtime.
|
540
651
|
raise NotImplementedError()
|
@@ -547,7 +658,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
547
658
|
prefix_lens: torch.Tensor,
|
548
659
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
549
660
|
use_ragged: bool,
|
550
|
-
encoder_lens: torch.Tensor,
|
661
|
+
encoder_lens: Optional[torch.Tensor],
|
662
|
+
spec_info: Optional[SpecInfo],
|
551
663
|
):
|
552
664
|
if use_ragged:
|
553
665
|
paged_kernel_lens = prefix_lens
|
@@ -568,6 +680,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
568
680
|
self.kv_indptr[0],
|
569
681
|
self.qo_indptr[0],
|
570
682
|
use_ragged,
|
683
|
+
spec_info,
|
571
684
|
)
|
572
685
|
|
573
686
|
def update_sliding_window(
|
@@ -578,7 +691,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
578
691
|
prefix_lens: torch.Tensor,
|
579
692
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
580
693
|
use_ragged: bool,
|
581
|
-
encoder_lens: torch.Tensor,
|
694
|
+
encoder_lens: Optional[torch.Tensor],
|
695
|
+
spec_info: Optional[SpecInfo],
|
582
696
|
):
|
583
697
|
for wrapper_id in range(2):
|
584
698
|
if wrapper_id == 0:
|
@@ -607,6 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
607
721
|
self.kv_indptr[wrapper_id],
|
608
722
|
self.qo_indptr[wrapper_id],
|
609
723
|
use_ragged,
|
724
|
+
spec_info,
|
610
725
|
)
|
611
726
|
|
612
727
|
def update_cross_attention(
|
@@ -617,7 +732,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
617
732
|
prefix_lens: torch.Tensor,
|
618
733
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
619
734
|
use_ragged: bool,
|
620
|
-
encoder_lens: torch.Tensor,
|
735
|
+
encoder_lens: Optional[torch.Tensor],
|
736
|
+
spec_info: Optional[SpecInfo],
|
621
737
|
):
|
622
738
|
for wrapper_id in range(2):
|
623
739
|
if wrapper_id == 0:
|
@@ -643,6 +759,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
643
759
|
self.kv_indptr[wrapper_id],
|
644
760
|
self.qo_indptr[wrapper_id],
|
645
761
|
use_ragged,
|
762
|
+
spec_info,
|
646
763
|
)
|
647
764
|
|
648
765
|
def call_begin_forward(
|
@@ -658,25 +775,37 @@ class FlashInferIndicesUpdaterPrefill:
|
|
658
775
|
kv_indptr: torch.Tensor,
|
659
776
|
qo_indptr: torch.Tensor,
|
660
777
|
use_ragged: bool,
|
778
|
+
spec_info: Optional[SpecInfo],
|
661
779
|
):
|
662
780
|
bs = len(req_pool_indices)
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
781
|
+
if spec_info is None:
|
782
|
+
# Normal extend
|
783
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
784
|
+
kv_indptr = kv_indptr[: bs + 1]
|
785
|
+
kv_indices = torch.empty(
|
786
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
787
|
+
)
|
788
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
789
|
+
self.req_to_token,
|
790
|
+
req_pool_indices,
|
791
|
+
paged_kernel_lens,
|
792
|
+
kv_indptr,
|
793
|
+
kv_start_idx,
|
794
|
+
kv_indices,
|
795
|
+
self.req_to_token.shape[1],
|
796
|
+
)
|
677
797
|
|
678
|
-
|
679
|
-
|
798
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
799
|
+
qo_indptr = qo_indptr[: bs + 1]
|
800
|
+
custom_mask = None
|
801
|
+
else:
|
802
|
+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
803
|
+
spec_info.generate_attn_arg_prefill(
|
804
|
+
req_pool_indices,
|
805
|
+
paged_kernel_lens,
|
806
|
+
self.req_to_token,
|
807
|
+
)
|
808
|
+
)
|
680
809
|
|
681
810
|
# extend part
|
682
811
|
if use_ragged:
|
@@ -702,6 +831,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
702
831
|
self.head_dim,
|
703
832
|
1,
|
704
833
|
q_data_type=self.q_data_type,
|
834
|
+
custom_mask=custom_mask,
|
705
835
|
)
|
706
836
|
|
707
837
|
|