sglang 0.4.6.post5__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (359) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_offline_throughput.py +10 -4
  4. sglang/bench_one_batch_server.py +67 -11
  5. sglang/bench_serving.py +86 -75
  6. sglang/lang/backend/runtime_endpoint.py +24 -1
  7. sglang/lang/interpreter.py +40 -1
  8. sglang/lang/ir.py +27 -0
  9. sglang/math_utils.py +8 -0
  10. sglang/profiler.py +167 -0
  11. sglang/srt/_custom_ops.py +34 -0
  12. sglang/srt/configs/internvl.py +8 -12
  13. sglang/srt/configs/model_config.py +33 -1
  14. sglang/srt/constrained/base_grammar_backend.py +5 -2
  15. sglang/srt/constrained/llguidance_backend.py +9 -8
  16. sglang/srt/constrained/outlines_backend.py +5 -4
  17. sglang/srt/constrained/xgrammar_backend.py +18 -18
  18. sglang/srt/conversation.py +52 -8
  19. sglang/srt/custom_op.py +38 -3
  20. sglang/srt/debug_utils.py +74 -0
  21. sglang/srt/disaggregation/base/__init__.py +1 -1
  22. sglang/srt/disaggregation/base/conn.py +25 -11
  23. sglang/srt/disaggregation/common/__init__.py +5 -0
  24. sglang/srt/disaggregation/common/conn.py +407 -0
  25. sglang/srt/disaggregation/common/utils.py +42 -0
  26. sglang/srt/disaggregation/decode.py +261 -52
  27. sglang/srt/disaggregation/fake/__init__.py +1 -1
  28. sglang/srt/disaggregation/fake/conn.py +16 -9
  29. sglang/srt/disaggregation/kv_events.py +60 -5
  30. sglang/srt/disaggregation/launch_lb.py +140 -0
  31. sglang/srt/disaggregation/mini_lb.py +29 -48
  32. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  33. sglang/srt/disaggregation/mooncake/conn.py +446 -149
  34. sglang/srt/disaggregation/mooncake/transfer_engine.py +32 -16
  35. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  36. sglang/srt/disaggregation/nixl/conn.py +134 -437
  37. sglang/srt/disaggregation/prefill.py +130 -43
  38. sglang/srt/disaggregation/utils.py +127 -86
  39. sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
  40. sglang/srt/distributed/parallel_state.py +52 -5
  41. sglang/srt/entrypoints/EngineBase.py +6 -0
  42. sglang/srt/entrypoints/engine.py +116 -5
  43. sglang/srt/entrypoints/http_server.py +28 -4
  44. sglang/srt/eplb_simulator/__init__.py +1 -0
  45. sglang/srt/eplb_simulator/reader.py +51 -0
  46. sglang/srt/function_call/base_format_detector.py +138 -86
  47. sglang/srt/function_call/deepseekv3_detector.py +54 -6
  48. sglang/srt/function_call/ebnf_composer.py +33 -19
  49. sglang/srt/function_call/function_call_parser.py +27 -0
  50. sglang/srt/function_call/llama32_detector.py +33 -14
  51. sglang/srt/function_call/mistral_detector.py +73 -26
  52. sglang/srt/function_call/pythonic_detector.py +86 -20
  53. sglang/srt/function_call/qwen25_detector.py +64 -10
  54. sglang/srt/function_call/utils.py +17 -0
  55. sglang/srt/hf_transformers_utils.py +4 -0
  56. sglang/srt/layers/activation.py +19 -0
  57. sglang/srt/layers/attention/aiter_backend.py +503 -125
  58. sglang/srt/layers/attention/base_attn_backend.py +4 -0
  59. sglang/srt/layers/attention/cutlass_mla_backend.py +40 -34
  60. sglang/srt/layers/attention/flashattention_backend.py +137 -63
  61. sglang/srt/layers/attention/flashinfer_backend.py +46 -3
  62. sglang/srt/layers/attention/flashinfer_mla_backend.py +59 -25
  63. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  64. sglang/srt/layers/attention/intel_amx_backend.py +128 -0
  65. sglang/srt/layers/attention/tbo_backend.py +232 -0
  66. sglang/srt/layers/attention/torch_native_backend.py +3 -0
  67. sglang/srt/layers/attention/triton_backend.py +304 -65
  68. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  69. sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
  70. sglang/srt/layers/attention/vision.py +51 -24
  71. sglang/srt/layers/communicator.py +281 -197
  72. sglang/srt/layers/dp_attention.py +6 -5
  73. sglang/srt/layers/layernorm.py +30 -19
  74. sglang/srt/layers/linear.py +0 -4
  75. sglang/srt/layers/logits_processor.py +0 -12
  76. sglang/srt/layers/moe/cutlass_moe.py +170 -7
  77. sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
  78. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  79. sglang/srt/layers/moe/ep_moe/layer.py +136 -72
  80. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +24 -45
  81. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  82. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  84. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  85. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +221 -29
  91. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -4
  92. sglang/srt/layers/moe/topk.py +60 -26
  93. sglang/srt/layers/multimodal.py +3 -3
  94. sglang/srt/layers/pooler.py +56 -0
  95. sglang/srt/layers/quantization/__init__.py +3 -2
  96. sglang/srt/layers/quantization/blockwise_int8.py +3 -0
  97. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
  98. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  99. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +69 -127
  100. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  101. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  102. sglang/srt/layers/quantization/fp8.py +28 -23
  103. sglang/srt/layers/quantization/fp8_kernel.py +156 -75
  104. sglang/srt/layers/quantization/fp8_utils.py +250 -69
  105. sglang/srt/layers/quantization/modelopt_quant.py +334 -7
  106. sglang/srt/layers/quantization/moe_wna16.py +3 -0
  107. sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
  108. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  109. sglang/srt/layers/radix_attention.py +2 -3
  110. sglang/srt/layers/rotary_embedding.py +6 -12
  111. sglang/srt/layers/sampler.py +80 -79
  112. sglang/srt/layers/utils.py +6 -0
  113. sglang/srt/lora/layers.py +12 -15
  114. sglang/srt/lora/lora.py +49 -5
  115. sglang/srt/lora/lora_manager.py +98 -39
  116. sglang/srt/lora/mem_pool.py +28 -21
  117. sglang/srt/lora/utils.py +17 -13
  118. sglang/srt/managers/cache_controller.py +2 -1
  119. sglang/srt/managers/data_parallel_controller.py +13 -5
  120. sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
  121. sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
  122. sglang/srt/managers/{deepseek_eplb.py → eplb_algorithms/deepseek_vec.py} +5 -7
  123. sglang/srt/managers/eplb_manager.py +55 -14
  124. sglang/srt/managers/expert_distribution.py +220 -46
  125. sglang/srt/managers/expert_location.py +110 -56
  126. sglang/srt/managers/expert_location_dispatch.py +23 -6
  127. sglang/srt/managers/io_struct.py +43 -8
  128. sglang/srt/managers/mm_utils.py +88 -38
  129. sglang/srt/managers/multimodal_processors/base_processor.py +190 -18
  130. sglang/srt/managers/multimodal_processors/gemma3.py +4 -31
  131. sglang/srt/managers/multimodal_processors/internvl.py +4 -0
  132. sglang/srt/managers/multimodal_processors/kimi_vl.py +15 -34
  133. sglang/srt/managers/multimodal_processors/minicpm.py +2 -1
  134. sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
  135. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -64
  136. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  137. sglang/srt/managers/schedule_batch.py +173 -38
  138. sglang/srt/managers/scheduler.py +376 -127
  139. sglang/srt/managers/tokenizer_manager.py +163 -19
  140. sglang/srt/managers/utils.py +0 -4
  141. sglang/srt/mem_cache/chunk_cache.py +1 -0
  142. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  143. sglang/srt/mem_cache/memory_pool.py +111 -407
  144. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  145. sglang/srt/mem_cache/radix_cache.py +36 -12
  146. sglang/srt/metrics/collector.py +9 -0
  147. sglang/srt/model_executor/cuda_graph_runner.py +191 -113
  148. sglang/srt/model_executor/expert_location_updater.py +157 -22
  149. sglang/srt/model_executor/forward_batch_info.py +52 -22
  150. sglang/srt/model_executor/model_runner.py +102 -62
  151. sglang/srt/model_loader/loader.py +8 -1
  152. sglang/srt/model_loader/utils.py +67 -1
  153. sglang/srt/models/bert.py +113 -13
  154. sglang/srt/models/deepseek_nextn.py +1 -1
  155. sglang/srt/models/deepseek_v2.py +623 -290
  156. sglang/srt/models/gemma3_causal.py +7 -0
  157. sglang/srt/models/gemma3_mm.py +19 -14
  158. sglang/srt/models/idefics2.py +342 -0
  159. sglang/srt/models/internvl.py +46 -102
  160. sglang/srt/models/kimi_vl.py +4 -4
  161. sglang/srt/models/llama.py +1 -1
  162. sglang/srt/models/minicpmo.py +2 -5
  163. sglang/srt/models/minicpmv.py +3 -295
  164. sglang/srt/models/phi4mm.py +512 -0
  165. sglang/srt/models/qwen2.py +38 -9
  166. sglang/srt/models/qwen2_5_vl.py +3 -9
  167. sglang/srt/models/qwen2_eagle.py +4 -1
  168. sglang/srt/models/qwen2_moe.py +58 -191
  169. sglang/srt/models/qwen2_vl.py +3 -9
  170. sglang/srt/models/qwen3.py +41 -10
  171. sglang/srt/models/qwen3_moe.py +230 -191
  172. sglang/srt/models/registry.py +9 -1
  173. sglang/srt/models/roberta.py +117 -9
  174. sglang/srt/models/transformers.py +291 -0
  175. sglang/srt/models/vila.py +305 -0
  176. sglang/srt/openai_api/adapter.py +248 -28
  177. sglang/srt/openai_api/protocol.py +68 -3
  178. sglang/srt/openai_api/utils.py +172 -0
  179. sglang/srt/operations.py +37 -2
  180. sglang/srt/operations_strategy.py +200 -24
  181. sglang/srt/sampling/sampling_batch_info.py +37 -1
  182. sglang/srt/sampling/sampling_params.py +4 -1
  183. sglang/srt/server_args.py +381 -209
  184. sglang/srt/speculative/build_eagle_tree.py +9 -9
  185. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +12 -14
  186. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +256 -0
  187. sglang/srt/speculative/eagle_utils.py +440 -200
  188. sglang/srt/speculative/eagle_worker.py +234 -63
  189. sglang/srt/two_batch_overlap.py +637 -0
  190. sglang/srt/utils.py +187 -7
  191. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  192. sglang/test/runners.py +54 -10
  193. sglang/test/send_one.py +4 -0
  194. sglang/test/test_block_fp8.py +1 -0
  195. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  196. sglang/test/test_block_fp8_ep.py +1 -0
  197. sglang/test/test_cutlass_moe.py +3 -3
  198. sglang/test/test_fp4_moe.py +248 -0
  199. sglang/test/test_utils.py +82 -7
  200. sglang/utils.py +9 -0
  201. sglang/version.py +1 -1
  202. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +17 -14
  203. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +359 -321
  204. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +1 -1
  205. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  206. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  207. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  208. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  209. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  210. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  211. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  212. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  213. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  214. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  215. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  216. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  217. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  218. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
  219. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  220. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  221. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  222. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  223. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  224. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  225. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  226. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  227. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  228. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  229. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  230. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  231. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  232. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  233. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  234. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  235. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  236. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  237. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  238. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  239. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  240. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  241. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  242. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  243. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  244. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  245. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  246. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  247. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
  248. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  249. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  250. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  251. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  252. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  253. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  254. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  255. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  256. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  257. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
  258. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
  259. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  260. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  261. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  262. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  263. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  264. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  265. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
  266. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  267. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  268. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
  269. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  270. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  271. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  272. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
  273. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  274. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  275. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  276. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  277. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  278. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  279. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  280. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
  281. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
  282. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
  283. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
  284. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  285. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  286. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
  287. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
  288. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
  289. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
  290. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  291. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  292. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  293. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  294. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
  295. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  296. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  297. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  298. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  299. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
  300. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
  301. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
  302. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
  303. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  304. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  305. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  306. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  307. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  308. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  309. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
  310. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
  311. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  312. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  313. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  314. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  315. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  316. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  317. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  318. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
  319. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
  320. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
  321. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
  322. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  323. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  324. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  325. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  326. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
  327. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  328. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  329. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  330. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  331. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  332. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  333. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  334. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
  335. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
  336. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  337. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  338. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
  339. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  340. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
  341. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  342. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  343. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  344. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  345. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
  346. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  347. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
  348. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
  349. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  350. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  351. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  352. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  353. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  354. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  355. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  356. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  357. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  358. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  359. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import os
5
+ import time
4
6
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, List, Optional
7
+ from typing import List, Optional
6
8
 
7
9
  import torch
8
10
  import torch.nn.functional as F
@@ -12,6 +14,7 @@ import triton.language as tl
12
14
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
13
15
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
14
16
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
17
+ from sglang.srt.layers.sampler import apply_custom_logit_processor
15
18
  from sglang.srt.managers.schedule_batch import (
16
19
  Req,
17
20
  ScheduleBatch,
@@ -19,10 +22,8 @@ from sglang.srt.managers.schedule_batch import (
19
22
  global_server_args_dict,
20
23
  )
21
24
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
22
- from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
23
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
24
- from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
25
- from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
25
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
26
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
26
27
 
27
28
  if is_cuda():
28
29
  from sgl_kernel import (
@@ -31,18 +32,19 @@ if is_cuda():
31
32
  tree_speculative_sampling_target_only,
32
33
  verify_tree_greedy,
33
34
  )
35
+ from sgl_kernel.top_k import fast_topk
34
36
  elif is_hip():
35
37
  from sgl_kernel import verify_tree_greedy
36
38
 
37
- if TYPE_CHECKING:
38
- from sglang.srt.managers.schedule_batch import ScheduleBatch
39
-
40
- import logging
41
39
 
42
40
  logger = logging.getLogger(__name__)
43
41
 
44
42
 
43
+ # Simulate acceptance length for benchmarking purposes
45
44
  SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
45
+ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
46
+
47
+ TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
46
48
 
47
49
 
48
50
  @dataclass
@@ -66,8 +68,6 @@ class EagleDraftInput:
66
68
  kv_indptr: torch.Tensor = None
67
69
  kv_indices: torch.Tensor = None
68
70
 
69
- all_padding_lens: Optional[torch.Tensor] = None
70
-
71
71
  def prepare_for_extend(self, batch: ScheduleBatch):
72
72
  # Prefill only generate 1 token.
73
73
  assert len(self.verified_id) == len(batch.seq_lens)
@@ -85,32 +85,29 @@ class EagleDraftInput:
85
85
  batch: ScheduleBatch,
86
86
  speculative_num_steps: int,
87
87
  ):
