sglang 0.4.1.post2__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  71. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  77. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  78. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  79. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  80. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  81. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  82. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  83. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  84. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  85. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
  86. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  89. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  94. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  101. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  102. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  103. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  104. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  105. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  106. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  107. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  108. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  109. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  112. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  113. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  114. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  115. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  116. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  117. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  118. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  119. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  120. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  121. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  122. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  123. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  124. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  125. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  126. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  127. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  128. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  129. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  130. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  131. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  132. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  133. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  134. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  135. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  136. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  137. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  138. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  139. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  140. sglang/srt/layers/quantization/fp8.py +2 -2
  141. sglang/srt/layers/sampler.py +57 -21
  142. sglang/srt/layers/torchao_utils.py +17 -3
  143. sglang/srt/managers/detokenizer_manager.py +2 -0
  144. sglang/srt/managers/io_struct.py +12 -3
  145. sglang/srt/managers/schedule_batch.py +26 -2
  146. sglang/srt/managers/schedule_policy.py +159 -90
  147. sglang/srt/managers/scheduler.py +71 -27
  148. sglang/srt/managers/tokenizer_manager.py +29 -20
  149. sglang/srt/managers/tp_worker.py +16 -4
  150. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  151. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  152. sglang/srt/model_executor/forward_batch_info.py +33 -8
  153. sglang/srt/model_executor/model_runner.py +63 -61
  154. sglang/srt/models/deepseek_v2.py +34 -7
  155. sglang/srt/models/grok.py +97 -26
  156. sglang/srt/openai_api/adapter.py +0 -17
  157. sglang/srt/openai_api/protocol.py +3 -3
  158. sglang/srt/sampling/sampling_batch_info.py +21 -0
  159. sglang/srt/sampling/sampling_params.py +9 -1
  160. sglang/srt/server.py +9 -5
  161. sglang/srt/server_args.py +109 -51
  162. sglang/srt/speculative/build_eagle_tree.py +347 -0
  163. sglang/srt/speculative/eagle_utils.py +618 -0
  164. sglang/srt/speculative/eagle_worker.py +170 -0
  165. sglang/srt/speculative/spec_info.py +5 -0
  166. sglang/srt/utils.py +15 -2
  167. sglang/version.py +1 -1
  168. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  169. sglang-0.4.1.post4.dist-info/RECORD +329 -0
  170. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  171. sglang-0.4.1.post2.dist-info/RECORD +0 -197
  172. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  173. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import random
18
18
  from collections import defaultdict
19
19
  from contextlib import contextmanager
20
20
  from enum import Enum, auto
21
- from typing import Dict, List, Optional
21
+ from typing import Dict, List, Optional, Set, Union
22
22
 
23
23
  import torch
24
24
 
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
50
50
  )
51
51
 
52
52
 
53
+ class CacheAwarePolicy(Enum):
54
+ """Scheduling policies that are aware of the tree cache."""
55
+
56
+ LPM = "lpm" # longest prefix match
57
+ DFS_WEIGHT = "dfs-weight" # depth-first search weighting
58
+
59
+
60
+ class CacheAgnosticPolicy(Enum):
61
+ """Scheduling policies that are not aware of the tree cache."""
62
+
63
+ FCFS = "fcfs" # first come first serve
64
+ LOF = "lof" # longest output first
65
+ RANDOM = "random"
66
+
67
+
53
68
  class SchedulePolicy:
54
- def __init__(self, policy: str, tree_cache: BasePrefixCache):
55
- if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
56
- # LPM and DFS-weight is meaningless when the tree cache is disabled.
57
- policy = "fcfs"
69
+ Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
58
70
 
59
- self.policy = policy
71
+ def __init__(self, policy: str, tree_cache: BasePrefixCache):
72
+ self.policy = self._validate_and_adjust_policy(policy, tree_cache)
60
73
  self.tree_cache = tree_cache
61
74
 
62
75
  # It is used to find the matching prefix for in-batch prefix caching.
@@ -64,110 +77,166 @@ class SchedulePolicy:
64
77
  req_to_token_pool=None, token_to_kv_pool=None, disable=False
