sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,1295 +1,138 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- import logging
5
- import os
6
- import time
7
- from dataclasses import dataclass
1
+ import math
2
+ from enum import IntEnum
8
3
  from typing import List, Optional
9
4
 
10
5
  import torch
11
- import torch.nn.functional as F
12
- import triton
13
- import triton.language as tl
14
6
 
15
- from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
16
- from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
17
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
18
- from sglang.srt.layers.sampler import apply_custom_logit_processor
19
- from sglang.srt.managers.schedule_batch import (
20
- Req,
21
- ScheduleBatch,
22
- get_last_loc,
23
- global_server_args_dict,
24
- )
25
- from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
- from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
27
- from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
7
+ from sglang.srt.utils import is_cuda, is_hip
28
8
 
29
- if is_cuda():
9
+ if is_cuda() or is_hip():
30
10
  from sgl_kernel import (
31
- fast_topk,
32
- top_k_renorm_prob,
33
- top_p_renorm_prob,
34
- tree_speculative_sampling_target_only,
35
- verify_tree_greedy,
36
- )
37
- elif is_hip():
38
- from sgl_kernel import fast_topk, verify_tree_greedy
39
-
40
-
41
- logger = logging.getLogger(__name__)
42
-
43
-
44
- # Simulate acceptance length for benchmarking purposes
45
- SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
46
- SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
47
-
48
- TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
49
-
50
- TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
51
-
52
-
53
- @dataclass
54
- class EagleDraftInput:
55
- # The inputs for decode
56
- # shape: (b, topk)
57
- topk_p: torch.Tensor = None
58
- topk_index: torch.Tensor = None
59
- # shape: (b, hidden_size)
60
- hidden_states: torch.Tensor = None
61
- capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
62
-
63
- # Inputs for extend
64
- # shape: (b,)
65
- verified_id: torch.Tensor = None
66
- accept_length: torch.Tensor = None
67
- accept_length_cpu: List[int] = None
68
-
69
- # Inputs for the attention backends
70
- # shape: (b + 1,)
71
- kv_indptr: torch.Tensor = None
72
- kv_indices: torch.Tensor = None
73
-
74
- # Shape info for padding
75
- num_tokens_per_batch: int = -1
76
- num_tokens_for_logprob_per_batch: int = -1
77
-
78
- # Inputs for draft extend
79
- # shape: (b,)
80
- seq_lens_for_draft_extend: torch.Tensor = None
81
- req_pool_indices_for_draft_extend: torch.Tensor = None
82
-
83
- def prepare_for_extend(self, batch: ScheduleBatch):
84
-
85
- if batch.forward_mode.is_idle():
86
- return
87
-
88
- # Prefill only generate 1 token.
89
- assert len(self.verified_id) == len(batch.seq_lens)
90
-
91
- pt = 0
92
- for i, extend_len in enumerate(batch.extend_lens):
93
- input_ids = batch.input_ids[pt : pt + extend_len]
94
- batch.input_ids[pt : pt + extend_len] = torch.cat(
95
- (input_ids[1:], self.verified_id[i].reshape(1))
96
- )
97
- pt += extend_len
98
-
99
- @classmethod
100
- def create_idle_input(
101
- cls,
102
- device: torch.device,
103
- hidden_size: int,
104
- dtype: torch.dtype,
105
- topk: int,
106
- capture_hidden_mode: CaptureHiddenMode,
107
- ):
108
- return cls(
109
- verified_id=torch.empty((0,), device=device, dtype=torch.int32),
110
- hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
111
- topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
112
- topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
113
- capture_hidden_mode=capture_hidden_mode,
114
- accept_length=torch.empty((0,), device=device, dtype=torch.int32),
115
- accept_length_cpu=[],
116
- )
117
-
118
- def prepare_extend_after_decode(
119
- self,
120
- batch: ScheduleBatch,
121
- speculative_num_steps: int,
122
- ):
123
-
124
- if batch.forward_mode.is_idle():
125
- return
126
-
127
- batch.input_ids = self.verified_id
128
- batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
129
- batch.extend_num_tokens = sum(batch.extend_lens)
130
- batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
131
- batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
132
- batch.return_logprob = False
133
- batch.return_hidden_states = False
134
-
135
- self.capture_hidden_mode = CaptureHiddenMode.LAST
136
- self.accept_length.add_(1)
137
- self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
138
- self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
139
-
140
- create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
141
- batch.input_ids,
142
- batch.seq_lens,
143
- self.accept_length,
144
- self.positions,
145
- self.verified_id,
146
- next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
147
- )
148
-
149
- def generate_attn_arg_prefill(
150
- self,
151
- req_pool_indices: torch.Tensor,
152
- paged_kernel_lens: torch.Tensor,
153
- paged_kernel_lens_sum: int,
154
- req_to_token: torch.Tensor,
155
- ):
156
- bs = self.accept_length.numel()
157
- qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
158
- qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
159
- cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
160
- cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
161
-
162
- if paged_kernel_lens_sum is None:
163
- paged_kernel_lens_sum = cum_kv_seq_len[-1]
164
-
165
- kv_indices = torch.empty(
166
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
167
- )
168
-
169
- create_flashinfer_kv_indices_triton[(bs,)](
170
- req_to_token,
171
- req_pool_indices,
172
- paged_kernel_lens,
173
- cum_kv_seq_len,
174
- None,
175
- kv_indices,
176
- req_to_token.size(1),
177
- )
178
- return kv_indices, cum_kv_seq_len, qo_indptr, None
179
-
180
- def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
181
- if has_been_filtered:
182
- # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
183
- # therefore, we don't need to filter the batch again in scheduler
184
- if len(new_indices) != len(self.topk_p):
185
- logger.warning(
186
- f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
187
- )
188
- self.topk_p = self.topk_p[: len(new_indices)]
189
- self.topk_index = self.topk_index[: len(new_indices)]
190
- self.hidden_states = self.hidden_states[: len(new_indices)]
191
- self.verified_id = self.verified_id[: len(new_indices)]
192
- else:
193
- # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
194
- self.topk_p = self.topk_p[new_indices]
195
- self.topk_index = self.topk_index[new_indices]
196
- self.hidden_states = self.hidden_states[new_indices]
197
- self.verified_id = self.verified_id[new_indices]
198
-
199
- def merge_batch(self, spec_info: EagleDraftInput):
200
- if self.hidden_states is None:
201
- self.hidden_states = spec_info.hidden_states
202
- self.verified_id = spec_info.verified_id
203
- self.topk_p = spec_info.topk_p
204
- self.topk_index = spec_info.topk_index
205
- return
206
- if spec_info.hidden_states is None:
207
- return
208
- self.hidden_states = torch.cat(
209
- [self.hidden_states, spec_info.hidden_states], axis=0
210
- )
211
- self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
212
- self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
213
- self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
214
-
215
-
216
- @dataclass
217
- class EagleVerifyOutput:
218
- # Draft input batch
219
- draft_input: EagleDraftInput
220
- # Logit outputs from target worker
221
- logits_output: LogitsProcessorOutput
222
- # Accepted token ids including the bonus token
223
- verified_id: torch.Tensor
224
- # Accepted token length per sequence in a batch in CPU.
225
- accept_length_per_req_cpu: List[int]
226
- # Accepted indices from logits_output.next_token_logits
227
- accepted_indices: torch.Tensor
228
-
229
-
230
- @dataclass
231
- class EagleVerifyInput:
232
- draft_token: torch.Tensor
233
- custom_mask: torch.Tensor
234
- positions: torch.Tensor
235
- retrive_index: torch.Tensor
236
- retrive_next_token: torch.Tensor
237
- retrive_next_sibling: torch.Tensor
238
- retrive_cum_len: torch.Tensor
239
- spec_steps: int
240
- topk: int
241
- draft_token_num: int
242
- capture_hidden_mode: CaptureHiddenMode
243
- seq_lens_sum: int
244
- seq_lens_cpu: torch.Tensor
245
- grammar: BaseGrammarObject = None
246
-
247
- @classmethod
248
- def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
249
- return cls(
250
- draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
251
- custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
252
- positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
253
- retrive_index=torch.full(
254
- (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
255
- ),
256
- retrive_next_token=torch.full(
257
- (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
258
- ),
259
- retrive_next_sibling=torch.full(
260
- (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
261
- ),
262
- retrive_cum_len=None,
263
- topk=topk,
264
- draft_token_num=num_verify_tokens,
265
- spec_steps=spec_steps,
266
- capture_hidden_mode=CaptureHiddenMode.FULL,
267
- seq_lens_sum=0,
268
- seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
269
- )
270
-
271
- def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
272
-
273
- if batch.forward_mode.is_idle():
274
- return
275
-
276
- batch.input_ids = self.draft_token
277
-
278
- if page_size == 1:
279
- batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
280
- end_offset = batch.seq_lens + self.draft_token_num
281
- else:
282
- prefix_lens = batch.seq_lens
283
- end_offset = prefix_lens + self.draft_token_num
284
- last_loc = get_last_loc(
285
- batch.req_to_token_pool.req_to_token,
286
- batch.req_pool_indices,
287
- prefix_lens,
288
- )
289
- batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
290
- prefix_lens, end_offset, last_loc, len(batch.input_ids)
291
- )
292
- self.last_loc = last_loc
293
-
294
- bs = batch.batch_size()
295
- assign_req_to_token_pool[(bs,)](
296
- batch.req_pool_indices,
297
- batch.req_to_token_pool.req_to_token,
298
- batch.seq_lens,
299
- end_offset,
300
- batch.out_cache_loc,
301
- batch.req_to_token_pool.req_to_token.shape[1],
302
- next_power_of_2(bs),
303
- )
304
-
305
- def generate_attn_arg_prefill(
306
- self,
307
- req_pool_indices: torch.Tensor,
308
- paged_kernel_lens: torch.Tensor,
309
- paged_kernel_lens_sum: int,
310
- req_to_token: torch.Tensor,
311
- ):
312
- batch_size = len(req_pool_indices)
313
- qo_indptr = torch.arange(
314
- 0,
315
- (1 + batch_size) * self.draft_token_num,
316
- step=self.draft_token_num,
317
- dtype=torch.int32,
318
- device="cuda",
319
- )
320
- cum_kv_seq_len = torch.zeros(
321
- (batch_size + 1,), dtype=torch.int32, device="cuda"
322
- )
323
-
324
- paged_kernel_lens = paged_kernel_lens + self.draft_token_num
325
- cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
326
-
327
- kv_indices = torch.empty(
328
- paged_kernel_lens_sum + self.draft_token_num * batch_size,
329
- dtype=torch.int32,
330
- device="cuda",
331
- )
332
- create_flashinfer_kv_indices_triton[(batch_size,)](
333
- req_to_token,
334
- req_pool_indices,
335
- paged_kernel_lens,
336
- cum_kv_seq_len,
337
- None,
338
- kv_indices,
339
- req_to_token.size(1),
340
- )
341
- return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
342
-
343
- def verify(
344
- self,
345
- batch: ScheduleBatch,
346
- logits_output: LogitsProcessorOutput,
347
- token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
348
- page_size: int,
349
- vocab_mask: Optional[torch.Tensor] = None, # For grammar
350
- ) -> torch.Tensor:
351
- """
352
- Verify and find accepted tokens based on logits output and batch
353
- (which contains spec decoding information).
354
-
355
- WARNING: This API in-place modifies the states of logits_output
356
-
357
- This API updates values inside logits_output based on the accepted
358
- tokens. I.e., logits_output.next_token_logits only contains
359
- accepted token logits.
360
- """
361
- if batch.forward_mode.is_idle():
362
- return EagleVerifyOutput(
363
- draft_input=EagleDraftInput.create_idle_input(
364
- device=batch.device,
365
- hidden_size=batch.model_config.hidden_size,
366
- dtype=batch.model_config.dtype,
367
- topk=self.topk,
368
- capture_hidden_mode=CaptureHiddenMode.LAST,
369
- ),
370
- logits_output=logits_output,
371
- verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
372
- accept_length_per_req_cpu=[],
373
- accepted_indices=torch.full(
374
- (0, self.spec_steps + 1),
375
- -1,
376
- dtype=torch.int32,
377
- device=batch.device,
378
- ),
379
- )
380
-
381
- bs = self.retrive_index.shape[0]
382
- candidates = self.draft_token.reshape(bs, self.draft_token_num)
383
- sampling_info = batch.sampling_info
384
-
385
- predict_shape = list(logits_output.next_token_logits.shape)[:-1]
386
- predict_shape[-1] += 1
387
- predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
388
- accept_index = torch.full(
389
- (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
390
- )
391
- accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
392
-
393
- if bs != len(sampling_info):
394
- sampling_info = copy.deepcopy(sampling_info)
395
- # NOTE: retrive_index are the indices of the requests that are kept.
396
- sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
397
-
398
- # Apply the custom logit processors if registered in the sampling info.
399
- if sampling_info.has_custom_logit_processor:
400
- apply_custom_logit_processor(
401
- logits_output.next_token_logits,
402
- sampling_info,
403
- num_tokens_in_batch=self.draft_token_num,
404
- )
405
-
406
- # Apply penalty
407
- if sampling_info.penalizer_orchestrator.is_required:
408
- # This is a relaxed version of penalties for speculative decoding.
409
- linear_penalty = torch.zeros(
410
- (bs, logits_output.next_token_logits.shape[1]),
411
- dtype=torch.float32,
412
- device="cuda",
413
- )
414
- sampling_info.apply_logits_bias(linear_penalty)
415
- logits_output.next_token_logits.add_(
416
- torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
417
- )
418
-
419
- # Apply grammar mask
420
- if vocab_mask is not None:
421
- assert self.grammar is not None
422
- self.grammar.apply_vocab_mask(
423
- logits=logits_output.next_token_logits, vocab_mask=vocab_mask
424
- )
425
-
426
- # Sample tokens. Force greedy sampling on AMD
427
- is_all_greedy = sampling_info.is_all_greedy
428
- if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
429
- logger.warning(
430
- "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
431
- "Falling back to greedy verification."
432
- )
433
-
434
- if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
435
- target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
436
- target_predict = target_predict.reshape(bs, self.draft_token_num)
437
-
438
- verify_tree_greedy(
439
- predicts=predict, # mutable
440
- accept_index=accept_index, # mutable
441
- accept_token_num=accept_length, # mutable
442
- candidates=candidates,
443
- retrive_index=self.retrive_index,
444
- retrive_next_token=self.retrive_next_token,
445
- retrive_next_sibling=self.retrive_next_sibling,
446
- target_predict=target_predict,
447
- )
448
- else:
449
- # apply temperature and get target probs
450
- expanded_temperature = torch.repeat_interleave(
451
- sampling_info.temperatures, self.draft_token_num, dim=0
452
- ) # (bs * draft_token_num, 1)
453
-
454
- target_probs = F.softmax(
455
- logits_output.next_token_logits / expanded_temperature, dim=-1
456
- ) # (bs * draft_token_num, vocab_size)
457
- target_probs = top_k_renorm_prob(
458
- target_probs,
459
- torch.repeat_interleave(
460
- sampling_info.top_ks, self.draft_token_num, dim=0
461
- ),
462
- ) # (bs * draft_token_num, vocab_size)
463
- if not torch.all(sampling_info.top_ps == 1.0):
464
- target_probs = top_p_renorm_prob(
465
- target_probs,
466
- torch.repeat_interleave(
467
- sampling_info.top_ps, self.draft_token_num, dim=0
468
- ),
469
- )
470
- target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
471
-
472
- draft_probs = torch.zeros(
473
- target_probs.shape, dtype=torch.float32, device="cuda"
474
- )
475
-
476
- # coins for rejection sampling
477
- coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
478
- # coins for final sampling
479
- coins_for_final_sampling = torch.rand(
480
- (bs,), dtype=torch.float32, device="cuda"
481
- )
482
- tree_speculative_sampling_target_only(
483
- predicts=predict, # mutable
484
- accept_index=accept_index, # mutable
485
- accept_token_num=accept_length, # mutable
486
- candidates=candidates,
487
- retrive_index=self.retrive_index,
488
- retrive_next_token=self.retrive_next_token,
489
- retrive_next_sibling=self.retrive_next_sibling,
490
- uniform_samples=coins,
491
- uniform_samples_for_final_sampling=coins_for_final_sampling,
492
- target_probs=target_probs,
493
- draft_probs=draft_probs,
494
- threshold_single=global_server_args_dict[
495
- "speculative_accept_threshold_single"
496
- ],
497
- threshold_acc=global_server_args_dict[
498
- "speculative_accept_threshold_acc"
499
- ],
500
- deterministic=True,
501
- )
502
-
503
- if SIMULATE_ACC_LEN:
504
- # Do simulation
505
- accept_index = _generate_simulated_accept_index(
506
- accept_index=accept_index,
507
- predict=predict, # mutable
508
- accept_length=accept_length, # mutable
509
- simulate_acc_len=SIMULATE_ACC_LEN,
510
- bs=bs,
511
- spec_steps=self.spec_steps,
512
- )
513
-
514
- unfinished_index = []
515
- unfinished_accept_index = []
516
- accept_index_cpu = accept_index.tolist()
517
- predict_cpu = predict.tolist()
518
- has_finished = False
519
-
520
- # Iterate every accepted token and check if req has finished after append the token
521
- # should be checked BEFORE free kv cache slots
522
- for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
523
- for j, idx in enumerate(accept_index_row):
524
- if idx == -1:
525
- break
526
- id = predict_cpu[idx]
527
- req.output_ids.append(id)
528
- req.check_finished()
529
- if req.finished():
530
- has_finished = True
531
- # set all tokens after finished token to -1 and break
532
- accept_index[i, j + 1 :] = -1
533
- break
534
- else:
535
- if req.grammar is not None:
536
- try:
537
- req.grammar.accept_token(id)
538
- except ValueError as e:
539
- logger.info(
540
- f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
541
- )
542
- raise e
543
- if not req.finished():
544
- unfinished_index.append(i)
545
- if idx == -1:
546
- unfinished_accept_index.append(accept_index[i, :j])
547
- else:
548
- unfinished_accept_index.append(accept_index[i])
549
- req.spec_verify_ct += 1
550
-
551
- if has_finished:
552
- accept_length = (accept_index != -1).sum(dim=1) - 1
553
-
554
- # Free the KV cache for unaccepted tokens
555
- # TODO: fuse them
556
- accept_index = accept_index[accept_index != -1]
557
- verified_id = predict[accept_index]
558
- evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
559
- evict_mask[accept_index] = False
560
-
561
- if page_size == 1:
562
- # TODO: boolean array index leads to a device sync. Remove it.
563
- token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
564
- else:
565
- if self.topk == 1:
566
- # Only evict full empty page. Do not evict partial empty page
567
- align_evict_mask_to_page_size[len(batch.seq_lens),](
568
- batch.seq_lens,
569
- evict_mask,
570
- page_size,
571
- self.draft_token_num,
572
- next_power_of_2(self.draft_token_num),
573
- )
574
- token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
575
- else:
576
- # Shift the accepted tokens to the beginning.
577
- # Only evict the last part
578
- src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
579
- batch.seq_lens,
580
- batch.out_cache_loc,
581
- accept_index,
582
- accept_length,
583
- self.draft_token_num,
584
- page_size,
585
- )
586
- to_free_slots = torch.empty(
587
- (to_free_num_slots.sum().item(),),
588
- dtype=torch.int64,
589
- device=to_free_num_slots.device,
590
- )
591
-
592
- # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
593
- # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
594
- # tgt_cache_loc: [0 1 , 3 4 , 6 ]
595
- # to_free_slots: [ 2, 5, 7 8]
596
- # to_free_slots also needs to be page-aligned without the first partial page
597
- #
598
- # split each row of out_cache_loc into two parts.
599
- # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
600
- # 2. the second part goes to to_free_slots.
601
- get_target_cache_loc[(bs,)](
602
- tgt_cache_loc,
603
- to_free_slots,
604
- accept_length,
605
- to_free_num_slots,
606
- batch.out_cache_loc,
607
- self.draft_token_num,
608
- next_power_of_2(self.draft_token_num),
609
- next_power_of_2(bs),
610
- )
611
-
612
- # Free the kv cache
613
- token_to_kv_pool_allocator.free(to_free_slots)
614
-
615
- # Copy the kv cache
616
- batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
617
- tgt_cache_loc, src_cache_loc
618
- )
619
-
620
- # Construct EagleVerifyOutput
621
- if not has_finished:
622
- if page_size == 1 or self.topk == 1:
623
- batch.out_cache_loc = batch.out_cache_loc[accept_index]
624
- assign_req_to_token_pool[(bs,)](
625
- batch.req_pool_indices,
626
- batch.req_to_token_pool.req_to_token,
627
- batch.seq_lens,
628
- batch.seq_lens + accept_length + 1,
629
- batch.out_cache_loc,
630
- batch.req_to_token_pool.req_to_token.shape[1],
631
- next_power_of_2(bs),
632
- )
633
- else:
634
- batch.out_cache_loc = tgt_cache_loc
635
- batch.seq_lens.add_(accept_length + 1)
636
-
637
- draft_input = EagleDraftInput(
638
- hidden_states=batch.spec_info.hidden_states[accept_index],
639
- verified_id=verified_id,
640
- accept_length=accept_length,
641
- accept_length_cpu=accept_length.tolist(),
642
- seq_lens_for_draft_extend=batch.seq_lens,
643
- req_pool_indices_for_draft_extend=batch.req_pool_indices,
644
- )
645
-
646
- return EagleVerifyOutput(
647
- draft_input=draft_input,
648
- logits_output=logits_output,
649
- verified_id=verified_id,
650
- accept_length_per_req_cpu=draft_input.accept_length_cpu,
651
- accepted_indices=accept_index,
652
- )
653
- else:
654
- if page_size == 1 or self.topk == 1:
655
- assign_req_to_token_pool[(bs,)](
656
- batch.req_pool_indices,
657
- batch.req_to_token_pool.req_to_token,
658
- batch.seq_lens,
659
- batch.seq_lens + accept_length + 1,
660
- batch.out_cache_loc[accept_index],
661
- batch.req_to_token_pool.req_to_token.shape[1],
662
- next_power_of_2(bs),
663
- )
664
- batch.seq_lens.add_(accept_length + 1)
665
-
666
- accept_length_cpu = accept_length.tolist()
667
- if len(unfinished_accept_index) > 0:
668
- unfinished_accept_index = torch.cat(unfinished_accept_index)
669
- unfinished_index_device = torch.tensor(
670
- unfinished_index, dtype=torch.int64, device=predict.device
671
- )
672
- draft_input_accept_length_cpu = [
673
- accept_length_cpu[i] for i in unfinished_index
674
- ]
675
- if page_size == 1 or self.topk == 1:
676
- batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
677
- else:
678
- batch.out_cache_loc = torch.empty(
679
- len(unfinished_index) + sum(draft_input_accept_length_cpu),
680
- dtype=torch.int64,
681
- device=predict.device,
682
- )
683
- accept_length_filter = create_accept_length_filter(
684
- accept_length,
685
- unfinished_index_device,
686
- batch.seq_lens,
687
- )
688
- filter_finished_cache_loc_kernel[(bs,)](
689
- batch.out_cache_loc,
690
- tgt_cache_loc,
691
- accept_length,
692
- accept_length_filter,
693
- next_power_of_2(bs),
694
- next_power_of_2(self.draft_token_num),
695
- )
696
-
697
- draft_input = EagleDraftInput(
698
- hidden_states=batch.spec_info.hidden_states[
699
- unfinished_accept_index
700
- ],
701
- verified_id=predict[unfinished_accept_index],
702
- accept_length_cpu=draft_input_accept_length_cpu,
703
- accept_length=accept_length[unfinished_index_device],
704
- seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
705
- req_pool_indices_for_draft_extend=batch.req_pool_indices[
706
- unfinished_index_device
707
- ],
708
- )
709
- else:
710
- draft_input = EagleDraftInput.create_idle_input(
711
- device=batch.device,
712
- hidden_size=batch.model_config.hidden_size,
713
- dtype=batch.model_config.dtype,
714
- topk=self.topk,
715
- capture_hidden_mode=CaptureHiddenMode.LAST,
716
- )
717
-
718
- return EagleVerifyOutput(
719
- draft_input=draft_input,
720
- logits_output=logits_output,
721
- verified_id=verified_id,
722
- accept_length_per_req_cpu=accept_length_cpu,
723
- accepted_indices=accept_index,
724
- )
725
-
726
-
727
- @triton.jit
728
- def create_extend_after_decode_spec_info(
729
- verified_id,
730
- seq_lens,
731
- accept_lens,
732
- positions,
733
- new_verified_id,
734
- bs_upper: tl.constexpr,
735
- ):
736
- pid = tl.program_id(axis=0)
737
- offsets = tl.arange(0, bs_upper)
738
- seq_length = tl.load(seq_lens + pid)
739
- accept_length = tl.load(accept_lens + pid)
740
-
741
- accept_len_cumsum = tl.sum(
742
- tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
11
+ build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
743
12
  )
744
- positions_ptr = positions + accept_len_cumsum
745
- mask = offsets < accept_length
746
- tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
747
-
748
- accept_len_cumsum += accept_length - 1
749
- verified_id_data = tl.load(verified_id + accept_len_cumsum)
750
- tl.store(new_verified_id + pid, verified_id_data)
751
13
 
752
14
 
753
- @triton.jit
754
- def assign_req_to_token_pool(
755
- req_pool_indices,
756
- req_to_token,
757
- start_offset,
758
- end_offset,
759
- out_cache_loc,
760
- pool_len: tl.constexpr,
761
- bs_upper: tl.constexpr,
15
+ def organize_draft_results(
16
+ score_list: List[torch.Tensor],
17
+ token_list: List[torch.Tensor],
18
+ parents_list: List[torch.Tensor],
19
+ num_draft_token: int,
762
20
  ):
763
- BLOCK_SIZE: tl.constexpr = 32
764
- pid = tl.program_id(axis=0)
765
- kv_start = tl.load(start_offset + pid)
766
- kv_end = tl.load(end_offset + pid)
767
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
768
-
769
- length_offset = tl.arange(0, bs_upper)
770
- start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
771
- end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
772
- out_offset = tl.sum(end - start, axis=0)
773
-
774
- out_cache_ptr = out_cache_loc + out_offset
775
-
776
- save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
777
- load_offset = tl.arange(0, BLOCK_SIZE)
778
-
779
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
780
- for _ in range(num_loop):
781
- mask = save_offset < kv_end
782
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
783
- tl.store(token_pool + save_offset, data, mask=mask)
784
- save_offset += BLOCK_SIZE
785
- load_offset += BLOCK_SIZE
786
-
787
-
788
- @triton.jit
789
- def assign_draft_cache_locs(
790
- req_pool_indices,
791
- req_to_token,
792
- seq_lens,
793
- extend_lens,
794
- num_new_pages_per_topk,
795
- out_cache_loc,
796
- pool_len: tl.constexpr,
797
- topk: tl.constexpr,
798
- speculative_num_steps: tl.constexpr,
799
- page_size: tl.constexpr,
800
- bs_upper: tl.constexpr,
801
- iter_upper: tl.constexpr,
802
- ):
803
- BLOCK_SIZE: tl.constexpr = 128
804
- pid = tl.program_id(axis=0)
805
-
806
- if page_size == 1 or topk == 1:
807
- copy_len = topk * speculative_num_steps
808
- out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
21
+ score_list = torch.cat(score_list, dim=1).flatten(1)
22
+ ss_token_list = torch.cat(token_list, dim=1)
23
+ top_scores = torch.topk(score_list, num_draft_token - 1, dim=-1)
24
+ top_scores_index = top_scores.indices
25
+ top_scores_index = torch.sort(top_scores_index).values
26
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
27
+
28
+ if len(parents_list) > 1:
29
+ parent_list = torch.cat(parents_list[:-1], dim=1)
809
30
  else:
810
- bs_offset = tl.arange(0, bs_upper)
811
- copy_len = tl.load(extend_lens + pid)
812
- cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
813
- out_cache_ptr = out_cache_loc + cum_copy_len
814
-
815
- # Part 1: Copy from out_cache_loc to req_to_token
816
- kv_start = tl.load(seq_lens + pid)
817
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
818
- num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
819
- for i in range(num_loop):
820
- copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
821
- mask = copy_offset < copy_len
822
- data = tl.load(out_cache_ptr + copy_offset, mask=mask)
823
- tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
824
-
825
- if page_size == 1 or topk == 1:
826
- return
31
+ batch_size = parents_list[0].shape[0]
32
+ parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
827
33
 
828
- # Part 2: Copy the indices for the last partial page
829
- prefix_len = tl.load(seq_lens + pid)
830
- last_page_len = prefix_len % page_size
831
- offsets = tl.arange(0, page_size)
832
- mask = offsets < last_page_len
833
- num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
834
- prefix_base = token_pool + prefix_len - last_page_len
34
+ return parent_list, top_scores_index, draft_tokens
835
35
 
836
- for topk_id in range(topk):
837
- value = tl.load(prefix_base + offsets, mask=mask)
838
- tl.store(
839
- prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
840
- value,
841
- mask=mask,
842
- )
843
36
 
844
- # Part 3: Remove the padding in out_cache_loc
845
- iter_offest = tl.arange(0, iter_upper)
846
- for topk_id in range(topk):
847
- indices = tl.load(
848
- prefix_base
849
- + topk_id * num_new_pages_per_topk_ * page_size
850
- + last_page_len
851
- + iter_offest,
852
- mask=iter_offest < speculative_num_steps,
853
- )
854
- tl.store(
855
- out_cache_loc
856
- + pid * topk * speculative_num_steps
857
- + topk_id * speculative_num_steps
858
- + iter_offest,
859
- indices,
860
- mask=iter_offest < speculative_num_steps,
861
- )
37
+ class TreeMaskMode(IntEnum):
38
+ FULL_MASK = 0
39
+ QLEN_ONLY = 1
40
+ QLEN_ONLY_BITPACKING = 2
862
41
 
863
42
 
864
- @triton.jit
865
- def generate_draft_decode_kv_indices(
866
- req_pool_indices,
867
- req_to_token,
868
- paged_kernel_lens,
869
- kv_indices,
870
- kv_indptr,
871
- positions,
872
- pool_len: tl.constexpr,
873
- kv_indices_stride: tl.constexpr,
874
- kv_indptr_stride: tl.constexpr,
875
- bs_upper: tl.constexpr,
876
- iter_upper: tl.constexpr,
877
- num_tokens_upper: tl.constexpr,
878
- page_size: tl.constexpr,
879
- ):
880
- BLOCK_SIZE: tl.constexpr = 128
881
- iters = tl.program_id(axis=0)
882
- bid = tl.program_id(axis=1)
883
- topk_id = tl.program_id(axis=2)
884
-
885
- num_steps = tl.num_programs(axis=0)
886
- num_seqs = tl.num_programs(axis=1)
887
- topk = tl.num_programs(axis=2)
888
-
889
- kv_indices += kv_indices_stride * iters
890
- kv_indptr += kv_indptr_stride * iters
891
- iters += 1
892
-
893
- load_offset = tl.arange(0, bs_upper)
894
- seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
895
- seq_len = tl.load(paged_kernel_lens + bid)
896
- cum_seq_len = tl.sum(seq_lens)
897
-
898
- # Update kv_indices
899
- kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
900
- kv_ptr = kv_indices + kv_offset
901
- token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
902
-
903
- kv_offset = tl.arange(0, BLOCK_SIZE)
904
- num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
905
- for _ in range(num_loop):
906
- mask = kv_offset < seq_len
907
- data = tl.load(token_pool_ptr + kv_offset, mask=mask)
908
- tl.store(kv_ptr + kv_offset, data, mask=mask)
909
- kv_offset += BLOCK_SIZE
910
-
911
- extend_offset = tl.arange(0, iter_upper)
912
- if page_size == 1 or topk == 1:
913
- extend_data = tl.load(
914
- token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
915
- mask=extend_offset < iters,
916
- )
917
- else:
918
- prefix_len = seq_len
919
- last_page_len = prefix_len % page_size
920
- num_new_pages_per_topk = (
921
- last_page_len + num_steps + page_size - 1
922
- ) // page_size
923
- prefix_base = seq_len // page_size * page_size
924
- start = (
925
- prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
926
- )
927
- extend_data = tl.load(
928
- token_pool_ptr + start + extend_offset,
929
- mask=extend_offset < iters,
930
- )
931
-
932
- tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
933
-
934
- # Update kv_indptr
935
- bs_offset = tl.arange(0, num_tokens_upper)
936
-
937
- zid = bid * topk + topk_id
938
- if zid == 0:
939
- zid = num_seqs * topk
940
- positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
941
- base = tl.sum(positions)
942
- tl.store(kv_indptr + zid, base + zid * iters)
943
-
944
-
945
- @triton.jit
946
- def align_evict_mask_to_page_size(
947
- seq_lens,
948
- evict_mask,
949
- page_size: tl.constexpr,
950
- num_draft_tokens: tl.constexpr,
951
- BLOCK_SIZE: tl.constexpr,
952
- ):
953
- t_range = tl.arange(0, BLOCK_SIZE)
954
-
955
- bid = tl.program_id(axis=0)
956
- seq_len = tl.load(seq_lens + bid)
957
- io_mask = t_range < num_draft_tokens
958
- mask_row = tl.load(
959
- evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
960
- )
961
-
962
- num_trues = tl.sum(mask_row)
963
- num_false = num_draft_tokens - num_trues
964
-
965
- start = (seq_len + num_false - 1) // page_size * page_size - seq_len
966
- for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
967
- tl.store(evict_mask + bid * num_draft_tokens + i, False)
968
-
969
-
970
- @triton.jit
971
- def get_target_cache_loc(
972
- tgt_cache_loc,
973
- to_free_slots,
974
- accept_length,
975
- to_free_num_slots,
976
- out_cache_loc,
977
- num_verify_tokens: tl.constexpr,
978
- num_verify_tokens_upper: tl.constexpr,
979
- bs_upper: tl.constexpr,
980
- ):
981
- bid = tl.program_id(axis=0)
982
- offset = tl.arange(0, num_verify_tokens_upper)
983
- bs_offset = tl.arange(0, bs_upper)
984
-
985
- # write the first part to tgt_cache_loc
986
- accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
987
- tgt_cache_loc_start = tl.sum(accept_len_all) + bid
988
- copy_len = tl.load(accept_length + bid) + 1
989
- out_cache_loc_row = tl.load(
990
- out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
991
- )
992
- tl.store(
993
- tgt_cache_loc + tgt_cache_loc_start + offset,
994
- out_cache_loc_row,
995
- mask=offset < copy_len,
996
- )
997
-
998
- # write the second part to to_free_num_pages
999
- to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
1000
- to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
1001
- out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
1002
- to_free_slots_start = tl.sum(to_free_num_slots_all)
1003
-
1004
- copy_len = to_free_num_slots_cur
1005
- out_cache_loc_row = tl.load(
1006
- out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
1007
- mask=offset < copy_len,
1008
- )
1009
- tl.store(
1010
- to_free_slots + to_free_slots_start + offset,
1011
- out_cache_loc_row,
1012
- mask=offset < copy_len,
1013
- )
1014
-
1015
-
1016
- @torch.compile(dynamic=True)
1017
- def get_src_tgt_cache_loc(
1018
- seq_lens: torch.Tensor,
1019
- out_cache_loc: torch.Tensor,
1020
- accept_index: torch.Tensor,
1021
- accept_length: torch.Tensor,
1022
- draft_token_num: int,
1023
- page_size: int,
1024
- ):
1025
- src_cache_loc = out_cache_loc[accept_index]
1026
- tgt_cache_loc = torch.empty_like(src_cache_loc)
1027
- extended_len = seq_lens + draft_token_num
1028
- keep_len = torch.minimum(
1029
- (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
1030
- extended_len,
1031
- )
1032
- to_free_num_slots = extended_len - keep_len
1033
- return src_cache_loc, tgt_cache_loc, to_free_num_slots
1034
-
1035
-
1036
- @triton.jit
1037
- def filter_finished_cache_loc_kernel(
1038
- out_cache_loc,
1039
- tgt_cache_loc,
1040
- accept_length,
1041
- accept_length_filter,
1042
- bs_upper: tl.constexpr,
1043
- num_verify_tokens_upper: tl.constexpr,
1044
- ):
1045
- bid = tl.program_id(0)
1046
- bs_offset = tl.arange(0, bs_upper)
1047
-
1048
- accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
1049
- old_start = tl.sum(accept_length_all) + bid
1050
-
1051
- accept_length_filter_all = tl.load(
1052
- accept_length_filter + bs_offset, mask=bs_offset < bid
1053
- )
1054
- new_start = tl.sum(accept_length_filter_all)
1055
-
1056
- copy_len = tl.load(accept_length_filter + bid)
1057
- copy_offset = tl.arange(0, num_verify_tokens_upper)
1058
- value = tl.load(
1059
- tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
1060
- )
1061
- tl.store(
1062
- out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
1063
- )
1064
-
1065
-
1066
- @torch.compile(dynamic=True)
1067
- def create_accept_length_filter(
1068
- accept_length: torch.Tensor,
1069
- unfinished_index_device: torch.Tensor,
43
+ def build_tree_kernel_efficient(
44
+ verified_id: torch.Tensor,
45
+ parent_list: List[torch.Tensor],
46
+ top_scores_index: torch.Tensor,
47
+ draft_tokens: torch.Tensor,
1070
48
  seq_lens: torch.Tensor,
1071
- ):
1072
- accept_length_filter = torch.zeros_like(accept_length)
1073
- accept_length_filter[unfinished_index_device] = (
1074
- accept_length[unfinished_index_device] + 1
1075
- )
1076
- seq_lens.add_(accept_length + 1)
1077
- return accept_length_filter
1078
-
1079
-
1080
- @torch.compile(dynamic=True)
1081
- def select_top_k_tokens(
1082
- i: int,
1083
- topk_p: torch.Tensor,
1084
- topk_index: torch.Tensor,
1085
- hidden_states: torch.Tensor,
1086
- scores: torch.Tensor,
49
+ seq_lens_sum: int,
1087
50
  topk: int,
51
+ spec_steps: int,
52
+ num_verify_tokens: int,
53
+ tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
54
+ tree_mask_buf: Optional[torch.Tensor] = None,
55
+ position_buf: Optional[torch.Tensor] = None,
1088
56
  ):
1089
- if i == 0:
1090
- # The first step after extend
1091
- input_ids = topk_index.flatten()
1092
- hidden_states = hidden_states.repeat_interleave(topk, dim=0)
1093
- scores = topk_p # shape: (b, topk)
1094
-
1095
- tree_info = (
1096
- topk_p.unsqueeze(1), # shape: (b, 1, topk)
1097
- topk_index, # shape: (b, topk)
1098
- torch.arange(-1, topk, dtype=torch.long, device="cuda")
1099
- .unsqueeze(0)
1100
- .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
1101
- )
1102
- else:
1103
- # The later decode steps
1104
- expand_scores = torch.mul(
1105
- scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
1106
- ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
1107
- topk_cs_p, topk_cs_index = fast_topk(
1108
- expand_scores.flatten(start_dim=1), topk, dim=-1
1109
- ) # (b, topk)
1110
- scores = topk_cs_p # shape: (b, topk)
1111
-
1112
- topk_index = topk_index.reshape(-1, topk**2)
1113
- input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
1114
-
1115
- if hidden_states.shape[0] > 0:
1116
- selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
1117
- 0, hidden_states.shape[0], step=topk, device="cuda"
1118
- ).repeat_interleave(topk)
1119
- hidden_states = hidden_states[selected_input_index, :]
1120
-
1121
- tree_info = (
1122
- expand_scores, # shape: (b, topk, topk)
1123
- topk_index, # shape: (b, topk * topk)
1124
- topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
1125
- )
1126
-
1127
- return input_ids, hidden_states, scores, tree_info
1128
-
1129
-
1130
- def _generate_simulated_accept_index(
1131
- accept_index,
1132
- predict,
1133
- accept_length,
1134
- simulate_acc_len,
1135
- bs,
1136
- spec_steps,
1137
- ):
1138
- simulate_acc_len_float = float(simulate_acc_len)
1139
- if SIMULATE_ACC_METHOD == "multinomial":
1140
- simulated_values = torch.normal(
1141
- mean=simulate_acc_len_float,
1142
- std=1.0,
1143
- size=(1,),
1144
- device="cpu",
1145
- )
1146
- # clamp simulated values to be between 1 and self.spec_steps
1147
- simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
1148
- simulate_acc_len = int(simulated_values.round().item())
1149
- elif SIMULATE_ACC_METHOD == "match-expected":
1150
- # multinomial sampling does not match the expected length
1151
- # we keep it for the sake of compatibility of existing tests
1152
- # but it's better to use "match-expected" for the cases that need to
1153
- # match the expected length, One caveat is that this will only sample
1154
- # either round down or round up of the expected length
1155
- simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
1156
- lower = int(simulate_acc_len_float // 1)
1157
- upper = lower + 1 if lower < spec_steps + 1 else lower
1158
- if lower == upper:
1159
- simulate_acc_len = lower
57
+ draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
58
+
59
+ # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
60
+ bs = seq_lens.numel()
61
+ device = seq_lens.device
62
+ # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
63
+ # where each row indicates the attending pattern of each draft token
64
+ # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
65
+ if tree_mask_buf is not None:
66
+ tree_mask = tree_mask_buf
67
+ if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
68
+ tree_mask.fill_(True)
69
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
70
+ tree_mask.fill_(0)
71
+ elif tree_mask_mode == TreeMaskMode.FULL_MASK:
72
+ tree_mask.fill_(True)
1160
73
  else:
1161
- weight_upper = simulate_acc_len_float - lower
1162
- weight_lower = 1.0 - weight_upper
1163
- probs = torch.tensor([weight_lower, weight_upper], device="cpu")
1164
- sampled_index = torch.multinomial(probs, num_samples=1)
1165
- simulate_acc_len = lower if sampled_index == 0 else upper
74
+ raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
75
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
76
+ tree_mask = torch.full(
77
+ (num_verify_tokens * bs * num_verify_tokens,),
78
+ True,
79
+ dtype=torch.bool,
80
+ device=device,
81
+ )
82
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
83
+ packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
84
+ packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
85
+ tree_mask = torch.zeros(
86
+ (num_verify_tokens * bs,),
87
+ dtype=packed_dtypes[packed_dtype_idx],
88
+ device=device,
89
+ )
90
+ elif tree_mask_mode == TreeMaskMode.FULL_MASK:
91
+ tree_mask = torch.full(
92
+ (
93
+ seq_lens_sum * num_verify_tokens
94
+ + num_verify_tokens * num_verify_tokens * bs,
95
+ ),
96
+ True,
97
+ device=device,
98
+ )
1166
99
  else:
1167
- raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
100
+ raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
1168
101
 
1169
- accept_indx_first_col = accept_index[:, 0].view(-1, 1)
1170
- sim_accept_index = torch.full(
1171
- (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
102
+ # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
103
+ retrive_buf = torch.full(
104
+ (3, bs, num_verify_tokens), -1, device=device, dtype=torch.long
1172
105
  )
1173
- sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
1174
- simulate_acc_len, device=accept_index.device
106
+ retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf
107
+ # position: where each token belongs to
108
+ # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
109
+ # then, positions = [7, 8, 8, 9]
110
+ if position_buf is not None:
111
+ positions = position_buf
112
+ else:
113
+ positions = torch.empty(
114
+ (bs * num_verify_tokens,), device=device, dtype=torch.long
115
+ )
116
+
117
+ sgl_build_tree_kernel_efficient(
118
+ parent_list,
119
+ top_scores_index,
120
+ seq_lens,
121
+ tree_mask,
122
+ positions,
123
+ retrive_index,
124
+ retrive_next_token,
125
+ retrive_next_sibling,
126
+ topk,
127
+ spec_steps,
128
+ num_verify_tokens,
129
+ tree_mask_mode,
1175
130
  )
1176
- accept_length.fill_(simulate_acc_len - 1)
1177
- predict.fill_(100) # some legit token id
1178
- return sim_accept_index
1179
-
1180
-
1181
- def traverse_tree(
1182
- retrieve_next_token: torch.Tensor,
1183
- retrieve_next_sibling: torch.Tensor,
1184
- draft_tokens: torch.Tensor,
1185
- grammar: BaseGrammarObject,
1186
- allocate_token_bitmask: torch.Tensor,
1187
- ):
1188
- """
1189
- Traverse the tree constructed by the draft model to generate the logits mask.
1190
- """
1191
- assert (
1192
- retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
131
+ return (
132
+ tree_mask,
133
+ positions,
134
+ retrive_index,
135
+ retrive_next_token,
136
+ retrive_next_sibling,
137
+ draft_tokens,
1193
138
  )
1194
-
1195
- allocate_token_bitmask.fill_(0)
1196
-
1197
- def dfs(
1198
- curr: int,
1199
- retrieve_next_token: torch.Tensor,
1200
- retrieve_next_sibling: torch.Tensor,
1201
- parent_pos: int,
1202
- ):
1203
- if curr == 0:
1204
- # the first token generated by the target model, and thus it is always
1205
- # accepted from the previous iteration
1206
- accepted = True
1207
- else:
1208
- parent_bitmask = allocate_token_bitmask[parent_pos]
1209
- curr_token_id = draft_tokens[curr]
1210
- # 32 boolean bitmask values are packed into 32-bit integers
1211
- accepted = (
1212
- parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
1213
- ) != 0
1214
-
1215
- if accepted:
1216
- if curr != 0:
1217
- # Accept the current token
1218
- grammar.accept_token(draft_tokens[curr])
1219
- if not grammar.is_terminated():
1220
- # Generate the bitmask for the current token
1221
- grammar.fill_vocab_mask(allocate_token_bitmask, curr)
1222
- if retrieve_next_token[curr] != -1:
1223
- # Visit the child node
1224
- dfs(
1225
- retrieve_next_token[curr],
1226
- retrieve_next_token,
1227
- retrieve_next_sibling,
1228
- curr,
1229
- )
1230
-
1231
- if curr != 0:
1232
- # Rollback the current token
1233
- grammar.rollback(1)
1234
-
1235
- if retrieve_next_sibling[curr] != -1:
1236
- # Visit the sibling node
1237
- dfs(
1238
- retrieve_next_sibling[curr],
1239
- retrieve_next_token,
1240
- retrieve_next_sibling,
1241
- parent_pos,
1242
- )
1243
-
1244
- dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
1245
-
1246
-
1247
- def generate_token_bitmask(
1248
- reqs: List[Req],
1249
- verify_input: EagleVerifyInput,
1250
- retrieve_next_token_cpu: torch.Tensor,
1251
- retrieve_next_sibling_cpu: torch.Tensor,
1252
- draft_tokens_cpu: torch.Tensor,
1253
- vocab_size: int,
1254
- ):
1255
- """
1256
- Generate the logit mask for structured output.
1257
- Draft model's token can be either valid or invalid with respect to the grammar.
1258
- We need to perform DFS to
1259
- 1. figure out which tokens are accepted by the grammar.
1260
- 2. if so, what is the corresponding logit mask.
1261
- """
1262
-
1263
- num_draft_tokens = draft_tokens_cpu.shape[-1]
1264
-
1265
- allocate_token_bitmask = None
1266
- assert len(reqs) == retrieve_next_token_cpu.shape[0]
1267
- grammar = None
1268
- for i, req in enumerate(reqs):
1269
- if req.grammar is not None:
1270
- if allocate_token_bitmask is None:
1271
- allocate_token_bitmask = req.grammar.allocate_vocab_mask(
1272
- vocab_size=vocab_size,
1273
- batch_size=draft_tokens_cpu.numel(),
1274
- device="cpu",
1275
- )
1276
- grammar = req.grammar
1277
- s = time.perf_counter()
1278
- traverse_tree(
1279
- retrieve_next_token_cpu[i],
1280
- retrieve_next_sibling_cpu[i],
1281
- draft_tokens_cpu[i],
1282
- req.grammar,
1283
- allocate_token_bitmask[
1284
- i * num_draft_tokens : (i + 1) * num_draft_tokens
1285
- ],
1286
- )
1287
- tree_traverse_time = time.perf_counter() - s
1288
- if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
1289
- logger.warning(
1290
- f"Bit mask generation took {tree_traverse_time} seconds with "
1291
- f"grammar: {req.grammar}"
1292
- )
1293
-
1294
- verify_input.grammar = grammar
1295
- return allocate_token_bitmask