sglang 0.4.1.post2__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/srt/layers/attention/__init__.py +14 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +211 -81
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/logits_processor.py +167 -212
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/managers/detokenizer_manager.py +2 -0
- sglang/srt/managers/io_struct.py +12 -3
- sglang/srt/managers/schedule_batch.py +26 -2
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +71 -27
- sglang/srt/managers/tokenizer_manager.py +29 -20
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/model_executor/cuda_graph_runner.py +118 -73
- sglang/srt/model_executor/forward_batch_info.py +33 -8
- sglang/srt/model_executor/model_runner.py +63 -61
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +97 -26
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +21 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +9 -5
- sglang/srt/server_args.py +109 -51
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +618 -0
- sglang/srt/speculative/eagle_worker.py +170 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +15 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
- sglang-0.4.1.post4.dist-info/RECORD +329 -0
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
- sglang-0.4.1.post2.dist-info/RECORD +0 -197
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -23,6 +23,7 @@ from typing import List, Optional
|
|
23
23
|
import torch
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
26
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
26
27
|
from sglang.srt.utils import (
|
27
28
|
get_amdgpu_memory_capacity,
|
28
29
|
get_hpu_memory_capacity,
|
@@ -42,7 +43,6 @@ class ServerArgs:
|
|
42
43
|
model_path: str
|
43
44
|
tokenizer_path: Optional[str] = None
|
44
45
|
tokenizer_mode: str = "auto"
|
45
|
-
skip_tokenizer_init: bool = False
|
46
46
|
load_format: str = "auto"
|
47
47
|
trust_remote_code: bool = True
|
48
48
|
dtype: str = "auto"
|
@@ -54,6 +54,8 @@ class ServerArgs:
|
|
54
54
|
chat_template: Optional[str] = None
|
55
55
|
is_embedding: bool = False
|
56
56
|
revision: Optional[str] = None
|
57
|
+
skip_tokenizer_init: bool = False
|
58
|
+
return_token_ids: bool = False
|
57
59
|
|
58
60
|
# Port for the HTTP server
|
59
61
|
host: str = "127.0.0.1"
|
@@ -107,14 +109,6 @@ class ServerArgs:
|
|
107
109
|
# Model override args in JSON
|
108
110
|
json_model_override_args: str = "{}"
|
109
111
|
|
110
|
-
# Double Sparsity
|
111
|
-
enable_double_sparsity: bool = False
|
112
|
-
ds_channel_config_path: str = None
|
113
|
-
ds_heavy_channel_num: int = 32
|
114
|
-
ds_heavy_token_num: int = 256
|
115
|
-
ds_heavy_channel_type: str = "qk"
|
116
|
-
ds_sparse_decode_threshold: int = 4096
|
117
|
-
|
118
112
|
# LoRA
|
119
113
|
lora_paths: Optional[List[str]] = None
|
120
114
|
max_loras_per_batch: int = 8
|
@@ -124,6 +118,21 @@ class ServerArgs:
|
|
124
118
|
sampling_backend: Optional[str] = None
|
125
119
|
grammar_backend: Optional[str] = "outlines"
|
126
120
|
|
121
|
+
# Speculative decoding
|
122
|
+
speculative_draft_model_path: Optional[str] = None
|
123
|
+
speculative_algorithm: Optional[str] = None
|
124
|
+
speculative_num_steps: int = 5
|
125
|
+
speculative_num_draft_tokens: int = 64
|
126
|
+
speculative_eagle_topk: int = 8
|
127
|
+
|
128
|
+
# Double Sparsity
|
129
|
+
enable_double_sparsity: bool = False
|
130
|
+
ds_channel_config_path: str = None
|
131
|
+
ds_heavy_channel_num: int = 32
|
132
|
+
ds_heavy_token_num: int = 256
|
133
|
+
ds_heavy_channel_type: str = "qk"
|
134
|
+
ds_sparse_decode_threshold: int = 4096
|
135
|
+
|
127
136
|
# Optimization/debug options
|
128
137
|
disable_radix_cache: bool = False
|
129
138
|
disable_jump_forward: bool = False
|
@@ -239,6 +248,17 @@ class ServerArgs:
|
|
239
248
|
"Overlap scheduler is disabled."
|
240
249
|
)
|
241
250
|
|
251
|
+
# Speculative Decoding
|
252
|
+
if self.speculative_algorithm == "EAGLE":
|
253
|
+
self.prefill_only_one_req = True
|
254
|
+
self.disable_cuda_graph_padding = True
|
255
|
+
self.disable_radix_cache = True
|
256
|
+
self.disable_overlap_schedule = True
|
257
|
+
self.chunked_prefill_size = -1
|
258
|
+
logger.info(
|
259
|
+
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
|
260
|
+
)
|
261
|
+
|
242
262
|
# GGUF
|
243
263
|
if (
|
244
264
|
self.load_format == "auto" or self.load_format == "gguf"
|
@@ -275,11 +295,6 @@ class ServerArgs:
|
|
275
295
|
"tokenizer if available, and 'slow' will "
|
276
296
|
"always use the slow tokenizer.",
|
277
297
|
)
|
278
|
-
parser.add_argument(
|
279
|
-
"--skip-tokenizer-init",
|
280
|
-
action="store_true",
|
281
|
-
help="If set, skip init tokenizer and pass input_ids in generate request",
|
282
|
-
)
|
283
298
|
parser.add_argument(
|
284
299
|
"--load-format",
|
285
300
|
type=str,
|
@@ -387,6 +402,17 @@ class ServerArgs:
|
|
387
402
|
"name, a tag name, or a commit id. If unspecified, will use "
|
388
403
|
"the default version.",
|
389
404
|
)
|
405
|
+
parser.add_argument(
|
406
|
+
"--skip-tokenizer-init",
|
407
|
+
action="store_true",
|
408
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
409
|
+
)
|
410
|
+
parser.add_argument(
|
411
|
+
"--return-token-ids",
|
412
|
+
action="store_true",
|
413
|
+
default=ServerArgs.return_token_ids,
|
414
|
+
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
415
|
+
)
|
390
416
|
|
391
417
|
# Memory and scheduling
|
392
418
|
parser.add_argument(
|
@@ -595,43 +621,6 @@ class ServerArgs:
|
|
595
621
|
default=ServerArgs.json_model_override_args,
|
596
622
|
)
|
597
623
|
|
598
|
-
# Double Sparsity
|
599
|
-
parser.add_argument(
|
600
|
-
"--enable-double-sparsity",
|
601
|
-
action="store_true",
|
602
|
-
help="Enable double sparsity attention",
|
603
|
-
)
|
604
|
-
parser.add_argument(
|
605
|
-
"--ds-channel-config-path",
|
606
|
-
type=str,
|
607
|
-
default=ServerArgs.ds_channel_config_path,
|
608
|
-
help="The path of the double sparsity channel config",
|
609
|
-
)
|
610
|
-
parser.add_argument(
|
611
|
-
"--ds-heavy-channel-num",
|
612
|
-
type=int,
|
613
|
-
default=ServerArgs.ds_heavy_channel_num,
|
614
|
-
help="The number of heavy channels in double sparsity attention",
|
615
|
-
)
|
616
|
-
parser.add_argument(
|
617
|
-
"--ds-heavy-token-num",
|
618
|
-
type=int,
|
619
|
-
default=ServerArgs.ds_heavy_token_num,
|
620
|
-
help="The number of heavy tokens in double sparsity attention",
|
621
|
-
)
|
622
|
-
parser.add_argument(
|
623
|
-
"--ds-heavy-channel-type",
|
624
|
-
type=str,
|
625
|
-
default=ServerArgs.ds_heavy_channel_type,
|
626
|
-
help="The type of heavy channels in double sparsity attention",
|
627
|
-
)
|
628
|
-
parser.add_argument(
|
629
|
-
"--ds-sparse-decode-threshold",
|
630
|
-
type=int,
|
631
|
-
default=ServerArgs.ds_sparse_decode_threshold,
|
632
|
-
help="The type of heavy channels in double sparsity attention",
|
633
|
-
)
|
634
|
-
|
635
624
|
# LoRA
|
636
625
|
parser.add_argument(
|
637
626
|
"--lora-paths",
|
@@ -671,6 +660,75 @@ class ServerArgs:
|
|
671
660
|
help="Choose the backend for grammar-guided decoding.",
|
672
661
|
)
|
673
662
|
|
663
|
+
# Speculative decoding
|
664
|
+
parser.add_argument(
|
665
|
+
"--speculative-algorithm",
|
666
|
+
type=str,
|
667
|
+
choices=["EAGLE"],
|
668
|
+
help="Speculative algorithm.",
|
669
|
+
)
|
670
|
+
parser.add_argument(
|
671
|
+
"--speculative-draft-model-path",
|
672
|
+
type=str,
|
673
|
+
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
|
674
|
+
)
|
675
|
+
parser.add_argument(
|
676
|
+
"--speculative-num-steps",
|
677
|
+
type=int,
|
678
|
+
help="The number of steps sampled from draft model in Speculative Decoding.",
|
679
|
+
default=ServerArgs.speculative_num_steps,
|
680
|
+
)
|
681
|
+
parser.add_argument(
|
682
|
+
"--speculative-num-draft-tokens",
|
683
|
+
type=int,
|
684
|
+
help="The number of token sampled from draft model in Speculative Decoding.",
|
685
|
+
default=ServerArgs.speculative_num_draft_tokens,
|
686
|
+
)
|
687
|
+
parser.add_argument(
|
688
|
+
"--speculative-eagle-topk",
|
689
|
+
type=int,
|
690
|
+
help="The number of token sampled from draft model in eagle2 each step.",
|
691
|
+
choices=[1, 2, 4, 8],
|
692
|
+
default=ServerArgs.speculative_eagle_topk,
|
693
|
+
)
|
694
|
+
|
695
|
+
# Double Sparsity
|
696
|
+
parser.add_argument(
|
697
|
+
"--enable-double-sparsity",
|
698
|
+
action="store_true",
|
699
|
+
help="Enable double sparsity attention",
|
700
|
+
)
|
701
|
+
parser.add_argument(
|
702
|
+
"--ds-channel-config-path",
|
703
|
+
type=str,
|
704
|
+
default=ServerArgs.ds_channel_config_path,
|
705
|
+
help="The path of the double sparsity channel config",
|
706
|
+
)
|
707
|
+
parser.add_argument(
|
708
|
+
"--ds-heavy-channel-num",
|
709
|
+
type=int,
|
710
|
+
default=ServerArgs.ds_heavy_channel_num,
|
711
|
+
help="The number of heavy channels in double sparsity attention",
|
712
|
+
)
|
713
|
+
parser.add_argument(
|
714
|
+
"--ds-heavy-token-num",
|
715
|
+
type=int,
|
716
|
+
default=ServerArgs.ds_heavy_token_num,
|
717
|
+
help="The number of heavy tokens in double sparsity attention",
|
718
|
+
)
|
719
|
+
parser.add_argument(
|
720
|
+
"--ds-heavy-channel-type",
|
721
|
+
type=str,
|
722
|
+
default=ServerArgs.ds_heavy_channel_type,
|
723
|
+
help="The type of heavy channels in double sparsity attention",
|
724
|
+
)
|
725
|
+
parser.add_argument(
|
726
|
+
"--ds-sparse-decode-threshold",
|
727
|
+
type=int,
|
728
|
+
default=ServerArgs.ds_sparse_decode_threshold,
|
729
|
+
help="The type of heavy channels in double sparsity attention",
|
730
|
+
)
|
731
|
+
|
674
732
|
# Optimization/debug options
|
675
733
|
parser.add_argument(
|
676
734
|
"--disable-radix-cache",
|
@@ -0,0 +1,347 @@
|
|
1
|
+
import cutex
|
2
|
+
import torch
|
3
|
+
|
4
|
+
# parent_table [bs,topk*depth+)]
|
5
|
+
# selected_index [bs,draft_token_num-1)]
|
6
|
+
# verified_seq_len [bs]
|
7
|
+
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
|
8
|
+
# positions [bs*draft_token]
|
9
|
+
# retrive_index [b, draft_token, depth+2]
|
10
|
+
kernels = cutex.SourceModule(
|
11
|
+
"""
|
12
|
+
//cuda
|
13
|
+
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
|
14
|
+
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
|
15
|
+
int bid = blockIdx.x;
|
16
|
+
int tid = threadIdx.x;
|
17
|
+
if (tid >= draft_token_num){
|
18
|
+
return;
|
19
|
+
}
|
20
|
+
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
21
|
+
for(int i=0; i<bid; i++){
|
22
|
+
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
23
|
+
}
|
24
|
+
int seq_len = verified_seq_len[bid];
|
25
|
+
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
|
26
|
+
for(int i=0; i<draft_token_num-1; i++){
|
27
|
+
tree_mask[token_tree_idx+i] = false;
|
28
|
+
}
|
29
|
+
|
30
|
+
int position = 0;
|
31
|
+
if (tid==0){
|
32
|
+
positions[bid*draft_token_num] = seq_len;
|
33
|
+
retrive_index[bid][0][0] = bid * draft_token_num;
|
34
|
+
return;
|
35
|
+
}
|
36
|
+
|
37
|
+
int depends_order[10];
|
38
|
+
|
39
|
+
int cur_position = tid-1;
|
40
|
+
while(true){
|
41
|
+
depends_order[position] = cur_position+1;
|
42
|
+
position += 1;
|
43
|
+
tree_mask[token_tree_idx+cur_position] = true;
|
44
|
+
int parent_tb_idx = selected_index[bid][cur_position]/topk;
|
45
|
+
if(parent_tb_idx==0){
|
46
|
+
break;
|
47
|
+
}
|
48
|
+
|
49
|
+
int token_idx = parent_list[bid][parent_tb_idx];
|
50
|
+
for(cur_position=0; cur_position<draft_token_num;cur_position++){
|
51
|
+
if(selected_index[bid][cur_position]==token_idx){
|
52
|
+
break;
|
53
|
+
}
|
54
|
+
}
|
55
|
+
}
|
56
|
+
positions[bid*draft_token_num+tid] = position + seq_len;
|
57
|
+
|
58
|
+
int is_leaf = 0;
|
59
|
+
for(int i=1;i<draft_token_num;i++){
|
60
|
+
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
|
61
|
+
{
|
62
|
+
is_leaf ++;
|
63
|
+
}
|
64
|
+
}
|
65
|
+
if(is_leaf==1){
|
66
|
+
for(int i=0; i<position; i++){
|
67
|
+
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
|
68
|
+
}
|
69
|
+
retrive_index[bid][tid][0] = bid*draft_token_num;
|
70
|
+
}
|
71
|
+
|
72
|
+
|
73
|
+
|
74
|
+
}
|
75
|
+
//!cuda
|
76
|
+
""",
|
77
|
+
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
|
78
|
+
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
|
83
|
+
bs = seq_lens.numel()
|
84
|
+
device = parent_list.device
|
85
|
+
tree_mask = torch.full(
|
86
|
+
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
|
87
|
+
True,
|
88
|
+
device=device,
|
89
|
+
)
|
90
|
+
retrive_index = torch.full(
|
91
|
+
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
|
92
|
+
)
|
93
|
+
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
|
94
|
+
|
95
|
+
kernels.build_tree(
|
96
|
+
parent_list,
|
97
|
+
top_score_index,
|
98
|
+
seq_lens.to(torch.int32),
|
99
|
+
tree_mask,
|
100
|
+
positions,
|
101
|
+
retrive_index,
|
102
|
+
topk,
|
103
|
+
depth,
|
104
|
+
draft_token,
|
105
|
+
grid=(bs, 1, 1),
|
106
|
+
block=(64, 1, 1),
|
107
|
+
)
|
108
|
+
index = retrive_index.sum(dim=-1) != -depth - 2
|
109
|
+
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
110
|
+
retrive_cum_len = torch.zeros(
|
111
|
+
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
112
|
+
)
|
113
|
+
retrive_cum_len[1:] = cum_len
|
114
|
+
retrive_index = retrive_index[index]
|
115
|
+
return tree_mask, positions, retrive_index, retrive_cum_len
|
116
|
+
|
117
|
+
|
118
|
+
if __name__ == "__main__":
|
119
|
+
|
120
|
+
def findp(p_i, index, parent_list):
|
121
|
+
pos = index // 10
|
122
|
+
index_list = index.tolist()
|
123
|
+
parent_list = parent_list.tolist()
|
124
|
+
res = [p_i]
|
125
|
+
while True:
|
126
|
+
p = pos[p_i]
|
127
|
+
if p == 0:
|
128
|
+
break
|
129
|
+
token_idx = parent_list[p]
|
130
|
+
p_i = index_list.index(token_idx)
|
131
|
+
res.append(p_i)
|
132
|
+
return res
|
133
|
+
|
134
|
+
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
|
135
|
+
mask = []
|
136
|
+
positions = []
|
137
|
+
retrive_index = []
|
138
|
+
for i, lens in enumerate(seq_len.tolist()):
|
139
|
+
first_mask = torch.full((lens + draft_token,), True)
|
140
|
+
first_mask[-(draft_token - 1) :] = False
|
141
|
+
positions.append(lens)
|
142
|
+
mask.append(first_mask)
|
143
|
+
seq_order = []
|
144
|
+
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
|
145
|
+
r_index = [first_index]
|
146
|
+
for j in range(draft_token - 1):
|
147
|
+
mask.append(torch.full((lens + 1,), True))
|
148
|
+
idx = findp(j, index, parent_list)
|
149
|
+
|
150
|
+
seq_order.append(idx)
|
151
|
+
positions.append(len(idx) + seq_len)
|
152
|
+
t = torch.full((draft_token - 1,), False)
|
153
|
+
t[idx] = True
|
154
|
+
mask.append(t)
|
155
|
+
|
156
|
+
for i in range(1, draft_token - 1):
|
157
|
+
is_leaf = 0
|
158
|
+
for j in range(draft_token - 1):
|
159
|
+
if i in seq_order[j]:
|
160
|
+
is_leaf += 1
|
161
|
+
|
162
|
+
if is_leaf == 1:
|
163
|
+
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
|
164
|
+
for _ in range(max_depth + 1 - len(seq_order[i])):
|
165
|
+
order_list.append(-1)
|
166
|
+
order = torch.Tensor(order_list).cuda().to(torch.long)
|
167
|
+
r_index.append(order)
|
168
|
+
retrive_index.append(torch.stack(r_index))
|
169
|
+
|
170
|
+
return (
|
171
|
+
torch.cat(mask).cuda(),
|
172
|
+
torch.Tensor(positions).cuda().to(torch.long),
|
173
|
+
torch.stack(retrive_index),
|
174
|
+
)
|
175
|
+
|
176
|
+
index = (
|
177
|
+
torch.Tensor(
|
178
|
+
[
|
179
|
+
0,
|
180
|
+
1,
|
181
|
+
2,
|
182
|
+
3,
|
183
|
+
10,
|
184
|
+
11,
|
185
|
+
12,
|
186
|
+
13,
|
187
|
+
20,
|
188
|
+
21,
|
189
|
+
22,
|
190
|
+
30,
|
191
|
+
110,
|
192
|
+
130,
|
193
|
+
150,
|
194
|
+
160,
|
195
|
+
210,
|
196
|
+
211,
|
197
|
+
212,
|
198
|
+
213,
|
199
|
+
214,
|
200
|
+
215,
|
201
|
+
216,
|
202
|
+
217,
|
203
|
+
218,
|
204
|
+
219,
|
205
|
+
220,
|
206
|
+
230,
|
207
|
+
310,
|
208
|
+
311,
|
209
|
+
312,
|
210
|
+
313,
|
211
|
+
314,
|
212
|
+
315,
|
213
|
+
316,
|
214
|
+
317,
|
215
|
+
320,
|
216
|
+
321,
|
217
|
+
322,
|
218
|
+
330,
|
219
|
+
360,
|
220
|
+
380,
|
221
|
+
390,
|
222
|
+
410,
|
223
|
+
411,
|
224
|
+
412,
|
225
|
+
413,
|
226
|
+
414,
|
227
|
+
415,
|
228
|
+
416,
|
229
|
+
417,
|
230
|
+
418,
|
231
|
+
419,
|
232
|
+
420,
|
233
|
+
421,
|
234
|
+
422,
|
235
|
+
423,
|
236
|
+
430,
|
237
|
+
431,
|
238
|
+
440,
|
239
|
+
441,
|
240
|
+
460,
|
241
|
+
470,
|
242
|
+
]
|
243
|
+
)
|
244
|
+
.to(torch.long)
|
245
|
+
.cuda()
|
246
|
+
)
|
247
|
+
|
248
|
+
parent_list = (
|
249
|
+
torch.Tensor(
|
250
|
+
[
|
251
|
+
-1,
|
252
|
+
0,
|
253
|
+
1,
|
254
|
+
2,
|
255
|
+
3,
|
256
|
+
4,
|
257
|
+
5,
|
258
|
+
6,
|
259
|
+
7,
|
260
|
+
8,
|
261
|
+
9,
|
262
|
+
10,
|
263
|
+
11,
|
264
|
+
12,
|
265
|
+
20,
|
266
|
+
30,
|
267
|
+
21,
|
268
|
+
13,
|
269
|
+
22,
|
270
|
+
40,
|
271
|
+
23,
|
272
|
+
110,
|
273
|
+
130,
|
274
|
+
160,
|
275
|
+
150,
|
276
|
+
190,
|
277
|
+
120,
|
278
|
+
111,
|
279
|
+
121,
|
280
|
+
200,
|
281
|
+
180,
|
282
|
+
210,
|
283
|
+
211,
|
284
|
+
212,
|
285
|
+
213,
|
286
|
+
214,
|
287
|
+
215,
|
288
|
+
216,
|
289
|
+
220,
|
290
|
+
230,
|
291
|
+
217,
|
292
|
+
310,
|
293
|
+
311,
|
294
|
+
312,
|
295
|
+
313,
|
296
|
+
320,
|
297
|
+
314,
|
298
|
+
321,
|
299
|
+
315,
|
300
|
+
316,
|
301
|
+
317,
|
302
|
+
]
|
303
|
+
)
|
304
|
+
.to(torch.long)
|
305
|
+
.cuda()
|
306
|
+
)
|
307
|
+
|
308
|
+
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
|
309
|
+
bs = verified_seq_len.shape[0]
|
310
|
+
topk = 10
|
311
|
+
depth = 5 # depth <= 10
|
312
|
+
draft_token = 64
|
313
|
+
|
314
|
+
tree_mask = torch.full(
|
315
|
+
(
|
316
|
+
torch.sum(verified_seq_len).item() * draft_token
|
317
|
+
+ draft_token * draft_token * bs,
|
318
|
+
),
|
319
|
+
True,
|
320
|
+
).cuda()
|
321
|
+
retrive_index = torch.full(
|
322
|
+
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
323
|
+
)
|
324
|
+
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
|
325
|
+
|
326
|
+
kernels.build_tree(
|
327
|
+
parent_list.unsqueeze(0),
|
328
|
+
index.unsqueeze(0),
|
329
|
+
verified_seq_len,
|
330
|
+
tree_mask,
|
331
|
+
positions,
|
332
|
+
retrive_index,
|
333
|
+
topk,
|
334
|
+
depth,
|
335
|
+
draft_token,
|
336
|
+
grid=(bs, 1, 1),
|
337
|
+
block=(64, 1, 1),
|
338
|
+
)
|
339
|
+
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
340
|
+
|
341
|
+
c_mask, c_positions, c_retive_index = create_mask(
|
342
|
+
verified_seq_len, draft_token, index, parent_list, depth
|
343
|
+
)
|
344
|
+
|
345
|
+
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
346
|
+
assert torch.allclose(positions, c_positions), "positions has error."
|
347
|
+
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|