65
78
  )
66
79
 
67
- def calc_priority(self, waiting_queue: List[Req]):
68
- if len(waiting_queue) > 128 and self.policy == "lpm":
69
- # Turn off the expensive prefix matching and sorting when the #queue is large.
70
- policy = "fcfs"
71
- else:
72
- policy = self.policy
80
+ def calc_priority(self, waiting_queue: List[Req]) -> bool:
81
+ policy = self._determine_active_policy(waiting_queue)
73
82
 
74
- # Compute matched prefix length
75
83
  prefix_computed = False
76
- if policy == "lpm" or policy == "dfs-weight":
77
- # rid to deprioritize in the current run for in-batch prefix caching.
78
- temporary_deprioritized = set()
79
- self.waiting_queue_radix_tree.reset()
80
-
81
- for r in waiting_queue:
82
- prefix_ids = r.adjust_max_prefix_ids()
83
-
84
- # NOTE: the prefix_indices must always be aligned with last_node
85
- r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
86
- rid=r.rid, key=prefix_ids
84
+ if isinstance(policy, CacheAwarePolicy):
85
+ prefix_computed = True
86
+ temporary_deprioritized = self._compute_prefix_matches(
87
+ waiting_queue, policy
88
+ )
89
+ if policy == CacheAwarePolicy.LPM:
90
+ SchedulePolicy._sort_by_longest_prefix(
91
+ waiting_queue, temporary_deprioritized
87
92
  )
93
+ elif policy == CacheAwarePolicy.DFS_WEIGHT:
94
+ SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
95
+ else:
96
+ raise ValueError(f"Unknown CacheAware Policy: {policy=}")
97
+ else:
98
+ if policy == CacheAgnosticPolicy.FCFS:
99
+ pass
100
+ elif policy == CacheAgnosticPolicy.LOF:
101
+ SchedulePolicy._sort_by_longest_output(waiting_queue)
102
+ elif policy == CacheAgnosticPolicy.RANDOM:
103
+ SchedulePolicy._sort_randomly(waiting_queue)
104
+ else:
105
+ raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
88
106
 
89
- # NOTE(sang): This logic is for in-batch prefix caching;
90
- # If there are more than 1 request that have small matching prefix from
91
- # existing cache, but all those requests share the same prefix, we prefer
92
- # to schedule only one of them so that we can increase the cache hit rate.
93
- # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
94
- # threshold means we cannot use in-batch prefix caching for short prefixes.
95
- # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
96
- if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
97
- in_batch_matching_prefixes, _ = (
98
- self.waiting_queue_radix_tree.match_prefix(
99
- rid=r.rid, key=prefix_ids
100
- )
101
- )
102
- if (
103
- len(in_batch_matching_prefixes)
104
- >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
105
- ):
106
- temporary_deprioritized.add(r.rid)
107
- else:
108
- # Insert with a dummy key
109
- self.waiting_queue_radix_tree.insert(
110
- prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
111
- )
107
+ return prefix_computed
112
108
 
113
- prefix_computed = True
109
+ def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
110
+ if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
111
+ # Turn off the expensive prefix matching and sorting when the #queue is large.
112
+ return CacheAgnosticPolicy.FCFS
113
+ return self.policy
114
+
115
+ def _validate_and_adjust_policy(
116
+ self, policy: str, tree_cache: BasePrefixCache
117
+ ) -> Policy:
118
+ """
119
+ Validates the policy and adjusts it if necessary based on tree cache settings.
120
+ """
121
+ try:
122
+ policy_enum = CacheAwarePolicy(policy)
123
+ if tree_cache.disable:
124
+ # If tree_cache is disabled, using CacheAgnosticPolicy policy
125
+ return CacheAgnosticPolicy.FCFS
126
+ return policy_enum
127
+ except ValueError:
128
+ try:
129
+ return CacheAgnosticPolicy(policy)
130
+ except ValueError:
131
+ raise ValueError(f"Unknown schedule_policy: {policy=}")
132
+
133
+ def _compute_prefix_matches(
134
+ self, waiting_queue: List[Req], policy: CacheAwarePolicy
135
+ ) -> Set[int]:
136
+ """
137
+ Computes and caches the matching prefixes for requests in the waiting queue,
138
+ and handles in-batch prefix caching logic.
139
+ """
140
+ temporary_deprioritized: Set[int] = set()
141
+ self.waiting_queue_radix_tree.reset()
142
+
143
+ for r in waiting_queue:
144
+ prefix_ids = r.adjust_max_prefix_ids()
145
+
146
+ # NOTE: the prefix_indices must always be aligned with last_node
147
+ r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
148
+ rid=r.rid, key=prefix_ids
149
+ )
114
150
 
115
- if policy == "lpm":
116
- # Longest Prefix Match
117
- waiting_queue.sort(
118
- key=lambda r: (
119
- -len(r.prefix_indices)
120
- if r.rid not in temporary_deprioritized
121
- else float("inf")
151
+ # NOTE(sang): This logic is for in-batch prefix caching;
152
+ # If there are more than 1 request that have small matching prefix from
153
+ # existing cache, but all those requests share the same prefix, we prefer
154
+ # to schedule only one of them so that we can increase the cache hit rate.
155
+ # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
156
+ # threshold means we cannot use in-batch prefix caching for short prefixes.
157
+ # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
158
+ if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
159
+ in_batch_matching_prefixes, _ = (
160
+ self.waiting_queue_radix_tree.match_prefix(
161
+ rid=r.rid, key=prefix_ids
162
+ )
122
163
  )
164
+ if (
165
+ len(in_batch_matching_prefixes)
166
+ >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
167
+ ):
168
+ temporary_deprioritized.add(r.rid)
169
+ else:
170
+ # Insert with a dummy key
171
+ self.waiting_queue_radix_tree.insert(
172
+ prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
173
+ )
174
+ return temporary_deprioritized
175
+
176
+ @staticmethod
177
+ def _sort_by_longest_prefix(
178
+ waiting_queue: List[Req], temporary_deprioritized: Set[int]
179
+ ) -> None:
180
+ """Sorts the waiting queue based on the longest prefix match."""
181
+ waiting_queue.sort(
182
+ key=lambda r: (
183
+ -len(r.prefix_indices)
184
+ if r.rid not in temporary_deprioritized
185
+ else float("inf")
123
186
  )
124
- elif policy == "fcfs":
125
- # first come first serve
126
- pass
127
- elif policy == "lof":
128
- # longest output first
129
- waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
130
- elif policy == "random":
131
- random.shuffle(waiting_queue)
132
- elif policy == "dfs-weight":
133
- # Experimental policy based on custom weights
134
- last_node_to_reqs = defaultdict(list)
135
- for req in waiting_queue:
136
- last_node_to_reqs[req.last_node].append(req)
137
-
138
- node_to_weight = defaultdict(int)
139
- for node in last_node_to_reqs:
140
- node_to_weight[node] = len(last_node_to_reqs[node])
141
- self.calc_weight(self.tree_cache.root_node, node_to_weight)
142
-
143
- waiting_queue.clear()
144
- self.get_dfs_priority(
145
- self.tree_cache.root_node,
146
- node_to_weight,
147
- last_node_to_reqs,
148
- waiting_queue,
149
- )
150
- else:
151
- raise ValueError(f"Unknown schedule_policy: {policy=}")
187
+ )
152
188
 
153
- return prefix_computed
189
+ @staticmethod
190
+ def _sort_by_dfs_weight(
191
+ waiting_queue: List[Req], tree_cache: BasePrefixCache
192
+ ) -> None:
193
+ """Sorts the waiting queue based on a depth-first search weighting."""
194
+ last_node_to_reqs = defaultdict(list)
195
+ for req in waiting_queue:
196
+ last_node_to_reqs[req.last_node].append(req)
197
+
198
+ node_to_weight = defaultdict(int)
199
+ for node in last_node_to_reqs:
200
+ node_to_weight[node] = len(last_node_to_reqs[node])
201
+ SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight)
202
+
203
+ waiting_queue.clear()
204
+ SchedulePolicy._get_dfs_priority(
205
+ tree_cache.root_node,
206
+ node_to_weight,
207
+ last_node_to_reqs,
208
+ waiting_queue,
209
+ )
210
+
211
+ @staticmethod
212
+ def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
213
+ """Sorts the waiting queue based on the longest output (max_new_tokens)."""
214
+ waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
154
215
 
155
- def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
216
+ @staticmethod
217
+ def _sort_randomly(waiting_queue: List[Req]) -> None:
218
+ """Shuffles the waiting queue randomly."""
219
+ random.shuffle(waiting_queue)
220
+
221
+ @staticmethod
222
+ def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
156
223
  for child in cur_node.children.values():
157
- self.calc_weight(child, node_to_weight)
224
+ SchedulePolicy._calc_weight(child, node_to_weight)
158
225
  node_to_weight[cur_node] += node_to_weight[child]
159
226
 
160
- def get_dfs_priority(
161
- self,
227
+ @staticmethod
228
+ def _get_dfs_priority(
162
229
  cur_node: TreeNode,
163
230
  node_to_priority: Dict[TreeNode, int],
164
231
  last_node_to_reqs: Dict[TreeNode, List[Req]],
165
232
  q: List,
166
- ):
233
+ ) -> None:
167
234
  childs = [child for child in cur_node.children.values()]
168
235
  childs.sort(key=lambda x: -node_to_priority[x])
169
236
  for child in childs:
170
- self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
237
+ SchedulePolicy._get_dfs_priority(
238
+ child, node_to_priority, last_node_to_reqs, q
239
+ )
171
240
  q.extend(last_node_to_reqs[cur_node])
172
241
 
173
242
 
@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
76
76
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
77
77
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
78
78
  from sglang.srt.server_args import PortArgs, ServerArgs
79
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
79
80
  from sglang.srt.utils import (
80
81
  broadcast_pyobj,
81
82
  configure_logger,
@@ -116,6 +117,14 @@ class Scheduler:
116
117
  self.enable_overlap = not server_args.disable_overlap_schedule
117
118
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
118
119
  self.enable_metrics = server_args.enable_metrics
120
+ self.spec_algorithm = SpeculativeAlgorithm.from_string(
121
+ server_args.speculative_algorithm
122
+ )
123
+ self.decode_mem_cache_buf_multiplier = (
124
+ self.server_args.speculative_num_draft_tokens
125
+ if not self.spec_algorithm.is_none()
126
+ else 1
127
+ )
119
128
 
120
129
  # Init inter-process communication
121
130
  context = zmq.Context(2)
@@ -199,6 +208,21 @@ class Scheduler:
199
208
  nccl_port=port_args.nccl_port,
200
209
  )
201
210
 
211
+ # Launch worker for speculative decoding if need
212
+ if self.spec_algorithm.is_eagle():
213
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker
214
+
215
+ self.draft_worker = EAGLEWorker(
216
+ gpu_id=gpu_id,
217
+ tp_rank=tp_rank,
218
+ server_args=server_args,
219
+ nccl_port=port_args.nccl_port,
220
+ target_worker=self.tp_worker,
221
+ dp_rank=dp_rank,
222
+ )
223
+ else:
224
+ self.draft_worker = None
225
+
202
226
  # Get token and memory info from the model worker
203
227
  (
204
228
  self.max_total_num_tokens,
@@ -855,6 +879,7 @@ class Scheduler:
855
879
  self.tree_cache,
856
880
  self.model_config,
857
881
  self.enable_overlap,
882
+ self.spec_algorithm,
858
883
  )
859
884
  new_batch.prepare_for_extend()
860
885
 
@@ -888,11 +913,15 @@ class Scheduler:
888
913
  return None
889
914
 
890
915
  # Check if decode out of memory
891
- if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
916
+ if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
917
+ test_retract and batch.batch_size() > 10
918
+ ):
892
919
  old_ratio = self.new_token_ratio
