sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/srt/configs/model_config.py +11 -2
  3. sglang/srt/layers/attention/__init__.py +0 -1
  4. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  5. sglang/srt/layers/logits_processor.py +30 -2
  6. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  7. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  8. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  71. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
  77. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  78. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  84. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  85. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  89. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  91. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  93. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  95. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  96. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  97. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  98. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  100. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  103. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  104. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  105. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  106. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  107. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  108. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  109. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  112. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  113. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  114. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  115. sglang/srt/layers/quantization/fp8.py +42 -2
  116. sglang/srt/layers/quantization/fp8_kernel.py +77 -18
  117. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  118. sglang/srt/managers/detokenizer_manager.py +2 -0
  119. sglang/srt/managers/io_struct.py +40 -9
  120. sglang/srt/managers/schedule_batch.py +22 -15
  121. sglang/srt/managers/scheduler.py +69 -21
  122. sglang/srt/managers/session_controller.py +102 -27
  123. sglang/srt/managers/tokenizer_manager.py +48 -10
  124. sglang/srt/managers/tp_worker.py +7 -0
  125. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  126. sglang/srt/model_executor/forward_batch_info.py +42 -3
  127. sglang/srt/model_executor/model_runner.py +4 -0
  128. sglang/srt/models/llama.py +11 -0
  129. sglang/srt/models/llama_eagle.py +132 -0
  130. sglang/srt/openai_api/adapter.py +60 -2
  131. sglang/srt/openai_api/protocol.py +48 -0
  132. sglang/srt/server.py +26 -3
  133. sglang/srt/server_args.py +24 -30
  134. sglang/srt/speculative/spec_info.py +19 -0
  135. sglang/srt/utils.py +62 -0
  136. sglang/version.py +1 -1
  137. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/METADATA +3 -3
  138. sglang-0.4.1.post3.dist-info/RECORD +305 -0
  139. sglang-0.4.1.post1.dist-info/RECORD +0 -195
  140. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/LICENSE +0 -0
  141. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/top_level.txt +0 -0
@@ -331,6 +331,7 @@ def throughput_test(
331
331
  extra_request_body=extra_request_body,
332
332
  profile=bench_args.profile,
333
333
  )
334
+ backend.shutdown()
334
335
 
335
336
  if bench_args.result_filename:
336
337
  with open(bench_args.result_filename, "a") as fout:
@@ -15,7 +15,7 @@
15
15
  import json
16
16
  import logging
17
17
  from enum import IntEnum, auto
18
- from typing import List, Optional, Union
18
+ from typing import List, Optional, Set, Union
19
19
 
20
20
  import torch
21
21
  from transformers import PretrainedConfig
@@ -47,6 +47,7 @@ class ModelConfig:
47
47
  self.model_path = model_path
48
48
  self.revision = revision
49
49
  self.quantization = quantization
50
+
50
51
  # Parse args
51
52
  self.model_override_args = json.loads(model_override_args)
52
53
  self.hf_config = get_config(
@@ -130,7 +131,8 @@ class ModelConfig:
130
131
  # Veirfy quantization
131
132
  self._verify_quantization()
132
133
 
133
- # Multimodel attrs
134
+ # Cache attributes
135
+ self.hf_eos_token_id = self.get_hf_eos_token_id()
134
136
  self.image_token_id = getattr(self.hf_config, "image_token_id", None)
135
137
 
136
138
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
@@ -271,6 +273,13 @@ class ModelConfig:
271
273
  self.quantization,
272
274
  )
273
275
 
276
+ def get_hf_eos_token_id(self) -> Optional[Set[int]]:
277
+ eos_ids = getattr(self.hf_config, "eos_token_id", None)
278
+ if eos_ids:
279
+ # it can be either int or list of int
280
+ eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
281
+ return eos_ids
282
+
274
283
 
275
284
  def get_hf_text_config(config: PretrainedConfig):
276
285
  """Get the "sub" config relevant to llm for multi modal models.
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
2
2
  from typing import Optional
3
3
 
4
4
  import torch
5
- from torch import nn
6
5
 
7
6
  from sglang.srt.layers.radix_attention import RadixAttention
8
7
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
8
8
  """
9
9
 
10
10
  import os
11
+ from dataclasses import dataclass
11
12
  from enum import Enum, auto
12
- from typing import TYPE_CHECKING, List
13
+ from typing import TYPE_CHECKING, List, Union
13
14
 
14
15
  import torch
15
16
  import triton
@@ -38,12 +39,25 @@ class WrapperDispatch(Enum):
38
39
  CROSS_ATTENTION = auto()
39
40
 
40
41
 