88
- assert len(self.verified_id) == len(batch.out_cache_loc)
89
- accept_length_cpu = batch.spec_info.accept_length_cpu
90
- batch.extend_lens = [x + 1 for x in accept_length_cpu]
88
+ batch.forward_mode = ForwardMode.DRAFT_EXTEND
89
+ batch.input_ids = self.verified_id
90
+ batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
91
91
  batch.extend_num_tokens = sum(batch.extend_lens)
92
92
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
93
93
  batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
94
- seq_lens_cpu = batch.seq_lens.tolist()
94
+ batch.return_logprob = False
95
+ batch.return_hidden_states = False
95
96
 
96
- self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
97
- new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
97
+ self.capture_hidden_mode = CaptureHiddenMode.LAST
98
98
  self.accept_length.add_(1)
99
+ self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
100
+ self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
99
101
 
100
- create_extend_spec_info[(self.accept_length.numel(),)](
101
- self.verified_id,
102
+ create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
103
+ batch.input_ids,
102
104
  batch.seq_lens,
103
105
  self.accept_length,
104
- torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
105
106
  self.positions,
106
- new_verified_id,
107
- next_power_of_2(speculative_num_steps + 1),
107
+ self.verified_id,
108
+ next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
108
109
  )
109
110
 
110
- batch.seq_lens_sum = sum(seq_lens_cpu)
111
- batch.input_ids = self.verified_id
112
- self.verified_id = new_verified_id
113
-
114
111
  def generate_attn_arg_prefill(
115
112
  self,
116
113
  req_pool_indices: torch.Tensor,
@@ -119,15 +116,17 @@ class EagleDraftInput:
119
116
  req_to_token: torch.Tensor,
120
117
  ):
121
118
  bs = self.accept_length.numel()
122
-
123
119
  qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
124
120
  qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
125
-
126
121
  cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
127
122
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
128
123
 
129
- # TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
130
- kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
124
+ if paged_kernel_lens_sum is None:
125
+ paged_kernel_lens_sum = cum_kv_seq_len[-1]
126
+
127
+ kv_indices = torch.empty(
128
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
129
+ )
131
130
 
132
131
  create_flashinfer_kv_indices_triton[(bs,)](
133
132
  req_to_token,
@@ -138,7 +137,6 @@ class EagleDraftInput:
138
137
  kv_indices,
139
138
  req_to_token.size(1),
140
139
  )
141
-
142
140
  return kv_indices, cum_kv_seq_len, qo_indptr, None
143
141
 
144
142
  def filter_batch(self, new_indices: torch.Tensor):
@@ -187,56 +185,14 @@ class EagleVerifyInput:
187
185
  retrive_next_token: torch.Tensor
188
186
  retrive_next_sibling: torch.Tensor
189
187
  retrive_cum_len: torch.Tensor
190
- draft_token_num: int
191
188
  spec_steps: int
189
+ topk: int
190
+ draft_token_num: int
192
191
  capture_hidden_mode: CaptureHiddenMode
192
+ seq_lens_sum: int
193
+ seq_lens_cpu: torch.Tensor
193
194
  grammar: BaseGrammarObject = None
194
195
 
195
- @classmethod
196
- def create(
197
- cls,
198
- verified_id: torch.Tensor,
199
- score_list: List[torch.Tensor],
200
- token_list: List[torch.Tensor],
201
- parents_list: List[torch.Tensor],
202
- seq_lens: torch.Tensor,
203
- seq_lens_sum: int,
204
- topk: int,
205
- spec_steps: int,
206
- num_verify_tokens: int,
207
- ):
208
- (
209
- tree_mask,
210
- position,
211
- retrive_index,
212
- retrive_next_token,
213
- retrive_next_sibling,
214
- draft_tokens,
215
- ) = build_tree_kernel_efficient(
216
- verified_id,
217
- score_list,
218
- token_list,
219
- parents_list,
220
- seq_lens,
221
- seq_lens_sum,
222
- topk,
223
- spec_steps,
224
- num_verify_tokens,
225
- )
226
-
227
- return cls(
228
- draft_tokens,
229
- tree_mask,
230
- position,
231
- retrive_index,
232
- retrive_next_token,
233
- retrive_next_sibling,
234
- None,
235
- num_verify_tokens,
236
- spec_steps,
237
- CaptureHiddenMode.FULL,
238
- )
239
-
240
196
  def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
241
197
  batch.input_ids = self.draft_token
242
198
 
@@ -311,7 +267,7 @@ class EagleVerifyInput:
311
267
  logits_output: torch.Tensor,
312
268
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
313
269
  page_size: int,
314
- vocab_mask: Optional[torch.Tensor] = None,
270
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
315
271
  ) -> torch.Tensor:
316
272
  """
317
273
  Verify and find accepted tokens based on logits output and batch
@@ -335,6 +291,14 @@ class EagleVerifyInput:
335
291
  )
336
292
  accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
337
293
 
294
+ # Apply the custom logit processors if registered in the sampling info.
295
+ if sampling_info.has_custom_logit_processor:
296
+ apply_custom_logit_processor(
297
+ logits_output.next_token_logits,
298
+ sampling_info,
299
+ num_tokens_in_batch=self.draft_token_num,
300
+ )
301
+
338
302
  # Apply penalty
339
303
  if sampling_info.penalizer_orchestrator.is_required:
340
304
  # This is a relaxed version of penalties for speculative decoding.
@@ -364,11 +328,11 @@ class EagleVerifyInput:
364
328
  predicts=predict, # mutable
365
329
  accept_index=accept_index, # mutable
366
330
  accept_token_num=accept_length, # mutable
367
- candidates=candidates.to(torch.int32),
368
- retrive_index=self.retrive_index.to(torch.int32),
369
- retrive_next_token=self.retrive_next_token.to(torch.int32),
370
- retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
371
- target_predict=target_predict.to(torch.int32),
331
+ candidates=candidates,
332
+ retrive_index=self.retrive_index,
333
+ retrive_next_token=self.retrive_next_token,
334
+ retrive_next_sibling=self.retrive_next_sibling,
335
+ target_predict=target_predict,
372
336
  )
373
337
  else:
374
338
  # apply temperature and get target probs
@@ -396,16 +360,23 @@ class EagleVerifyInput:
396
360
  draft_probs = torch.zeros(
397
361
  target_probs.shape, dtype=torch.float32, device="cuda"
398
362
  )
363
+
364
+ # coins for rejection sampling
399
365
  coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
366
+ # coins for final sampling
367
+ coins_for_final_sampling = torch.rand(
368
+ (bs,), dtype=torch.float32, device="cuda"
369
+ )
400
370
  tree_speculative_sampling_target_only(
401
371
  predicts=predict, # mutable
402
372
  accept_index=accept_index, # mutable
403
373
  accept_token_num=accept_length, # mutable
404
- candidates=candidates.to(torch.int32),
405
- retrive_index=self.retrive_index.to(torch.int32),
406
- retrive_next_token=self.retrive_next_token.to(torch.int32),
407
- retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
374
+ candidates=candidates,
375
+ retrive_index=self.retrive_index,
376
+ retrive_next_token=self.retrive_next_token,
377
+ retrive_next_sibling=self.retrive_next_sibling,
408
378
  uniform_samples=coins,
379
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
409
380
  target_probs=target_probs,
410
381
  draft_probs=draft_probs,
411
382
  threshold_single=global_server_args_dict[
@@ -428,8 +399,8 @@ class EagleVerifyInput:
428
399
  spec_steps=self.spec_steps,
429
400
  )
430
401
 
431
- new_accept_index = []
432
402
  unfinished_index = []
403
+ unfinished_accept_index = []
433
404
  accept_index_cpu = accept_index.tolist()
434
405
  predict_cpu = predict.tolist()
435
406
  has_finished = False
@@ -437,12 +408,10 @@ class EagleVerifyInput:
437
408
  # Iterate every accepted token and check if req has finished after append the token
438
409
  # should be checked BEFORE free kv cache slots
439
410
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
440
- new_accept_index_ = []
441
411
  for j, idx in enumerate(accept_index_row):
442
412
  if idx == -1:
443
413
  break
444
414
  id = predict_cpu[idx]
445
- # if not found_finished:
446
415
  req.output_ids.append(id)
447
416
  req.check_finished()
448
417
  if req.finished():
@@ -451,8 +420,6 @@ class EagleVerifyInput:
451
420
  accept_index[i, j + 1 :] = -1
452
421
  break
453
422
  else:
454
- new_accept_index_.append(idx)
455
- # update grammar state
456
423
  if req.grammar is not None:
457
424
  try:
458
425
  req.grammar.accept_token(id)
@@ -462,50 +429,104 @@ class EagleVerifyInput:
462
429
  )
463
430
  raise e
464
431
  if not req.finished():
465
- new_accept_index.extend(new_accept_index_)
466
432
  unfinished_index.append(i)
433
+ if idx == -1:
434
+ unfinished_accept_index.append(accept_index[i, :j])
435
+ else:
436
+ unfinished_accept_index.append(accept_index[i])
467
437
  req.spec_verify_ct += 1
468
438
 
469
439
  if has_finished:
470
440
  accept_length = (accept_index != -1).sum(dim=1) - 1
471
441
 
472
442
  # Free the KV cache for unaccepted tokens
443
+ # TODO: fuse them
473
444
  accept_index = accept_index[accept_index != -1]
474
445
  verified_id = predict[accept_index]
475
446
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
476
447
  evict_mask[accept_index] = False
477
448
 
478
- if page_size != 1:
479
- align_evict_mask_to_page_size[len(batch.seq_lens),](
480
- batch.seq_lens,
481
- evict_mask,
482
- page_size,
483
- self.draft_token_num,
484
- next_power_of_2(self.draft_token_num),
485
- )
449
+ if page_size == 1:
450
+ # TODO: boolean array index leads to a device sync. Remove it.
451
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
452
+ else:
453
+ if self.topk == 1:
454
+ # Only evict full empty page. Do not evict partial empty page
455
+ align_evict_mask_to_page_size[len(batch.seq_lens),](
456
+ batch.seq_lens,
457
+ evict_mask,
458
+ page_size,
459
+ self.draft_token_num,
460
+ next_power_of_2(self.draft_token_num),
461
+ )
462
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
463
+ else:
464
+ # Shift the accepted tokens to the beginning.
465
+ # Only evict the last part
466
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
467
+ batch.seq_lens,
468
+ batch.out_cache_loc,
469
+ accept_index,
470
+ accept_length,
471
+ self.draft_token_num,
472
+ page_size,
473
+ )
474
+ to_free_slots = torch.empty(
475
+ (to_free_num_slots.sum().item(),),
476
+ dtype=torch.int64,
477
+ device=to_free_num_slots.device,
478
+ )
486
479
 
487
- token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
480
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
481
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
482
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
483
+ # to_free_slots: [ 2, 5, 7 8]
484
+ # to_free_slots also needs to be page-aligned without the first partial page
485
+ #
486
+ # split each row of out_cache_loc into two parts.
487
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
488
+ # 2. the second part goes to to_free_slots.
489
+ get_target_cache_loc[(bs,)](
490
+ tgt_cache_loc,
491
+ to_free_slots,
492
+ accept_length,
493
+ to_free_num_slots,
494
+ batch.out_cache_loc,
495
+ self.draft_token_num,
496
+ next_power_of_2(self.draft_token_num),
497
+ next_power_of_2(bs),
498
+ )
499
+
500
+ # Free the kv cache
501
+ token_to_kv_pool_allocator.free(to_free_slots)
502
+
503
+ # Copy the kv cache
504
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
505
+ tgt_cache_loc, src_cache_loc
506
+ )
488
507
 
489
508
  # Construct EagleVerifyOutput
490
509
  if not has_finished:
491
- batch.out_cache_loc = batch.out_cache_loc[accept_index]
492
- assign_req_to_token_pool[(bs,)](
493
- batch.req_pool_indices,
494
- batch.req_to_token_pool.req_to_token,
495
- batch.seq_lens,
496
- batch.seq_lens + accept_length + 1,
497
- batch.out_cache_loc,
498
- batch.req_to_token_pool.req_to_token.shape[1],
499
- next_power_of_2(bs),
500
- )
510
+ if page_size == 1 or self.topk == 1:
511
+ batch.out_cache_loc = batch.out_cache_loc[accept_index]
512
+ assign_req_to_token_pool[(bs,)](
513
+ batch.req_pool_indices,
514
+ batch.req_to_token_pool.req_to_token,
515
+ batch.seq_lens,
516
+ batch.seq_lens + accept_length + 1,
517
+ batch.out_cache_loc,
518
+ batch.req_to_token_pool.req_to_token.shape[1],
519
+ next_power_of_2(bs),
520
+ )
521
+ else:
522
+ batch.out_cache_loc = tgt_cache_loc
501
523
  batch.seq_lens.add_(accept_length + 1)
502
- accept_length_cpu = accept_length.tolist()
503
524
 
504
525
  draft_input = EagleDraftInput()
505
526
  draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
506
527
  draft_input.verified_id = verified_id
507
528
  draft_input.accept_length = accept_length
508
- draft_input.accept_length_cpu = accept_length_cpu
529
+ draft_input.accept_length_cpu = accept_length.tolist()
509
530
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
510
531
  draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
511
532
 
@@ -513,47 +534,66 @@ class EagleVerifyInput:
513
534
  draft_input=draft_input,
514
535
  logits_output=logits_output,
515
536
  verified_id=verified_id,
516
- accept_length_per_req_cpu=accept_length_cpu,
537
+ accept_length_per_req_cpu=draft_input.accept_length_cpu,
517
538
  accepted_indices=accept_index,
518
539
  )
519
540
  else:
520
- assign_req_to_token_pool[(bs,)](
521
- batch.req_pool_indices,
522
- batch.req_to_token_pool.req_to_token,
523
- batch.seq_lens,
524
- batch.seq_lens + accept_length + 1,
525
- batch.out_cache_loc[accept_index],
526
- batch.req_to_token_pool.req_to_token.shape[1],
527
- next_power_of_2(bs),
528
- )
529
- batch.seq_lens.add_(accept_length + 1)
530
- accept_length_cpu = accept_length.tolist()
541
+ if page_size == 1 or self.topk == 1:
542
+ assign_req_to_token_pool[(bs,)](
543
+ batch.req_pool_indices,
544
+ batch.req_to_token_pool.req_to_token,
545
+ batch.seq_lens,
546
+ batch.seq_lens + accept_length + 1,
547
+ batch.out_cache_loc[accept_index],
548
+ batch.req_to_token_pool.req_to_token.shape[1],
549
+ next_power_of_2(bs),
550
+ )
551
+ batch.seq_lens.add_(accept_length + 1)
531
552
 
553
+ accept_length_cpu = accept_length.tolist()
532
554
  draft_input = EagleDraftInput()
533
- if len(new_accept_index) > 0:
534
- new_accept_index = torch.tensor(new_accept_index, device="cuda")
535
- unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
536
- draft_input.hidden_states = batch.spec_info.hidden_states[
537
- new_accept_index
538
- ]
539
- draft_input.verified_id = predict[new_accept_index]
540
- draft_input.accept_length_cpu = [
555
+ if len(unfinished_accept_index) > 0:
556
+ unfinished_accept_index = torch.cat(unfinished_accept_index)
557
+ unfinished_index_device = torch.tensor(
558
+ unfinished_index, dtype=torch.int64, device=predict.device
559
+ )
560
+ draft_input_accept_length_cpu = [
541
561
  accept_length_cpu[i] for i in unfinished_index
542
562
  ]
543
- draft_input.accept_length = accept_length[unfinished_index_device]
544
- if has_finished:
545
- draft_input.seq_lens_for_draft_extend = batch.seq_lens[
546
- unfinished_index_device
547
- ]
548
- draft_input.req_pool_indices_for_draft_extend = (
549
- batch.req_pool_indices[unfinished_index_device]
550
- )
563
+ if page_size == 1 or self.topk == 1:
564
+ batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
551
565
  else:
552
- draft_input.seq_lens_for_draft_extend = batch.seq_lens
553
- draft_input.req_pool_indices_for_draft_extend = (
554
- batch.req_pool_indices
566
+ batch.out_cache_loc = torch.empty(
567
+ len(unfinished_index) + sum(draft_input_accept_length_cpu),
568
+ dtype=torch.int64,
569
+ device=predict.device,
570
+ )
571
+ accept_length_filter = create_accept_length_filter(
572
+ accept_length,
573
+ unfinished_index_device,
574
+ batch.seq_lens,
555
575
  )
556
- batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
576
+ filter_finished_cache_loc_kernel[(bs,)](
577
+ batch.out_cache_loc,
578
+ tgt_cache_loc,
579
+ accept_length,
580
+ accept_length_filter,
581
+ next_power_of_2(bs),
582
+ next_power_of_2(self.draft_token_num),
583
+ )
584
+
585
+ draft_input.hidden_states = batch.spec_info.hidden_states[
586
+ unfinished_accept_index
587
+ ]
588
+ draft_input.verified_id = predict[unfinished_accept_index]
589
+ draft_input.accept_length_cpu = draft_input_accept_length_cpu
590
+ draft_input.accept_length = accept_length[unfinished_index_device]
591
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[
592
+ unfinished_index_device
593
+ ]
594
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
595
+ unfinished_index_device
596
+ ]
557
597
 
558
598
  return EagleVerifyOutput(
559
599
  draft_input=draft_input,
@@ -565,26 +605,28 @@ class EagleVerifyInput:
565
605
 
566
606
 
567
607
  @triton.jit
568
- def create_extend_spec_info(
608
+ def create_extend_after_decode_spec_info(
569
609
  verified_id,
570
- seq_len,
571
- accept_len,
572
- accept_len_cum,
610
+ seq_lens,
611
+ accept_lens,
573
612
  positions,
574
613
  new_verified_id,
575
- accept_len_upper: tl.constexpr,
614
+ bs_upper: tl.constexpr,
576
615
  ):
577
616
  pid = tl.program_id(axis=0)
578
- offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
579
- seq_length = tl.load(seq_len + pid)
580
- accept_length = tl.load(accept_len + pid)
581
- positions_ptr = positions + offset
582
- data = tl.arange(0, accept_len_upper)
583
- mask = data < accept_length
584
- tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
585
-
586
- offset = tl.load(accept_len_cum + pid) - 1
587
- verified_id_data = tl.load(verified_id + offset)
617
+ offsets = tl.arange(0, bs_upper)
618
+ seq_length = tl.load(seq_lens + pid)
619
+ accept_length = tl.load(accept_lens + pid)
620
+
621
+ accept_len_cumsum = tl.sum(
622
+ tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
623
+ )
624
+ positions_ptr = positions + accept_len_cumsum
625
+ mask = offsets < accept_length
626
+ tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
627
+
628
+ accept_len_cumsum += accept_length - 1
629
+ verified_id_data = tl.load(verified_id + accept_len_cumsum)
588
630
  tl.store(new_verified_id + pid, verified_id_data)
589
631
 
590
632
 
@@ -605,8 +647,8 @@ def assign_req_to_token_pool(
605
647
  token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
606
648
 
607
649
  length_offset = tl.arange(0, bs_upper)
608
- start = tl.load(start_offset + length_offset, mask=length_offset < pid)
609
- end = tl.load(end_offset + length_offset, mask=length_offset < pid)
650
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
651
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
610
652
  out_offset = tl.sum(end - start, axis=0)
611
653
 
612
654
  out_cache_ptr = out_cache_loc + out_offset
@@ -628,36 +670,75 @@ def assign_draft_cache_locs(
628
670
  req_pool_indices,
629
671
  req_to_token,
630
672
  seq_lens,
673
+ extend_lens,
674
+ num_new_pages_per_topk,
631
675
  out_cache_loc,
632
676
  pool_len: tl.constexpr,
633
677
  topk: tl.constexpr,
634
678
  speculative_num_steps: tl.constexpr,
635
679
  page_size: tl.constexpr,
680
+ bs_upper: tl.constexpr,
681
+ iter_upper: tl.constexpr,
636
682
  ):
637
- BLOCK_SIZE: tl.constexpr = 32
683
+ BLOCK_SIZE: tl.constexpr = 128
638
684
  pid = tl.program_id(axis=0)
639
- kv_start = tl.load(seq_lens + pid)
640
685
 
641
686
  if page_size == 1 or topk == 1:
642
- kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
687
+ copy_len = topk * speculative_num_steps
643
688
  out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
644
689
  else:
645
- prefix_len = tl.load(seq_lens + pid)
646
- last_page_len = prefix_len % page_size
647
- num_new_page = (
648
- last_page_len + speculative_num_steps + page_size - 1
649
- ) // page_size
650
- kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
690
+ bs_offset = tl.arange(0, bs_upper)
691
+ copy_len = tl.load(extend_lens + pid)
692
+ cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
693
+ out_cache_ptr = out_cache_loc + cum_copy_len
651
694
 
695
+ # Part 1: Copy from out_cache_loc to req_to_token
696
+ kv_start = tl.load(seq_lens + pid)
652
697
  token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
653
-
654
- num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
698
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
655
699
  for i in range(num_loop):
656
- save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
657
- load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
658
- mask = save_offset < kv_end
659
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
660
- tl.store(token_pool + save_offset, data, mask=mask)
700
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
701
+ mask = copy_offset < copy_len
702
+ data = tl.load(out_cache_ptr + copy_offset, mask=mask)
703
+ tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
704
+
705
+ if page_size == 1 or topk == 1:
706
+ return
707
+
708
+ # Part 2: Copy the indices for the last partial page
709
+ prefix_len = tl.load(seq_lens + pid)
710
+ last_page_len = prefix_len % page_size
711
+ offsets = tl.arange(0, page_size)
712
+ mask = offsets < last_page_len
713
+ num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
714
+ prefix_base = token_pool + prefix_len - last_page_len
715
+
716
+ for topk_id in range(topk):
717
+ value = tl.load(prefix_base + offsets, mask=mask)
718
+ tl.store(
719
+ prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
720
+ value,
721
+ mask=mask,
722
+ )
723
+
724
+ # Part 3: Remove the padding in out_cache_loc
725
+ iter_offest = tl.arange(0, iter_upper)
726
+ for topk_id in range(topk):
727
+ indices = tl.load(
728
+ prefix_base
729
+ + topk_id * num_new_pages_per_topk_ * page_size
730
+ + last_page_len
731
+ + iter_offest,
732
+ mask=iter_offest < speculative_num_steps,
733
+ )
734
+ tl.store(
735
+ out_cache_loc
736
+ + pid * topk * speculative_num_steps
737
+ + topk_id * speculative_num_steps
738
+ + iter_offest,
739
+ indices,
740
+ mask=iter_offest < speculative_num_steps,
741
+ )
661
742
 
662
743
 
663
744
  @triton.jit
@@ -668,29 +749,33 @@ def generate_draft_decode_kv_indices(
668
749
  kv_indices,
669
750
  kv_indptr,
670
751
  positions,
671
- num_seqs: tl.constexpr,
672
- topk: tl.constexpr,
673
752
  pool_len: tl.constexpr,
674
753
  kv_indices_stride: tl.constexpr,
675
754
  kv_indptr_stride: tl.constexpr,
676
755
  bs_upper: tl.constexpr,
677
756
  iter_upper: tl.constexpr,
678
757
  num_tokens_upper: tl.constexpr,
758
+ page_size: tl.constexpr,
679
759
  ):
680
760
  BLOCK_SIZE: tl.constexpr = 128
681
761
  iters = tl.program_id(axis=0)
682
762
  bid = tl.program_id(axis=1)
683
763
  topk_id = tl.program_id(axis=2)
684
764
 
765
+ num_steps = tl.num_programs(axis=0)
766
+ num_seqs = tl.num_programs(axis=1)
767
+ topk = tl.num_programs(axis=2)
768
+
685
769
  kv_indices += kv_indices_stride * iters
686
770
  kv_indptr += kv_indptr_stride * iters
687
771
  iters += 1
688
772
 
689
773
  load_offset = tl.arange(0, bs_upper)
690
- seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
774
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
691
775
  seq_len = tl.load(paged_kernel_lens + bid)
692
776
  cum_seq_len = tl.sum(seq_lens)
693
777
 
778
+ # Update kv_indices
694
779
  kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
695
780
  kv_ptr = kv_indices + kv_offset
696
781
  token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
@@ -704,10 +789,26 @@ def generate_draft_decode_kv_indices(
704
789
  kv_offset += BLOCK_SIZE
705
790
 
706
791
  extend_offset = tl.arange(0, iter_upper)
707
- extend_data = tl.load(
708
- token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
709
- mask=extend_offset < iters,
710
- )
792
+ if page_size == 1 or topk == 1:
793
+ extend_data = tl.load(
794
+ token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
795
+ mask=extend_offset < iters,
796
+ )
797
+ else:
798
+ prefix_len = seq_len
799
+ last_page_len = prefix_len % page_size
800
+ num_new_pages_per_topk = (
801
+ last_page_len + num_steps + page_size - 1
802
+ ) // page_size
803
+ prefix_base = seq_len // page_size * page_size
804
+ start = (
805
+ prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
806
+ )
807
+ extend_data = tl.load(
808
+ token_pool_ptr + start + extend_offset,
809
+ mask=extend_offset < iters,
810
+ )
811
+
711
812
  tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
712
813
 
713
814
  # Update kv_indptr
@@ -716,7 +817,7 @@ def generate_draft_decode_kv_indices(
716
817
  zid = bid * topk + topk_id
717
818
  if zid == 0:
718
819
  zid = num_seqs * topk
719
- positions = tl.load(positions + bs_offset, mask=bs_offset < zid)
820
+ positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
720
821
  base = tl.sum(positions)
721
822
  tl.store(kv_indptr + zid, base + zid * iters)
722
823
 
@@ -734,7 +835,9 @@ def align_evict_mask_to_page_size(
734
835
  bid = tl.program_id(axis=0)
735
836
  seq_len = tl.load(seq_lens + bid)
736
837
  io_mask = t_range < num_draft_tokens
737
- mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
838
+ mask_row = tl.load(
839
+ evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
840
+ )
738
841
 
739
842
  num_trues = tl.sum(mask_row)
740
843
  num_false = num_draft_tokens - num_trues
@@ -744,6 +847,116 @@ def align_evict_mask_to_page_size(
744
847
  tl.store(evict_mask + bid * num_draft_tokens + i, False)
745
848
 
746
849
 
850
+ @triton.jit
851
+ def get_target_cache_loc(
852
+ tgt_cache_loc,
853
+ to_free_slots,
854
+ accept_length,
855
+ to_free_num_slots,
856
+ out_cache_loc,
857
+ num_verify_tokens: tl.constexpr,
858
+ num_verify_tokens_upper: tl.constexpr,
859
+ bs_upper: tl.constexpr,
860
+ ):
861
+ bid = tl.program_id(axis=0)
862
+ offset = tl.arange(0, num_verify_tokens_upper)
863
+ bs_offset = tl.arange(0, bs_upper)
864
+
865
+ # write the first part to tgt_cache_loc
866
+ accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
867
+ tgt_cache_loc_start = tl.sum(accept_len_all) + bid
868
+ copy_len = tl.load(accept_length + bid) + 1
869
+ out_cache_loc_row = tl.load(
870
+ out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
871
+ )
872
+ tl.store(
873
+ tgt_cache_loc + tgt_cache_loc_start + offset,
874
+ out_cache_loc_row,
875
+ mask=offset < copy_len,
876
+ )
877
+
878
+ # write the second part to to_free_num_pages
879
+ to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
880
+ to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
881
+ out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
882
+ to_free_slots_start = tl.sum(to_free_num_slots_all)
883
+
884
+ copy_len = to_free_num_slots_cur
885
+ out_cache_loc_row = tl.load(
886
+ out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
887
+ mask=offset < copy_len,
888
+ )
889
+ tl.store(
890
+ to_free_slots + to_free_slots_start + offset,
891
+ out_cache_loc_row,
892
+ mask=offset < copy_len,
893
+ )
894
+
895
+
896
+ @torch.compile(dynamic=True)
897
+ def get_src_tgt_cache_loc(
898
+ seq_lens: torch.Tensor,
899
+ out_cache_loc: torch.Tensor,
900
+ accept_index: torch.Tensor,
901
+ accept_length: torch.Tensor,
902
+ draft_token_num: int,
903
+ page_size: int,
904
+ ):
905
+ src_cache_loc = out_cache_loc[accept_index]
906
+ tgt_cache_loc = torch.empty_like(src_cache_loc)
907
+ extended_len = seq_lens + draft_token_num
908
+ keep_len = torch.minimum(
909
+ (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
910
+ extended_len,
911
+ )
912
+ to_free_num_slots = extended_len - keep_len
913
+ return src_cache_loc, tgt_cache_loc, to_free_num_slots
914
+
915
+
916
+ @triton.jit
917
+ def filter_finished_cache_loc_kernel(
918
+ out_cache_loc,
919
+ tgt_cache_loc,
920
+ accept_length,
921
+ accept_length_filter,
922
+ bs_upper: tl.constexpr,
923
+ num_verify_tokens_upper: tl.constexpr,
924
+ ):
925
+ bid = tl.program_id(0)
926
+ bs_offset = tl.arange(0, bs_upper)
927
+
928
+ accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
929
+ old_start = tl.sum(accept_length_all) + bid
930
+
931
+ accept_length_filter_all = tl.load(
932
+ accept_length_filter + bs_offset, mask=bs_offset < bid
933
+ )
934
+ new_start = tl.sum(accept_length_filter_all)
935
+
936
+ copy_len = tl.load(accept_length_filter + bid)
937
+ copy_offset = tl.arange(0, num_verify_tokens_upper)
938
+ value = tl.load(
939
+ tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
940
+ )
941
+ tl.store(
942
+ out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
943
+ )
944
+
945
+
946
+ @torch.compile(dynamic=True)
947
+ def create_accept_length_filter(
948
+ accept_length: torch.Tensor,
949
+ unfinished_index_device: torch.Tensor,
950
+ seq_lens: torch.Tensor,
951
+ ):
952
+ accept_length_filter = torch.zeros_like(accept_length)
953
+ accept_length_filter[unfinished_index_device] = (
954
+ accept_length[unfinished_index_device] + 1
955
+ )
956
+ seq_lens.add_(accept_length + 1)
957
+ return accept_length_filter
958
+
959
+
747
960
  @torch.compile(dynamic=True)
748
961
  def select_top_k_tokens(
749
962
  i: int,
@@ -802,15 +1015,35 @@ def _generate_simulated_accept_index(
802
1015
  spec_steps,
803
1016
  ):
804
1017
  simulate_acc_len_float = float(simulate_acc_len)
805
- simulated_values = torch.normal(
806
- mean=simulate_acc_len_float,
807
- std=1.0,
808
- size=(1,),
809
- device="cpu",
810
- )
811
- # clamp simulated values to be between 1 and self.spec_steps
812
- simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
813
- simulate_acc_len = int(simulated_values.round().item())
1018
+ if SIMULATE_ACC_METHOD == "multinomial":
1019
+ simulated_values = torch.normal(
1020
+ mean=simulate_acc_len_float,
1021
+ std=1.0,
1022
+ size=(1,),
1023
+ device="cpu",
1024
+ )
1025
+ # clamp simulated values to be between 1 and self.spec_steps
1026
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
1027
+ simulate_acc_len = int(simulated_values.round().item())
1028
+ elif SIMULATE_ACC_METHOD == "match-expected":
1029
+ # multinomial sampling does not match the expected length
1030
+ # we keep it for the sake of compatibility of existing tests
1031
+ # but it's better to use "match-expected" for the cases that need to
1032
+ # match the expected length, One caveat is that this will only sample
1033
+ # either round down or round up of the expected length
1034
+ simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
1035
+ lower = int(simulate_acc_len_float // 1)
1036
+ upper = lower + 1 if lower < spec_steps + 1 else lower
1037
+ if lower == upper:
1038
+ simulate_acc_len = lower
1039
+ else:
1040
+ weight_upper = simulate_acc_len_float - lower
1041
+ weight_lower = 1.0 - weight_upper
1042
+ probs = torch.tensor([weight_lower, weight_upper], device="cpu")
1043
+ sampled_index = torch.multinomial(probs, num_samples=1)
1044
+ simulate_acc_len = lower if sampled_index == 0 else upper
1045
+ else:
1046
+ raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
814
1047
 
815
1048
  accept_indx_first_col = accept_index[:, 0].view(-1, 1)
816
1049
  sim_accept_index = torch.full(
@@ -901,9 +1134,9 @@ def generate_token_bitmask(
901
1134
  """
902
1135
  Generate the logit mask for structured output.
903
1136
  Draft model's token can be either valid or invalid with respect to the grammar.
904
- We need to perform DFS to figure out:
905
- 1. which tokens are accepted by the grammar
906
- 2. what is the corresponding logit mask.
1137
+ We need to perform DFS to
1138
+ 1. figure out which tokens are accepted by the grammar.
1139
+ 2. if so, what is the corresponding logit mask.
907
1140
  """
908
1141
 
909
1142
  num_draft_tokens = draft_tokens_cpu.shape[-1]
@@ -920,6 +1153,7 @@ def generate_token_bitmask(
920
1153
  device="cpu",
921
1154
  )
922
1155
  grammar = req.grammar
1156
+ s = time.perf_counter()
923
1157
  traverse_tree(
924
1158
  retrieve_next_token_cpu[i],
925
1159
  retrieve_next_sibling_cpu[i],
@@ -929,6 +1163,12 @@ def generate_token_bitmask(
929
1163
  i * num_draft_tokens : (i + 1) * num_draft_tokens
930
1164
  ],
931
1165
  )
1166
+ tree_traverse_time = time.perf_counter() - s
1167
+ if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
1168
+ logger.warning(
1169
+ f"Bit mask generation took {tree_traverse_time} seconds with "
1170
+ f"grammar: {req.grammar}"
1171
+ )
932
1172
 
933
1173
  verify_input.grammar = grammar
934
1174
  return allocate_token_bitmask