893
920
 
894
921
  retracted_reqs, new_token_ratio = batch.retract_decode()
895
922
  self.new_token_ratio = new_token_ratio
923
+ if self.draft_worker:
924
+ self.draft_worker.finish_request(retracted_reqs)
896
925
 
897
926
  logger.info(
898
927
  "Decode out of memory happened. "
@@ -926,11 +955,17 @@ class Scheduler:
926
955
  self.forward_ct += 1
927
956
 
928
957
  if self.is_generation:
929
- model_worker_batch = batch.get_model_worker_batch()
930
958
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
931
- logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
932
- model_worker_batch
933
- )
959
+ if self.spec_algorithm.is_none():
960
+ model_worker_batch = batch.get_model_worker_batch()
961
+ logits_output, next_token_ids = (
962
+ self.tp_worker.forward_batch_generation(model_worker_batch)
963
+ )
964
+ else:
965
+ logits_output, next_token_ids, model_worker_batch, spec_info = (
966
+ self.draft_worker.forward_batch_speculative_generation(batch)
967
+ )
968
+ batch.spec_info = spec_info
934
969
  elif batch.forward_mode.is_idle():
935
970
  model_worker_batch = batch.get_model_worker_batch()
936
971
  self.tp_worker.forward_batch_idle(model_worker_batch)