42
+ @dataclass
43
+ class DecodeMetadata:
44
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
45
+
46
+
47
+ @dataclass
48
+ class PrefillMetadata:
49
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
50
+ use_ragged: bool
51
+ extend_no_prefix: bool
52
+
53
+
41
54
  class FlashInferAttnBackend(AttentionBackend):
42
55
  """Flashinfer attention kernels."""
43
56
 
44
57
  def __init__(self, model_runner: ModelRunner):
45
58
  super().__init__()
46
59
 
60
+ # Parse constants
47
61
  self.decode_use_tensor_cores = should_use_tensor_core(
48
62
  kv_cache_dtype=model_runner.kv_cache_dtype,
49
63
  num_attention_heads=model_runner.model_config.num_attention_heads
@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend):
52
66
  model_runner.tp_size
53
67
  ),
54
68
  )
55
-
56
69
  self.max_context_len = model_runner.model_config.context_len
57
70
 
58
71
  assert not (
@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend):
120
133
  )
121
134
 
122
135
  # Other metadata
123
- self.forward_metadata = None
124
- self.cuda_graph_metadata = {}
136
+ self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
137
+ self.decode_cuda_graph_metadata = {}
125
138
 
126
139
  def init_forward_metadata(self, forward_batch: ForwardBatch):
127
140
  if forward_batch.forward_mode.is_decode():
@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend):
129
142
  forward_batch.req_pool_indices,
130
143
  forward_batch.seq_lens,
131
144
  forward_batch.seq_lens_sum,
132
- decode_wrappers=None,
145
+ decode_wrappers=self.decode_wrappers,
133
146
  encoder_lens=forward_batch.encoder_lens,
134
147
  )
135
- self.forward_metadata = (self.decode_wrappers,)
148
+ self.forward_metadata = DecodeMetadata(self.decode_wrappers)
136
149
  else:
137
150
  prefix_lens = forward_batch.extend_prefix_lens
138
151
 
@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend):
149
162
  forward_batch.seq_lens,
150
163
  forward_batch.seq_lens_sum,
151
164
  prefix_lens,
165
+ prefill_wrappers=self.prefill_wrappers_paged,
152
166
  use_ragged=use_ragged,
153
167
  encoder_lens=forward_batch.encoder_lens,
154
168
  )
155
-
156
- self.forward_metadata = (use_ragged, extend_no_prefix)
169
+ self.forward_metadata = PrefillMetadata(
170
+ self.prefill_wrappers_paged, use_ragged, extend_no_prefix
171
+ )
157
172
 
158
173
  def init_cuda_graph_state(self, max_bs: int):
159
174
  cuda_graph_kv_indices = torch.zeros(
@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend):
194
209
  decode_wrappers=decode_wrappers,
195
210
  encoder_lens=encoder_lens,
196
211
  )
197
- self.cuda_graph_metadata[bs] = decode_wrappers
198
- self.forward_metadata = (decode_wrappers,)
212
+ self.decode_cuda_graph_metadata[bs] = decode_wrappers
213
+ self.forward_metadata = DecodeMetadata(decode_wrappers)
199
214
 
200
215
  def init_forward_metadata_replay_cuda_graph(
201
216
  self,
@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend):
209
224
  req_pool_indices[:bs],
210
225
  seq_lens[:bs],
211
226
  seq_lens_sum,
212
- decode_wrappers=self.cuda_graph_metadata[bs],
227
+ decode_wrappers=self.decode_cuda_graph_metadata[bs],
213
228
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
214
229
  )
215
230
 
@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend):
225
240
  forward_batch: ForwardBatch,
226
241
  save_kv_cache=True,
227
242
  ):
228
- prefill_wrapper_paged = self.prefill_wrappers_paged[
243
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
229
244
  self._get_wrapper_idx(layer)
230
245
  ]
231
-
232
- use_ragged, extend_no_prefix = self.forward_metadata
233
246
  cache_loc = (
234
247
  forward_batch.out_cache_loc
235
248
  if not layer.is_cross_attention
236
249
  else forward_batch.encoder_out_cache_loc
237
250
  )
238
251
 
239
- if not use_ragged:
252
+ if not self.forward_metadata.use_ragged:
240
253
  if k is not None:
241
254
  assert v is not None
242
255
  if save_kv_cache:
@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend):
260
273
  logits_soft_cap=layer.logit_cap,
261
274
  )
262
275
 
263
- if extend_no_prefix:
276
+ if self.forward_metadata.extend_no_prefix:
264
277
  o = o1
265
278
  else:
266
279
  o2, s2 = prefill_wrapper_paged.forward_return_lse(
@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend):
287
300
  forward_batch: ForwardBatch,
288
301
  save_kv_cache=True,
289
302
  ):
