sglang 0.4.1.post2__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  71. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  77. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  78. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  79. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  80. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  81. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  82. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  83. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  84. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  85. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
  86. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  89. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  94. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  101. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  102. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  103. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  104. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  105. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  106. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  107. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  108. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  109. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  112. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  113. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  114. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  115. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  116. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  117. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  118. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  119. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  120. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  121. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  122. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  123. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  124. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  125. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  126. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  127. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  128. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  129. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  130. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  131. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  132. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  133. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  134. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  135. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  136. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  137. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  138. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  139. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  140. sglang/srt/layers/quantization/fp8.py +2 -2
  141. sglang/srt/layers/sampler.py +57 -21
  142. sglang/srt/layers/torchao_utils.py +17 -3
  143. sglang/srt/managers/detokenizer_manager.py +2 -0
  144. sglang/srt/managers/io_struct.py +12 -3
  145. sglang/srt/managers/schedule_batch.py +26 -2
  146. sglang/srt/managers/schedule_policy.py +159 -90
  147. sglang/srt/managers/scheduler.py +71 -27
  148. sglang/srt/managers/tokenizer_manager.py +29 -20
  149. sglang/srt/managers/tp_worker.py +16 -4
  150. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  151. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  152. sglang/srt/model_executor/forward_batch_info.py +33 -8
  153. sglang/srt/model_executor/model_runner.py +63 -61
  154. sglang/srt/models/deepseek_v2.py +34 -7
  155. sglang/srt/models/grok.py +97 -26
  156. sglang/srt/openai_api/adapter.py +0 -17
  157. sglang/srt/openai_api/protocol.py +3 -3
  158. sglang/srt/sampling/sampling_batch_info.py +21 -0
  159. sglang/srt/sampling/sampling_params.py +9 -1
  160. sglang/srt/server.py +9 -5
  161. sglang/srt/server_args.py +109 -51
  162. sglang/srt/speculative/build_eagle_tree.py +347 -0
  163. sglang/srt/speculative/eagle_utils.py +618 -0
  164. sglang/srt/speculative/eagle_worker.py +170 -0
  165. sglang/srt/speculative/spec_info.py +5 -0
  166. sglang/srt/utils.py +15 -2
  167. sglang/version.py +1 -1
  168. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  169. sglang-0.4.1.post4.dist-info/RECORD +329 -0
  170. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  171. sglang-0.4.1.post2.dist-info/RECORD +0 -197
  172. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  173. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,618 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, List
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from sglang.srt.layers.attention.flashinfer_backend import (
10
+ create_flashinfer_kv_indices_triton,
11
+ )
12
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
+ from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
14
+ from sglang.srt.speculative.spec_info import SpecInfo
15
+
16
+ if TYPE_CHECKING:
17
+ from python.sglang.srt.layers.sampler import SampleOutput
18
+ from python.sglang.srt.managers.schedule_batch import ScheduleBatch
19
+ from sglang.srt.server_args import ServerArgs
20
+
21
+
22
+ @triton.jit
23
+ def eagle_verify_retrive(
24
+ retrive_index,
25
+ accept_mask,
26
+ retrive_cum_len,
27
+ accept_index,
28
+ accept_length,
29
+ extract_index,
30
+ max_len: tl.constexpr,
31
+ draft_token_num: tl.constexpr,
32
+ max_len_upper: tl.constexpr,
33
+ ):
34
+ pid = tl.program_id(axis=0)
35
+
36
+ retrive_end = tl.load(retrive_cum_len + pid + 1)
37
+ retrive_start = tl.load(retrive_cum_len + pid)
38
+ retrive_len = retrive_end - retrive_start
39
+ accept_ptr = accept_mask + retrive_start
40
+ accept_offset = tl.arange(0, draft_token_num)
41
+ accept_load_mask = accept_offset < retrive_len
42
+ accept_len_list = tl.load(
43
+ accept_ptr + accept_offset, mask=accept_load_mask, other=-1
44
+ )
45
+
46
+ accept_len = tl.max(accept_len_list)
47
+ max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
48
+ # triton is not support argmax with tie_break_right, so I need implement it by some way
49
+ mask_max = accept_len_list == accept_len
50
+
51
+ count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
52
+ count = tl.sum(tl.where(mask_max, 1, count_mask))
53
+ if count > 1:
54
+ index = tl.arange(0, draft_token_num)
55
+ mask_left = index != max_index
56
+ remained_index = tl.where(mask_max and mask_left, index, 0)
57
+ max_index = tl.max(remained_index)
58
+
59
+ tl.store(accept_length + pid, accept_len)
60
+ retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
61
+ retrive_offset = tl.arange(0, max_len_upper)
62
+ retrive_load_mask = retrive_offset < accept_len + 1
63
+ data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
64
+
65
+ tl.store(
66
+ accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
67
+ )
68
+
69
+ extract_load_ptr = accept_index + pid * max_len + accept_len
70
+ if accept_len == max_len - 1:
71
+ extract_data = tl.load(extract_load_ptr - 1)
72
+ tl.store(extract_index + pid * 2, extract_data)
73
+ extract_data = tl.load(extract_load_ptr)
74
+ tl.store(extract_index + pid * 2 + 1, extract_data)
75
+
76
+ else:
77
+ extract_data = tl.load(extract_load_ptr)
78
+ tl.store(extract_index + pid * 2, extract_data)
79
+
80
+
81
+ @triton.jit
82
+ def create_extend_spec_info(
83
+ verified_id,
84
+ seq_len,
85
+ accept_len,
86
+ accept_len_cum,
87
+ positions,
88
+ new_verified_id,
89
+ accept_len_upper: tl.constexpr,
90
+ ):
91
+ pid = tl.program_id(axis=0)
92
+ offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
93
+ seq_length = tl.load(seq_len + pid)
94
+ accept_length = tl.load(accept_len + pid)
95
+ positions_ptr = positions + offset
96
+ data = tl.arange(0, accept_len_upper)
97
+ mask = data < accept_length
98
+ tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
99
+
100
+ offset = tl.load(accept_len_cum + pid) - 1
101
+ verified_id_data = tl.load(verified_id + offset)
102
+ tl.store(new_verified_id + pid, verified_id_data)
103
+
104
+
105
+ @triton.jit
106
+ def assign_req_to_token_pool(
107
+ req_pool_indices,
108
+ req_to_token,
109
+ start_offset,
110
+ end_offset,
111
+ out_cache_loc,
112
+ pool_len: tl.constexpr,
113
+ bs_upper: tl.constexpr,
114
+ ):
115
+ BLOCK_SIZE: tl.constexpr = 32
116
+ pid = tl.program_id(axis=0)
117
+ kv_start = tl.load(start_offset + pid)
118
+ kv_end = tl.load(end_offset + pid)
119
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
120
+
121
+ length_offset = tl.arange(0, bs_upper)
122
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid)
123
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid)
124
+ out_offset = tl.sum(end - start, axis=0)
125
+
126
+ out_cache_ptr = out_cache_loc + out_offset
127
+
128
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
129
+ load_offset = tl.arange(0, BLOCK_SIZE)
130
+
131
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
132
+ for _ in range(num_loop):
133
+ mask = save_offset < kv_end
134
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
135
+ tl.store(token_pool + save_offset, data, mask=mask)
136
+ save_offset += BLOCK_SIZE
137
+ load_offset += BLOCK_SIZE
138
+
139
+
140
+ @triton.jit
141
+ def generate_draft_decode_kv_indices(
142
+ req_pool_indices,
143
+ req_to_token,
144
+ paged_kernel_lens,
145
+ kv_indices,
146
+ iters: tl.constexpr,
147
+ topk: tl.constexpr,
148
+ pool_len: tl.constexpr,
149
+ bs_upper: tl.constexpr,
150
+ iter_upper: tl.constexpr,
151
+ ):
152
+ BLOCK_SIZE: tl.constexpr = 128
153
+ bid = tl.program_id(axis=0)
154
+ topk_id = tl.program_id(axis=1)
155
+
156
+ load_offset = tl.arange(0, bs_upper)
157
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
158
+ seq_len = tl.load(paged_kernel_lens + bid)
159
+ cum_seq_len = tl.sum(seq_lens)
160
+
161
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
162
+ kv_ptr = kv_indices + kv_offset
163
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
164
+
165
+ kv_offset = tl.arange(0, BLOCK_SIZE)
166
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
167
+ for _ in range(num_loop):
168
+ mask = kv_offset < seq_len
169
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
170
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
171
+ kv_offset += BLOCK_SIZE
172
+
173
+ extend_offset = tl.arange(0, iter_upper)
174
+ extend_data = tl.load(
175
+ token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
176
+ mask=extend_offset < iters,
177
+ )
178
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
179
+
180
+
181
+ class EAGLEDraftInput(SpecInfo):
182
+ hidden_states: torch.Tensor = None
183
+ verified_id: torch.Tensor = None
184
+ positions: torch.Tensor = None
185
+ accept_length: torch.Tensor = None
186
+ has_finished: bool = False
187
+ unfinished_index: List[int] = None
188
+
189
+ def init(self, server_args: ServerArgs):
190
+ self.prev_mode = ForwardMode.DECODE
191
+ self.sample_output = None
192
+ self.topk: int = server_args.speculative_eagle_topk
193
+ self.num_verify_token: int = server_args.speculative_num_draft_tokens
194
+ self.spec_steps = server_args.speculative_num_steps
195
+
196
+ self.scores: torch.Tensor = None
197
+ self.score_list: List[torch.Tensor] = []
198
+ self.token_list: List[torch.Tensor] = []
199
+ self.origin_score_list: List[torch.Tensor] = [] # used for sampling
200
+ self.parents_list: List[torch.Tensor] = []
201
+ self.cache_list: List[torch.Tenor] = []
202
+ self.iter = 0
203
+ self.root_token: int = None
204
+
205
+ assert self.topk <= 10, "topk should <= 10"
206
+
207
+ def prepare_for_extend(self, batch: ForwardBatch):
208
+ req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
209
+ out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
210
+ batch.out_cache_loc = out_cache_loc
211
+
212
+ pt = 0
213
+ for i, req in enumerate(batch.reqs):
214
+ req.req_pool_idx = req_pool_indices[i]
215
+ pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
216
+ assert seq_len - pre_len == req.extend_input_len
217
+
218
+ if pre_len > 0:
219
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][
220
+ :pre_len
221
+ ] = req.prefix_indices
222
+
223
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
224
+ out_cache_loc[pt : pt + req.extend_input_len]
225
+ )
226
+
227
+ pt += req.extend_input_len
228
+
229
+ seq_lens = [0] + batch.extend_lens
230
+ input_ids = batch.input_ids.tolist()
231
+ verified_id = batch.spec_info.verified_id.tolist()
232
+ model_input_ids = []
233
+ for i in range(len(seq_lens) - 1):
234
+ model_input_ids.extend(
235
+ input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
236
+ )
237
+ batch.input_ids = torch.tensor(
238
+ model_input_ids, dtype=torch.int32, device="cuda"
239
+ )
240
+
241
+ def capture_for_decode(
242
+ self,
243
+ sample_output: SampleOutput,
244
+ hidden_states: torch.Tensor,
245
+ prev_mode: ForwardMode,
246
+ ):
247
+ self.sample_output = sample_output
248
+ self.prev_mode = prev_mode
249
+ self.hidden_states = hidden_states
250
+
251
+ def prepare_for_decode(self, batch: ScheduleBatch):
252
+ prob = self.sample_output # b * (1/topk), vocab
253
+ top = torch.topk(prob, self.topk, dim=-1)
254
+ topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
255
+ if self.prev_mode == ForwardMode.DECODE:
256
+ scores = torch.mul(
257
+ self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
258
+ ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
259
+ topk_cs = torch.topk(
260
+ scores.flatten(start_dim=1), self.topk, dim=-1
261
+ ) # (b, topk)
262
+ topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
263
+ self.scores = topk_cs_p
264
+
265
+ selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
266
+
267
+ batch.spec_info.hidden_states = batch.spec_info.hidden_states[
268
+ selected_input_index, :
269
+ ]
270
+ topk_index = topk_index.reshape(-1, self.topk**2)
271
+ batch.input_ids = torch.gather(
272
+ topk_index, index=topk_cs_index, dim=1
273
+ ).flatten()
274
+ batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
275
+ self.score_list.append(scores) # b, topk, topk
276
+ self.token_list.append(topk_index) # b, topk*topk
277
+ self.origin_score_list.append(topk_p.reshape(topk_index.shape))
278
+ self.parents_list.append(
279
+ topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
280
+ ) # b, topk
281
+
282
+ elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
283
+ self.scores = topk_p # b, top_k
284
+ self.score_list.append(topk_p.unsqueeze(1))
285
+ self.token_list.append(topk_index)
286
+ self.origin_score_list.append(topk_p)
287
+ batch.spec_info.hidden_states = (
288
+ batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
289
+ )
290
+ batch.input_ids = topk_index.flatten()
291
+ batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
292
+ self.parents_list.append(
293
+ torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
294
+ .unsqueeze(0)
295
+ .repeat(self.scores.shape[0], 1)
296
+ ) # b, topk+1
297
+ self.cache_list.append(batch.out_cache_loc)
298
+ self.positions = (
299
+ batch.seq_lens[:, None]
300
+ + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
301
+ ).flatten()
302
+
303
+ bs = batch.seq_lens.numel()
304
+ assign_req_to_token_pool[(bs,)](
305
+ batch.req_pool_indices,
306
+ batch.req_to_token_pool.req_to_token,
307
+ batch.seq_lens + self.topk * self.iter,
308
+ batch.seq_lens + self.topk * (self.iter + 1),
309
+ batch.out_cache_loc,
310
+ batch.req_to_token_pool.req_to_token.shape[1],
311
+ triton.next_power_of_2(bs),
312
+ )
313
+ self.iter += 1
314
+
315
+ def prepare_extend_after_decode(self, batch: ScheduleBatch):
316
+ batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
317
+ batch.extend_lens = (self.accept_length + 1).tolist()
318
+
319
+ pt = 0
320
+ seq_lens = batch.seq_lens.tolist()
321
+
322
+ i = 0
323
+
324
+ for req in batch.reqs:
325
+ if req.finished():
326
+ continue
327
+ # assert seq_len - pre_len == req.extend_input_len
328
+ input_len = self.accept_length[i] + 1
329
+ seq_len = seq_lens[i]
330
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx][
331
+ seq_len - input_len : seq_len
332
+ ] = batch.out_cache_loc[pt : pt + input_len]
333
+ pt += input_len
334
+ i += 1
335
+
336
+ self.positions = torch.empty_like(self.verified_id)
337
+ new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
338
+ self.accept_length.add_(1)
339
+
340
+ create_extend_spec_info[(self.accept_length.numel(),)](
341
+ self.verified_id,
342
+ batch.seq_lens,
343
+ self.accept_length,
344
+ torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
345
+ self.positions,
346
+ new_verified_id,
347
+ triton.next_power_of_2(self.spec_steps + 1),
348
+ )
349
+
350
+ batch.input_ids = self.verified_id
351
+ self.verified_id = new_verified_id
352
+
353
+ def prepare_for_verify(self, batch: ScheduleBatch):
354
+ score_list = torch.cat(self.score_list, dim=1).flatten(
355
+ 1
356
+ ) # b, n, topk; n= 1+(self.iter-1)*self.topk
357
+ ss_token_list = torch.cat(
358
+ self.token_list, dim=1
359
+ ) # b, (self.topk+(self.iter-1)*self.topk)
360
+ origin_token_list = torch.cat(self.origin_score_list, dim=1)
361
+ top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
362
+ top_scores_index = top_scores.indices
363
+ top_scores_index = torch.sort(top_scores_index).values
364
+
365
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
366
+ scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
367
+ draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
368
+ parent_list = torch.cat(self.parents_list[:-1], dim=1)
369
+
370
+ tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
371
+ parent_list,
372
+ top_scores_index,
373
+ batch.seq_lens,
374
+ self.topk,
375
+ self.iter - 1,
376
+ self.num_verify_token,
377
+ )
378
+
379
+ return EagleVerifyInput(
380
+ draft_tokens.flatten(),
381
+ scores.flatten(),
382
+ tree_mask,
383
+ position,
384
+ retrive_index,
385
+ retrive_cum_len,
386
+ self.num_verify_token,
387
+ )
388
+
389
+ def generate_attn_arg_decode(
390
+ self,
391
+ req_pool_indices: torch.Tensor,
392
+ paged_kernel_lens: torch.Tensor,
393
+ req_to_token: torch.Tensor,
394
+ ):
395
+ seq_num = req_pool_indices.numel()
396
+ bs = self.topk * req_pool_indices.numel()
397
+ seq_len = self.positions.reshape(-1).contiguous()
398
+
399
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
400
+ cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
401
+ total_len = torch.sum(paged_kernel_lens).item()
402
+
403
+ kv_indices = torch.empty(
404
+ (total_len * self.topk + seq_num * self.iter * self.topk,),
405
+ dtype=torch.int32,
406
+ device="cuda",
407
+ )
408
+
409
+ generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
410
+ req_pool_indices,
411
+ req_to_token,
412
+ paged_kernel_lens,
413
+ kv_indices,
414
+ self.iter,
415
+ self.topk,
416
+ req_to_token.shape[1],
417
+ triton.next_power_of_2(seq_num),
418
+ triton.next_power_of_2(self.spec_steps),
419
+ )
420
+ return bs, kv_indices, cum_kv_seq_len
421
+
422
+ def clear(self):
423
+ self.iter = 0
424
+ self.score_list.clear()
425
+ self.positions = None
426
+
427
+ def clear_draft_cache(self, batch):
428
+ draft_cache = torch.cat(self.cache_list, dim=0)
429
+ batch.token_to_kv_pool.free(draft_cache)
430
+
431
+ def generate_attn_arg_prefill(
432
+ self,
433
+ req_pool_indices: torch.Tensor,
434
+ paged_kernel_lens: torch.Tensor,
435
+ req_to_token: torch.Tensor,
436
+ ):
437
+ bs = self.accept_length.numel()
438
+ qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
439
+ qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
440
+
441
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
442
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
443
+ kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
444
+
445
+ create_flashinfer_kv_indices_triton[(bs,)](
446
+ req_to_token,
447
+ req_pool_indices,
448
+ paged_kernel_lens,
449
+ cum_kv_seq_len,
450
+ None,
451
+ kv_indices,
452
+ req_to_token.size(1),
453
+ )
454
+
455
+ return kv_indices, cum_kv_seq_len, qo_indptr, None
456
+
457
+ def merge_batch(self, spec_info: EAGLEDraftInput):
458
+
459
+ self.hidden_states = torch.cat(
460
+ [self.hidden_states, spec_info.hidden_states], axis=0
461
+ )
462
+ self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
463
+ # self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
464
+ self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
465
+
466
+
467
+ class EagleVerifyInput(SpecInfo):
468
+ def __init__(
469
+ self,
470
+ draft_token: torch.Tensor,
471
+ draft_score: torch.Tensor,
472
+ tree_mask: torch.Tensor,
473
+ positions: torch.Tensor,
474
+ retrive_index: torch.Tensor,
475
+ retrive_cum_len: torch.Tensor,
476
+ draft_token_num: int,
477
+ ):
478
+ self.draft_token = draft_token
479
+ self.draft_score = draft_score
480
+ self.custom_mask = tree_mask
481
+ self.positions = positions
482
+ self.retrive_index = retrive_index
483
+ self.retrive_cum_len = retrive_cum_len
484
+ self.draft_token_num = draft_token_num
485
+
486
+ def prepare_for_verify(self, batch: ScheduleBatch):
487
+ batch.input_ids = self.draft_token
488
+ batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
489
+ bs = batch.seq_lens.numel()
490
+ assign_req_to_token_pool[(bs,)](
491
+ batch.req_pool_indices,
492
+ batch.req_to_token_pool.req_to_token,
493
+ batch.seq_lens,
494
+ batch.seq_lens + self.draft_token_num,
495
+ batch.out_cache_loc,
496
+ batch.req_to_token_pool.req_to_token.shape[1],
497
+ triton.next_power_of_2(bs),
498
+ )
499
+
500
+ def generate_attn_arg_prefill(
501
+ self,
502
+ req_pool_indices: torch.Tensor,
503
+ paged_kernel_lens: torch.Tensor,
504
+ req_to_token: torch.Tensor,
505
+ ):
506
+ batch_size = len(req_pool_indices)
507
+ qo_indptr = torch.arange(
508
+ 0,
509
+ (1 + batch_size) * self.draft_token_num,
510
+ step=self.draft_token_num,
511
+ dtype=torch.int32,
512
+ device="cuda",
513
+ )
514
+
515
+ cum_kv_seq_len = torch.zeros(
516
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
517
+ )
518
+
519
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
520
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
521
+
522
+ kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
523
+
524
+ create_flashinfer_kv_indices_triton[(batch_size,)](
525
+ req_to_token,
526
+ req_pool_indices,
527
+ paged_kernel_lens,
528
+ cum_kv_seq_len,
529
+ None,
530
+ kv_indices,
531
+ req_to_token.size(1),
532
+ )
533
+ return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
534
+
535
+ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
536
+ predict = torch.argmax(logits_output.next_token_logits, dim=-1)
537
+ predict = torch.cat(
538
+ [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
539
+ )
540
+ draft_token = torch.cat(
541
+ [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
542
+ dim=-1,
543
+ )
544
+ target_predict = predict[self.retrive_index]
545
+ candidates = draft_token[self.retrive_index]
546
+ # logits = logits_output.next_token_logits[self.retrive_index]
547
+ # target_predict = torch.argmax(logits[:, :-1], dim=-1)
548
+ accept_mask = candidates[:, 1:] == target_predict[:, :-1]
549
+ accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
550
+ bs = self.retrive_cum_len.numel() - 1
551
+
552
+ max_draft_len = self.retrive_index.shape[-1]
553
+ accept_index = torch.full(
554
+ (bs, max_draft_len), -1, dtype=torch.long, device="cuda"
555
+ )
556
+ accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
557
+ extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
558
+ eagle_verify_retrive[(bs,)](
559
+ self.retrive_index.contiguous(),
560
+ accept_mask.contiguous(),
561
+ self.retrive_cum_len,
562
+ accept_index,
563
+ accept_length,
564
+ extract_index,
565
+ max_draft_len,
566
+ self.draft_token_num,
567
+ triton.next_power_of_2(max_draft_len),
568
+ )
569
+
570
+ accept_index = accept_index[accept_index != -1]
571
+ # extract_index = extract_index[extract_index != 0]
572
+
573
+ draft_input = EAGLEDraftInput()
574
+
575
+ accept_length_cpu = accept_length.tolist()
576
+ verified_id = predict[accept_index]
577
+ verified_id_cpu = verified_id.tolist()
578
+
579
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
580
+ evict_mask[accept_index] = False
581
+ mem_need_free_idx = batch.out_cache_loc[evict_mask]
582
+ batch.token_to_kv_pool.free(mem_need_free_idx)
583
+ assign_req_to_token_pool[(bs,)](
584
+ batch.req_pool_indices,
585
+ batch.req_to_token_pool.req_to_token,
586
+ batch.seq_lens,
587
+ batch.seq_lens + accept_length + 1,
588
+ batch.out_cache_loc[accept_index],
589
+ batch.req_to_token_pool.req_to_token.shape[1],
590
+ triton.next_power_of_2(bs),
591
+ )
592
+ batch.seq_lens.add_(accept_length + 1)
593
+ new_accept_index = []
594
+ unfinished_index = []
595
+ finished_extend_len = {} # {rid:accept_length + 1}
596
+ # retracted_reqs, new_token_ratio = batch.retract_decode()
597
+
598
+ low = 0
599
+ for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
600
+ req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
601
+ req.check_finished()
602
+ if req.finished():
603
+ draft_input.has_finished = True
604
+ else:
605
+ new_accept_index.append(accept_index[low : low + verified_len + 1])
606
+ unfinished_index.append(i)
607
+ low += verified_len + 1
608
+ finished_extend_len[req.rid] = verified_len + 1
609
+
610
+ if len(new_accept_index) > 0:
611
+ new_accept_index = torch.cat(new_accept_index, dim=0)
612
+ draft_input.verified_id = predict[new_accept_index]
613
+ draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
614
+ draft_input.accept_length = accept_length[unfinished_index]
615
+ draft_input.unfinished_index = unfinished_index
616
+
617
+ logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
618
+ return draft_input, logits_output, verified_id, finished_extend_len