@@ -974,12 +1009,10 @@ class Scheduler:
974
1009
  logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
975
1010
  else:
976
1011
  # Move next_token_ids and logprobs to cpu
1012
+ next_token_ids = next_token_ids.tolist()
977
1013
  if batch.return_logprob:
978
1014
  logits_output.next_token_logprobs = (
979
- logits_output.next_token_logprobs[
980
- torch.arange(len(next_token_ids), device=self.device),
981
- next_token_ids,
982
- ].tolist()
1015
+ logits_output.next_token_logprobs.tolist()
983
1016
  )
984
1017
  logits_output.input_token_logprobs = (
985
1018
  logits_output.input_token_logprobs.tolist()
@@ -987,7 +1020,6 @@ class Scheduler:
987
1020
  logits_output.normalized_prompt_logprobs = (
988
1021
  logits_output.normalized_prompt_logprobs.tolist()
989
1022
  )
990
- next_token_ids = next_token_ids.tolist()
991
1023
 
992
1024
  # Check finish conditions
993
1025
  logprob_pt = 0
@@ -1064,13 +1096,9 @@ class Scheduler:
1064
1096
  logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1065
1097
  next_token_logprobs = logits_output.next_token_logprobs
1066
1098
  else:
1067
- # Move next_token_ids and logprobs to cpu
1068
- if batch.return_logprob:
1069
- next_token_logprobs = logits_output.next_token_logprobs[
1070
- torch.arange(len(next_token_ids), device=self.device),
1071
- next_token_ids,
1072
- ].tolist()
1073
1099
  next_token_ids = next_token_ids.tolist()
1100
+ if batch.return_logprob:
1101
+ next_token_logprobs = logits_output.next_token_logprobs.tolist()
1074
1102
 
1075
1103
  self.token_to_kv_pool.free_group_begin()
1076
1104
 
@@ -1084,7 +1112,10 @@ class Scheduler:
1084
1112
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
1085
1113
  continue
1086
1114
 
1087
- req.output_ids.append(next_token_id)
1115
+ if batch.spec_algorithm.is_none():
1116
+ # speculative worker will solve the output_ids in speculative decoding
1117
+ req.output_ids.append(next_token_id)
1118
+
1088
1119
  req.check_finished()
1089
1120
 
1090
1121
  if req.finished():
@@ -1095,10 +1126,10 @@ class Scheduler:
1095
1126
  req.output_token_logprobs_idx.append(next_token_id)
1096
1127
  if req.top_logprobs_num > 0:
1097
1128
  req.output_top_logprobs_val.append(
1098
- logits_output.output_top_logprobs_val[i]
1129
+ logits_output.next_token_top_logprobs_val[i]
1099
1130
  )
1100
1131
  req.output_top_logprobs_idx.append(
1101
- logits_output.output_top_logprobs_idx[i]
1132
+ logits_output.next_token_top_logprobs_idx[i]
1102
1133
  )
1103
1134
 
1104
1135
  if req.grammar is not None:
@@ -1200,8 +1231,9 @@ class Scheduler:
1200
1231
  req.output_top_logprobs_idx.extend(
1201
1232
  output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
1202
1233
  )
1203
- req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
1204
- req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
1234
+
1235
+ req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
1236
+ req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1205
1237
 
1206
1238
  return num_input_logprobs
1207
1239
 
@@ -1218,6 +1250,7 @@ class Scheduler:
1218
1250
  decode_ids_list = []
1219
1251
  read_offsets = []
1220
1252
  output_ids = []
1253
+ origin_input_ids = []
1221
1254
 
1222
1255
  skip_special_tokens = []
1223
1256
  spaces_between_special_tokens = []
@@ -1257,6 +1290,9 @@ class Scheduler:
1257
1290
  # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1258
1291
  or (not req.stream and len(req.output_ids) % 50 == 0)
1259
1292
  ):
1293
+ if self.draft_worker and req.finished():
1294
+ self.draft_worker.finish_request(req)
1295
+
1260
1296
  rids.append(req.rid)
1261
1297
  finished_reasons.append(
1262
1298
  req.finished_reason.to_json() if req.finished_reason else None
@@ -1266,8 +1302,14 @@ class Scheduler:
1266
1302
  decode_ids, read_offset = req.init_incremental_detokenize()
1267
1303
  decode_ids_list.append(decode_ids)
1268
1304
  read_offsets.append(read_offset)
1269
- if self.skip_tokenizer_init:
1305
+ if self.skip_tokenizer_init or self.server_args.return_token_ids:
1270
1306
  output_ids.append(req.output_ids)
1307
+ else:
1308
+ output_ids = None
1309
+ if self.server_args.return_token_ids:
1310
+ origin_input_ids.append(req.origin_input_ids)
1311
+ else:
1312
+ origin_input_ids = None
1271
1313
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1272
1314
  spaces_between_special_tokens.append(
1273
1315
  req.sampling_params.spaces_between_special_tokens
@@ -1299,6 +1341,7 @@ class Scheduler:
1299
1341
  decoded_texts,
1300
1342
  decode_ids_list,
1301
1343
  read_offsets,
1344
+ origin_input_ids,
1302
1345
  output_ids,
1303
1346
  skip_special_tokens,
1304
1347
  spaces_between_special_tokens,
@@ -1321,11 +1364,11 @@ class Scheduler:
1321
1364
  embeddings = []
1322
1365
  prompt_tokens = []
1323
1366
  for req in reqs:
1324
- assert req.finished()
1325
- rids.append(req.rid)
1326
- finished_reasons.append(req.finished_reason.to_json())
1327
- embeddings.append(req.embedding)
1328
- prompt_tokens.append(len(req.origin_input_ids))
1367
+ if req.finished():
1368
+ rids.append(req.rid)
1369
+ finished_reasons.append(req.finished_reason.to_json())
1370
+ embeddings.append(req.embedding)
1371
+ prompt_tokens.append(len(req.origin_input_ids))
1329
1372
  self.send_to_detokenizer.send_pyobj(
1330
1373
  BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
1331
1374
  )
@@ -1381,6 +1424,7 @@ class Scheduler:
1381
1424
  self.tree_cache,
1382
1425
  self.model_config,
1383
1426
  self.enable_overlap,
1427
+ self.spec_algorithm,
1384
1428
  )
1385
1429
  idle_batch.prepare_for_idle()
1386
1430
  return idle_batch
@@ -222,10 +222,8 @@ class TokenizerManager:
222
222
  is_single = obj.is_single
223
223
  if is_single:
224
224
  tokenized_obj = await self._tokenize_one_request(obj)
225
- self.send_to_scheduler.send_pyobj(tokenized_obj)
226
- async for response in self._wait_one_response(
227
- obj, request, created_time
228
- ):
225
+ self._send_one_request(obj, tokenized_obj, created_time)
226
+ async for response in self._wait_one_response(obj, request):
229
227
  yield response
230
228
  else:
231
229
  async for response in self._handle_batch_request(
@@ -306,16 +304,24 @@ class TokenizerManager:
306
304
 
307
305
  return tokenized_obj
308
306
 
309
- async def _wait_one_response(
307
+ def _send_one_request(
310
308
  self,
311
309
  obj: Union[GenerateReqInput, EmbeddingReqInput],
312
- request: Optional[fastapi.Request] = None,
310
+ tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
313
311
  created_time: Optional[float] = None,
314
312
  ):
315
- """Wait for the response of one request."""
316
313
  event = asyncio.Event()
317
314
  state = ReqState([], False, event, obj, created_time=created_time)
318
315
  self.rid_to_state[obj.rid] = state
316
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
317
+
318
+ async def _wait_one_response(
319
+ self,
320
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
321
+ request: Optional[fastapi.Request] = None,
322
+ ):
323
+ """Wait for the response of one request."""
324
+ state = self.rid_to_state[obj.rid]
319
325
 
320
326
  while True:
321
327
  try:
@@ -361,10 +367,8 @@ class TokenizerManager:
361
367
  for i in range(batch_size):
362
368
  tmp_obj = obj[i]
363
369
  tokenized_obj = await self._tokenize_one_request(tmp_obj)
364
- self.send_to_scheduler.send_pyobj(tokenized_obj)
365
- generators.append(
366
- self._wait_one_response(tmp_obj, request, created_time)
367
- )
370
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
371
+ generators.append(self._wait_one_response(tmp_obj, request))
368
372
  rids.append(tmp_obj.rid)
369
373
  else:
370
374
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
@@ -389,10 +393,8 @@ class TokenizerManager:
389
393
  tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
390
394
  tokenized_obj.sampling_params.max_new_tokens = 0
391
395
  tokenized_obj.stream = False
392
- self.send_to_scheduler.send_pyobj(tokenized_obj)
393
- await self._wait_one_response(
394
- tmp_obj, request, created_time
395
- ).__anext__()
396
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
397
+ await self._wait_one_response(tmp_obj, request).__anext__()
396
398
 
397
399
  # Expand requests, assign new rids for them, and send them
398
400
  for i in range(batch_size):
@@ -400,10 +402,8 @@ class TokenizerManager:
400
402
  tmp_obj = copy.copy(objs[i])
401
403
  tokenized_obj = copy.copy(tokenized_objs[i])
402
404
  tokenized_obj.rid = tmp_obj.regenerate_rid()
403
- self.send_to_scheduler.send_pyobj(tokenized_obj)
404
- generators.append(
405
- self._wait_one_response(tmp_obj, request, created_time)
406
- )
405
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
406
+ generators.append(self._wait_one_response(tmp_obj, request))
407
407
  rids.append(tmp_obj.rid)
408
408
 
409
409
  # Wait for all requests
@@ -663,6 +663,13 @@ class TokenizerManager:
663
663
  "text": recv_obj.output_strs[i],
664
664
  "meta_info": meta_info,
665
665
  }
666
+ if self.server_args.return_token_ids:
667
+ out_dict.update(
668
+ {
669
+ "input_ids": recv_obj.origin_input_ids[i],
670
+ "output_ids": recv_obj.output_ids[i],
671
+ }
672
+ )
666
673
  elif isinstance(recv_obj, BatchTokenIDOut):
667
674
  out_dict = {
668
675
  "token_ids": recv_obj.output_ids[i],
@@ -692,6 +699,7 @@ class TokenizerManager:
692
699
  )
693
700
  else:
694
701
  if completion_tokens >= 2:
702
+ # Compute time_per_output_token for the streaming case
695
703
  self.metrics_collector.observe_time_per_output_token(
696
704
  (time.time() - state.first_token_time)
697
705
  / (completion_tokens - 1)
@@ -707,7 +715,8 @@ class TokenizerManager:
707
715
  self.metrics_collector.observe_e2e_request_latency(
708
716
  time.time() - state.created_time
709
717
  )
710
- if completion_tokens >= 1:
718
+ # Compute time_per_output_token for the non-streaming case
719
+ if not state.obj.stream and completion_tokens >= 1:
711
720
  self.metrics_collector.observe_time_per_output_token(
712
721
  (time.time() - state.created_time)
713
722
  / completion_tokens