290
- decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
303
+ decode_wrapper = self.forward_metadata.decode_wrappers[
304
+ self._get_wrapper_idx(layer)
305
+ ]
291
306
  cache_loc = (
292
307
  forward_batch.out_cache_loc
293
308
  if not layer.is_cross_attention
@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend):
322
337
 
323
338
  class FlashInferIndicesUpdaterDecode:
324
339
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
325
- # Constants
340
+ # Parse Constants
326
341
  self.num_qo_heads = (
327
342
  model_runner.model_config.num_attention_heads // model_runner.tp_size
328
343
  )
@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode:
340
355
  self.kv_indptr = attn_backend.kv_indptr
341
356
  self.kv_last_page_len = attn_backend.kv_last_page_len
342
357
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
343
- self.decode_wrappers = attn_backend.decode_wrappers
344
358
 
345
- # Dispatch
359
+ # Dispatch the update function
346
360
  if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
347
361
  self.update = self.update_sliding_window
348
362
  elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode:
356
370
  req_pool_indices: torch.Tensor,
357
371
  seq_lens: torch.Tensor,
358
372
  seq_lens_sum: int,
359
- decode_wrappers: List,
373
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
360
374
  encoder_lens: torch.Tensor,
361
375
  ):
362
376
  # Keep the signature for type checking. It will be assigned during runtime.
@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode:
367
381
  req_pool_indices: torch.Tensor,
368
382
  seq_lens: torch.Tensor,
369
383
  seq_lens_sum: int,
370
- decode_wrappers: List,
384
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
371
385
  encoder_lens: torch.Tensor,
372
386
  ):
373
387
  decode_wrappers = decode_wrappers or self.decode_wrappers
@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode:
385
399
  req_pool_indices: torch.Tensor,
386
400
  seq_lens: torch.Tensor,
387
401
  seq_lens_sum: int,
388
- decode_wrappers: List,
402
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
389
403
  encoder_lens: torch.Tensor,
390
404
  ):
391
- decode_wrappers = decode_wrappers or self.decode_wrappers
392
-
393
405
  for wrapper_id in range(2):
394
406
  if wrapper_id == 0:
395
407
  # Sliding window attention
@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode:
419
431
  req_pool_indices: torch.Tensor,
420
432
  seq_lens: torch.Tensor,
421
433
  seq_lens_sum: int,
422
- decode_wrappers: List,
434
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
423
435
  encoder_lens: torch.Tensor,
424
436
  ):
425
- decode_wrappers = decode_wrappers or self.decode_wrappers
426
-
427
437
  for wrapper_id in range(2):
428
438
  if wrapper_id == 0:
429
439
  # Normal attention
@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode:
446
456
 
447
457
  def call_begin_forward(
448
458
  self,
449
- wrapper,
459
+ wrapper: BatchDecodeWithPagedKVCacheWrapper,
450
460
  req_pool_indices: torch.Tensor,
451
461
  paged_kernel_lens: torch.Tensor,
452
462
  paged_kernel_lens_sum: int,
@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode:
486
496
 
487
497
  class FlashInferIndicesUpdaterPrefill:
488
498
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
489
- # Constants
499
+ # Parse Constants
490
500
  self.num_qo_heads = (
491
501
  model_runner.model_config.num_attention_heads // model_runner.tp_size
492
502
  )
@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill:
505
515
  self.kv_last_page_len = attn_backend.kv_last_page_len
506
516
  self.qo_indptr = attn_backend.qo_indptr
507
517
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
508
- self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
509
- self.wrappers_paged = attn_backend.prefill_wrappers_paged
518
+ self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
510
519
 
511
- # Dispatch
520
+ # Dispatch the update function
512
521
  if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
513
522
  self.update = self.update_sliding_window
514
523
  elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill:
523
532
  seq_lens: torch.Tensor,
524
533
  seq_lens_sum: int,
525
534
  prefix_lens: torch.Tensor,
535
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
526
536
  use_ragged: bool,
527
537
  encoder_lens: torch.Tensor,
528
538
  ):
@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill:
535
545
  seq_lens: torch.Tensor,
536
546
  seq_lens_sum: int,
537
547
  prefix_lens: torch.Tensor,
548
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
538
549
  use_ragged: bool,
539
550
  encoder_lens: torch.Tensor,
540
551
  ):
@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill:
546
557
  paged_kernel_lens_sum = seq_lens_sum
547
558
 
548
559
  self.call_begin_forward(
549
- self.wrapper_ragged,
550
- self.wrappers_paged[0],
560
+ self.prefill_wrapper_ragged,
561
+ prefill_wrappers[0],
551
562
  req_pool_indices,
552
563
  paged_kernel_lens,
553
564
  paged_kernel_lens_sum,
@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill:
565
576
  seq_lens: torch.Tensor,
566
577
  seq_lens_sum: int,
567
578
  prefix_lens: torch.Tensor,
579
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
568
580
  use_ragged: bool,
569
581
  encoder_lens: torch.Tensor,
570
582
  ):
@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill:
584
596
  kv_start_idx = seq_lens - paged_kernel_lens
585
597
 
586
598
  self.call_begin_forward(
587
- self.wrapper_ragged,
588
- self.wrappers_paged[wrapper_id],
599
+ self.prefill_wrapper_ragged,
600
+ prefill_wrappers[wrapper_id],
589
601
  req_pool_indices,
590
602
  paged_kernel_lens,
591
603
  paged_kernel_lens_sum,
@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill:
603
615
  seq_lens: torch.Tensor,
604
616
  seq_lens_sum: int,
605
617
  prefix_lens: torch.Tensor,
618
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
606
619
  use_ragged: bool,
607
620
  encoder_lens: torch.Tensor,
608
621
  ):
@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill:
619
632
  paged_kernel_lens_sum = paged_kernel_lens.sum().item()
620
633
 
621
634
  self.call_begin_forward(
622
- self.wrapper_ragged,
623
- self.wrappers_paged[wrapper_id],
635
+ self.prefill_wrapper_ragged,
636
+ prefill_wrappers[wrapper_id],
624
637
  req_pool_indices,
625
638
  paged_kernel_lens,
626
639
  paged_kernel_lens_sum,
@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill:
634
647
 
635
648
  def call_begin_forward(
636
649
  self,
637
- wrapper_ragged,
638
- wrapper_paged,
650
+ wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
651
+ wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
639
652
  req_pool_indices: torch.Tensor,
640
653
  paged_kernel_lens: torch.Tensor,
641
654
  paged_kernel_lens_sum: int,
@@ -24,7 +24,11 @@ from vllm.distributed import (
24
24
  )
25
25
 
26
26
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
27
+ from sglang.srt.model_executor.forward_batch_info import (
28
+ CaptureHiddenMode,
29
+ ForwardBatch,
30
+ ForwardMode,
31
+ )
28
32
 
29
33
 
30
34
  @dataclasses.dataclass
@@ -46,6 +50,10 @@ class LogitsProcessorOutput:
46
50
  output_top_logprobs_val: List = None
47
51
  output_top_logprobs_idx: List = None
48
52
 
53
+ # Used by speculative decoding (EAGLE)
54
+ # The output of transformer layers
55
+ hidden_states: Optional[torch.Tensor] = None
56
+
49
57
 
50
58
  @dataclasses.dataclass
51
59
  class LogitsMetadata:
@@ -61,6 +69,8 @@ class LogitsMetadata:
61
69
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
62
70
  extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
63
71
 
72
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
73
+
64
74
  @classmethod
65
75
  def from_forward_batch(cls, forward_batch: ForwardBatch):
66
76
  extend_logprob_pruned_lens_cpu = None
@@ -78,6 +88,11 @@ class LogitsMetadata:
78
88
  else:
79
89
  return_top_logprob = False
80
90
 
91
+ if forward_batch.spec_info:
92
+ capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
93
+ else:
94
+ capture_hidden_mode = CaptureHiddenMode.NULL
95
+
81
96
  return cls(
82
97
  forward_mode=forward_batch.forward_mode,
83
98
  top_logprobs_nums=forward_batch.top_logprobs_nums,
@@ -87,6 +102,7 @@ class LogitsMetadata:
87
102
  extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
88
103
  extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
89
104
  extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
105
+ capture_hidden_mode=capture_hidden_mode,
90
106
  )
91
107
 
92
108
 
@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module):
116
132
  assert isinstance(logits_metadata, LogitsMetadata)
117
133
 
118
134
  # Get the last hidden states and last logits for the next token prediction
119
- if logits_metadata.forward_mode.is_decode():
135
+ if (
136
+ logits_metadata.forward_mode.is_decode()
137
+ or logits_metadata.forward_mode.is_target_verify()
138
+ ):
120
139
  last_index = None
121
140
  last_hidden = hidden_states
122
141
  else:
@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module):
137
156
  if not logits_metadata.return_logprob:
138
157
  return LogitsProcessorOutput(
139
158
  next_token_logits=last_logits,
159
+ hidden_states=(
160
+ hidden_states
161
+ if logits_metadata.capture_hidden_mode.is_full()
162
+ else (
163
+ last_hidden
164
+ if logits_metadata.capture_hidden_mode.is_last()
165
+ else None
166
+ )
167
+ ),
140
168
  )
141
169
  else:
142
170
  last_logprobs = self.compute_temp_top_p_normalized_logprobs(
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 5
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }