sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,236 +1,52 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
1
  import logging
5
- import os
6
- import time
2
+ from copy import copy
7
3
  from dataclasses import dataclass
8
- from typing import List, Optional
4
+ from typing import List, Optional, Tuple
9
5
 
10
6
  import torch
11
7
  import torch.nn.functional as F
12
- import triton
13
- import triton.language as tl
14
8
 
15
9
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
16
10
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
17
11
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
18
12
  from sglang.srt.layers.sampler import apply_custom_logit_processor
19
13
  from sglang.srt.managers.schedule_batch import (
20
- Req,
21
14
  ScheduleBatch,
22
15
  get_last_loc,
23
16
  global_server_args_dict,
24
17
  )
25
18
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
- from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
19
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
20
+ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
21
+ from sglang.srt.speculative.spec_utils import (
22
+ SIMULATE_ACC_LEN,
23
+ TREE_SPEC_KERNEL_AVAILABLE,
24
+ _generate_simulated_accept_index,
25
+ align_evict_mask_to_page_size,
26
+ assign_req_to_token_pool,
27
+ create_accept_length_filter,
28
+ create_extend_after_decode_spec_info,
29
+ filter_finished_cache_loc_kernel,
30
+ get_src_tgt_cache_loc,
31
+ get_target_cache_loc,
32
+ )
27
33
  from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
28
34
 
29
- logger = logging.getLogger(__name__)
30
-
31
35
  if is_cuda():
32
36
  from sgl_kernel import (
33
- fast_topk,
34
37
  top_k_renorm_prob,
35
38
  top_p_renorm_prob,
36
39
  tree_speculative_sampling_target_only,
37
40
  verify_tree_greedy,
38
41
  )
39
42
  elif is_hip():
40
- from sgl_kernel import fast_topk, verify_tree_greedy
41
-
43
+ from sgl_kernel import verify_tree_greedy
42
44
 
43
45
  logger = logging.getLogger(__name__)
44
46
 
45
47
 
46
- # Simulate acceptance length for benchmarking purposes
47
- SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
48
- SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
49
-
50
- TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
51
-
52
- TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
53
-
54
-
55
- @dataclass
56
- class EagleDraftInput:
57
- # The inputs for decode
58
- # shape: (b, topk)
59
- topk_p: torch.Tensor = None
60
- topk_index: torch.Tensor = None
61
- # shape: (b, hidden_size)
62
- hidden_states: torch.Tensor = None
63
- capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
64
-
65
- # Inputs for extend
66
- # shape: (b,)
67
- verified_id: torch.Tensor = None
68
- accept_length: torch.Tensor = None
69
- accept_length_cpu: List[int] = None
70
-
71
- # Inputs for the attention backends
72
- # shape: (b + 1,)
73
- kv_indptr: torch.Tensor = None
74
- kv_indices: torch.Tensor = None
75
-
76
- # Shape info for padding
77
- num_tokens_per_batch: int = -1
78
- num_tokens_for_logprob_per_batch: int = -1
79
-
80
- # Inputs for draft extend
81
- # shape: (b,)
82
- seq_lens_for_draft_extend: torch.Tensor = None
83
- req_pool_indices_for_draft_extend: torch.Tensor = None
84
-
85
- def prepare_for_extend(self, batch: ScheduleBatch):
86
-
87
- if batch.forward_mode.is_idle():
88
- return
89
-
90
- # Prefill only generate 1 token.
91
- assert len(self.verified_id) == len(batch.seq_lens)
92
-
93
- pt = 0
94
- for i, extend_len in enumerate(batch.extend_lens):
95
- input_ids = batch.input_ids[pt : pt + extend_len]
96
- batch.input_ids[pt : pt + extend_len] = torch.cat(
97
- (input_ids[1:], self.verified_id[i].reshape(1))
98
- )
99
- pt += extend_len
100
-
101
- @classmethod
102
- def create_idle_input(
103
- cls,
104
- device: torch.device,
105
- hidden_size: int,
106
- dtype: torch.dtype,
107
- topk: int,
108
- capture_hidden_mode: CaptureHiddenMode,
109
- ):
110
- return cls(
111
- verified_id=torch.empty((0,), device=device, dtype=torch.int32),
112
- hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
113
- topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
114
- topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
115
- capture_hidden_mode=capture_hidden_mode,
116
- accept_length=torch.empty((0,), device=device, dtype=torch.int32),
117
- accept_length_cpu=[],
118
- )
119
-
120
- def prepare_extend_after_decode(
121
- self,
122
- batch: ScheduleBatch,
123
- speculative_num_steps: int,
124
- ):
125
-
126
- if batch.forward_mode.is_idle():
127
- return
128
-
129
- batch.input_ids = self.verified_id
130
- batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
131
- batch.extend_num_tokens = sum(batch.extend_lens)
132
- batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
133
- batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
134
- batch.return_logprob = False
135
- batch.return_hidden_states = False
136
-
137
- self.capture_hidden_mode = CaptureHiddenMode.LAST
138
- self.accept_length.add_(1)
139
- self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
140
- self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
141
-
142
- create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
143
- batch.input_ids,
144
- batch.seq_lens,
145
- self.accept_length,
146
- self.positions,
147
- self.verified_id,
148
- next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
149
- )
150
-
151
- def generate_attn_arg_prefill(
152
- self,
153
- req_pool_indices: torch.Tensor,
154
- paged_kernel_lens: torch.Tensor,
155
- paged_kernel_lens_sum: int,
156
- req_to_token: torch.Tensor,
157
- ):
158
- bs = self.accept_length.numel()
159
- qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
160
- qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
161
- cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
162
- cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
163
-
164
- if paged_kernel_lens_sum is None:
165
- paged_kernel_lens_sum = cum_kv_seq_len[-1]
166
-
167
- kv_indices = torch.empty(
168
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
169
- )
170
-
171
- create_flashinfer_kv_indices_triton[(bs,)](
172
- req_to_token,
173
- req_pool_indices,
174
- paged_kernel_lens,
175
- cum_kv_seq_len,
176
- None,
177
- kv_indices,
178
- req_to_token.size(1),
179
- )
180
- return kv_indices, cum_kv_seq_len, qo_indptr, None
181
-
182
- def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
183
- if has_been_filtered:
184
- # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
185
- # therefore, we don't need to filter the batch again in scheduler
186
- if len(new_indices) != len(self.topk_p):
187
- logger.warning(
188
- f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
189
- )
190
- self.topk_p = self.topk_p[: len(new_indices)]
191
- self.topk_index = self.topk_index[: len(new_indices)]
192
- self.hidden_states = self.hidden_states[: len(new_indices)]
193
- self.verified_id = self.verified_id[: len(new_indices)]
194
- else:
195
- # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
196
- self.topk_p = self.topk_p[new_indices]
197
- self.topk_index = self.topk_index[new_indices]
198
- self.hidden_states = self.hidden_states[new_indices]
199
- self.verified_id = self.verified_id[new_indices]
200
-
201
- def merge_batch(self, spec_info: EagleDraftInput):
202
- if self.hidden_states is None:
203
- self.hidden_states = spec_info.hidden_states
204
- self.verified_id = spec_info.verified_id
205
- self.topk_p = spec_info.topk_p
206
- self.topk_index = spec_info.topk_index
207
- return
208
- if spec_info.hidden_states is None:
209
- return
210
- self.hidden_states = torch.cat(
211
- [self.hidden_states, spec_info.hidden_states], axis=0
212
- )
213
- self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
214
- self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
215
- self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
216
-
217
-
218
48
  @dataclass
219
- class EagleVerifyOutput:
220
- # Draft input batch
221
- draft_input: EagleDraftInput
222
- # Logit outputs from target worker
223
- logits_output: LogitsProcessorOutput
224
- # Accepted token ids including the bonus token
225
- verified_id: torch.Tensor
226
- # Accepted token length per sequence in a batch in CPU.
227
- accept_length_per_req_cpu: List[int]
228
- # Accepted indices from logits_output.next_token_logits
229
- accepted_indices: torch.Tensor
230
-
231
-
232
- @dataclass
233
- class EagleVerifyInput:
49
+ class EagleVerifyInput(SpecInput):
234
50
  draft_token: torch.Tensor
235
51
  custom_mask: torch.Tensor
236
52
  positions: torch.Tensor
@@ -246,6 +62,12 @@ class EagleVerifyInput:
246
62
  seq_lens_cpu: torch.Tensor
247
63
  grammar: BaseGrammarObject = None
248
64
 
65
+ def __post_init__(self):
66
+ super().__init__(SpecInputType.EAGLE_VERIFY)
67
+
68
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
69
+ return self.draft_token_num, self.draft_token_num
70
+
249
71
  @classmethod
250
72
  def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
251
73
  return cls(
@@ -282,14 +104,21 @@ class EagleVerifyInput:
282
104
  end_offset = batch.seq_lens + self.draft_token_num
283
105
  else:
284
106
  prefix_lens = batch.seq_lens
107
+ prefix_lens_cpu = batch.seq_lens_cpu
285
108
  end_offset = prefix_lens + self.draft_token_num
109
+ end_offset_cpu = prefix_lens_cpu + self.draft_token_num
286
110
  last_loc = get_last_loc(
287
111
  batch.req_to_token_pool.req_to_token,
288
112
  batch.req_pool_indices,
289
113
  prefix_lens,
290
114
  )
291
115
  batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
292
- prefix_lens, end_offset, last_loc, len(batch.input_ids)
116
+ prefix_lens,
117
+ prefix_lens_cpu,
118
+ end_offset,
119
+ end_offset_cpu,
120
+ last_loc,
121
+ len(batch.input_ids),
293
122
  )
294
123
  self.last_loc = last_loc
295
124
 
@@ -502,13 +331,12 @@ class EagleVerifyInput:
502
331
  deterministic=True,
503
332
  )
504
333
 
505
- if SIMULATE_ACC_LEN:
334
+ if SIMULATE_ACC_LEN > 0.0:
506
335
  # Do simulation
507
336
  accept_index = _generate_simulated_accept_index(
508
337
  accept_index=accept_index,
509
338
  predict=predict, # mutable
510
339
  accept_length=accept_length, # mutable
511
- simulate_acc_len=SIMULATE_ACC_LEN,
512
340
  bs=bs,
513
341
  spec_steps=self.spec_steps,
514
342
  )
@@ -559,6 +387,10 @@ class EagleVerifyInput:
559
387
  verified_id = predict[accept_index]
560
388
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
561
389
  evict_mask[accept_index] = False
390
+ accept_length_cpu = accept_length.cpu()
391
+ # FIXME: this `tolist()` fixes the numerical calculation consistency
392
+ # try to unify the tensor representation and list representation
393
+ accept_length_list = accept_length_cpu.tolist()
562
394
 
563
395
  if page_size == 1:
564
396
  # TODO: boolean array index leads to a device sync. Remove it.
@@ -635,13 +467,15 @@ class EagleVerifyInput:
635
467
  else:
636
468
  batch.out_cache_loc = tgt_cache_loc
637
469
  batch.seq_lens.add_(accept_length + 1)
470
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
638
471
 
639
472
  draft_input = EagleDraftInput(
640
473
  hidden_states=batch.spec_info.hidden_states[accept_index],
641
474
  verified_id=verified_id,
642
475
  accept_length=accept_length,
643
- accept_length_cpu=accept_length.tolist(),
476
+ accept_length_cpu=accept_length_list,
644
477
  seq_lens_for_draft_extend=batch.seq_lens,
478
+ seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
645
479
  req_pool_indices_for_draft_extend=batch.req_pool_indices,
646
480
  )
647
481
 
@@ -664,15 +498,15 @@ class EagleVerifyInput:
664
498
  next_power_of_2(bs),
665
499
  )
666
500
  batch.seq_lens.add_(accept_length + 1)
501
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
667
502
 
668
- accept_length_cpu = accept_length.tolist()
669
503
  if len(unfinished_accept_index) > 0:
670
504
  unfinished_accept_index = torch.cat(unfinished_accept_index)
671
505
  unfinished_index_device = torch.tensor(
672
506
  unfinished_index, dtype=torch.int64, device=predict.device
673
507
  )
674
508
  draft_input_accept_length_cpu = [
675
- accept_length_cpu[i] for i in unfinished_index
509
+ accept_length_list[i] for i in unfinished_index
676
510
  ]
677
511
  if page_size == 1 or self.topk == 1:
678
512
  batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
@@ -687,6 +521,7 @@ class EagleVerifyInput:
687
521
  unfinished_index_device,
688
522
  batch.seq_lens,
689
523
  )
524
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
690
525
  filter_finished_cache_loc_kernel[(bs,)](
691
526
  batch.out_cache_loc,
692
527
  tgt_cache_loc,
@@ -704,6 +539,7 @@ class EagleVerifyInput:
704
539
  accept_length_cpu=draft_input_accept_length_cpu,
705
540
  accept_length=accept_length[unfinished_index_device],
706
541
  seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
542
+ seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
707
543
  req_pool_indices_for_draft_extend=batch.req_pool_indices[
708
544
  unfinished_index_device
709
545
  ],
@@ -721,577 +557,191 @@ class EagleVerifyInput:
721
557
  draft_input=draft_input,
722
558
  logits_output=logits_output,
723
559
  verified_id=verified_id,
724
- accept_length_per_req_cpu=accept_length_cpu,
560
+ accept_length_per_req_cpu=accept_length_list,
725
561
  accepted_indices=accept_index,
726
562
  )
727
563
 
728
564
 
729
- @triton.jit
730
- def create_extend_after_decode_spec_info(
731
- verified_id,
732
- seq_lens,
733
- accept_lens,
734
- positions,
735
- new_verified_id,
736
- bs_upper: tl.constexpr,
737
- ):
738
- pid = tl.program_id(axis=0)
739
- offsets = tl.arange(0, bs_upper)
740
- seq_length = tl.load(seq_lens + pid)
741
- accept_length = tl.load(accept_lens + pid)
742
-
743
- accept_len_cumsum = tl.sum(
744
- tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
745
- )
746
- positions_ptr = positions + accept_len_cumsum
747
- mask = offsets < accept_length
748
- tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
749
-
750
- accept_len_cumsum += accept_length - 1
751
- verified_id_data = tl.load(verified_id + accept_len_cumsum)
752
- tl.store(new_verified_id + pid, verified_id_data)
753
-
754
-
755
- @triton.jit
756
- def assign_req_to_token_pool(
757
- req_pool_indices,
758
- req_to_token,
759
- start_offset,
760
- end_offset,
761
- out_cache_loc,
762
- pool_len: tl.constexpr,
763
- bs_upper: tl.constexpr,
764
- ):
765
- BLOCK_SIZE: tl.constexpr = 32
766
- pid = tl.program_id(axis=0)
767
- kv_start = tl.load(start_offset + pid)
768
- kv_end = tl.load(end_offset + pid)
769
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
770
-
771
- length_offset = tl.arange(0, bs_upper)
772
- start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
773
- end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
774
- out_offset = tl.sum(end - start, axis=0)
775
-
776
- out_cache_ptr = out_cache_loc + out_offset
777
-
778
- save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
779
- load_offset = tl.arange(0, BLOCK_SIZE)
780
-
781
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
782
- for _ in range(num_loop):
783
- mask = save_offset < kv_end
784
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
785
- tl.store(token_pool + save_offset, data, mask=mask)
786
- save_offset += BLOCK_SIZE
787
- load_offset += BLOCK_SIZE
788
-
789
-
790
- @triton.jit
791
- def assign_draft_cache_locs(
792
- req_pool_indices,
793
- req_to_token,
794
- seq_lens,
795
- extend_lens,
796
- num_new_pages_per_topk,
797
- out_cache_loc,
798
- pool_len: tl.constexpr,
799
- topk: tl.constexpr,
800
- speculative_num_steps: tl.constexpr,
801
- page_size: tl.constexpr,
802
- bs_upper: tl.constexpr,
803
- iter_upper: tl.constexpr,
804
- ):
805
- BLOCK_SIZE: tl.constexpr = 128
806
- pid = tl.program_id(axis=0)
807
-
808
- if page_size == 1 or topk == 1:
809
- copy_len = topk * speculative_num_steps
810
- out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
811
- else:
812
- bs_offset = tl.arange(0, bs_upper)
813
- copy_len = tl.load(extend_lens + pid)
814
- cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
815
- out_cache_ptr = out_cache_loc + cum_copy_len
816
-
817
- # Part 1: Copy from out_cache_loc to req_to_token
818
- kv_start = tl.load(seq_lens + pid)
819
- token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
820
- num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
821
- for i in range(num_loop):
822
- copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
823
- mask = copy_offset < copy_len
824
- data = tl.load(out_cache_ptr + copy_offset, mask=mask)
825
- tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
826
-
827
- if page_size == 1 or topk == 1:
828
- return
829
-
830
- # Part 2: Copy the indices for the last partial page
831
- prefix_len = tl.load(seq_lens + pid)
832
- last_page_len = prefix_len % page_size
833
- offsets = tl.arange(0, page_size)
834
- mask = offsets < last_page_len
835
- num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
836
- prefix_base = token_pool + prefix_len - last_page_len
837
-
838
- for topk_id in range(topk):
839
- value = tl.load(prefix_base + offsets, mask=mask)
840
- tl.store(
841
- prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
842
- value,
843
- mask=mask,
844
- )
845
-
846
- # Part 3: Remove the padding in out_cache_loc
847
- iter_offest = tl.arange(0, iter_upper)
848
- for topk_id in range(topk):
849
- indices = tl.load(
850
- prefix_base
851
- + topk_id * num_new_pages_per_topk_ * page_size
852
- + last_page_len
853
- + iter_offest,
854
- mask=iter_offest < speculative_num_steps,
855
- )
856
- tl.store(
857
- out_cache_loc
858
- + pid * topk * speculative_num_steps
859
- + topk_id * speculative_num_steps
860
- + iter_offest,
861
- indices,
862
- mask=iter_offest < speculative_num_steps,
863
- )
565
+ @dataclass
566
+ class EagleDraftInput(SpecInput):
567
+ # The inputs for decode
568
+ # shape: (b, topk)
569
+ topk_p: torch.Tensor = None
570
+ topk_index: torch.Tensor = None
571
+ # shape: (b, hidden_size)
572
+ hidden_states: torch.Tensor = None
573
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
864
574
 
575
+ # Inputs for extend
576
+ # shape: (b,)
577
+ verified_id: torch.Tensor = None
578
+ accept_length: torch.Tensor = None
579
+ accept_length_cpu: List[int] = None
865
580
 
866
- @triton.jit
867
- def generate_draft_decode_kv_indices(
868
- req_pool_indices,
869
- req_to_token,
870
- paged_kernel_lens,
871
- kv_indices,
872
- kv_indptr,
873
- positions,
874
- pool_len: tl.constexpr,
875
- kv_indices_stride: tl.constexpr,
876
- kv_indptr_stride: tl.constexpr,
877
- bs_upper: tl.constexpr,
878
- iter_upper: tl.constexpr,
879
- num_tokens_upper: tl.constexpr,
880
- page_size: tl.constexpr,
881
- ):
882
- BLOCK_SIZE: tl.constexpr = 128
883
- iters = tl.program_id(axis=0)
884
- bid = tl.program_id(axis=1)
885
- topk_id = tl.program_id(axis=2)
886
-
887
- num_steps = tl.num_programs(axis=0)
888
- num_seqs = tl.num_programs(axis=1)
889
- topk = tl.num_programs(axis=2)
890
-
891
- kv_indices += kv_indices_stride * iters
892
- kv_indptr += kv_indptr_stride * iters
893
- iters += 1
894
-
895
- load_offset = tl.arange(0, bs_upper)
896
- seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
897
- seq_len = tl.load(paged_kernel_lens + bid)
898
- cum_seq_len = tl.sum(seq_lens)
899
-
900
- # Update kv_indices
901
- kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
902
- kv_ptr = kv_indices + kv_offset
903
- token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
904
-
905
- kv_offset = tl.arange(0, BLOCK_SIZE)
906
- num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
907
- for _ in range(num_loop):
908
- mask = kv_offset < seq_len
909
- data = tl.load(token_pool_ptr + kv_offset, mask=mask)
910
- tl.store(kv_ptr + kv_offset, data, mask=mask)
911
- kv_offset += BLOCK_SIZE
912
-
913
- extend_offset = tl.arange(0, iter_upper)
914
- if page_size == 1 or topk == 1:
915
- extend_data = tl.load(
916
- token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
917
- mask=extend_offset < iters,
918
- )
919
- else:
920
- prefix_len = seq_len
921
- last_page_len = prefix_len % page_size
922
- num_new_pages_per_topk = (
923
- last_page_len + num_steps + page_size - 1
924
- ) // page_size
925
- prefix_base = seq_len // page_size * page_size
926
- start = (
927
- prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
928
- )
929
- extend_data = tl.load(
930
- token_pool_ptr + start + extend_offset,
931
- mask=extend_offset < iters,
932
- )
581
+ # Inputs for the attention backends
582
+ # shape: (b + 1,)
583
+ kv_indptr: torch.Tensor = None
584
+ kv_indices: torch.Tensor = None
933
585
 
934
- tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
935
-
936
- # Update kv_indptr
937
- bs_offset = tl.arange(0, num_tokens_upper)
938
-
939
- zid = bid * topk + topk_id
940
- if zid == 0:
941
- zid = num_seqs * topk
942
- positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
943
- base = tl.sum(positions)
944
- tl.store(kv_indptr + zid, base + zid * iters)
945
-
946
-
947
- @triton.jit
948
- def align_evict_mask_to_page_size(
949
- seq_lens,
950
- evict_mask,
951
- page_size: tl.constexpr,
952
- num_draft_tokens: tl.constexpr,
953
- BLOCK_SIZE: tl.constexpr,
954
- ):
955
- t_range = tl.arange(0, BLOCK_SIZE)
956
-
957
- bid = tl.program_id(axis=0)
958
- seq_len = tl.load(seq_lens + bid)
959
- io_mask = t_range < num_draft_tokens
960
- mask_row = tl.load(
961
- evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
962
- )
586
+ # Shape info for padding
587
+ num_tokens_per_batch: int = -1
588
+ num_tokens_for_logprob_per_batch: int = -1
963
589
 
964
- num_trues = tl.sum(mask_row)
965
- num_false = num_draft_tokens - num_trues
966
-
967
- start = (seq_len + num_false - 1) // page_size * page_size - seq_len
968
- for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
969
- tl.store(evict_mask + bid * num_draft_tokens + i, False)
970
-
971
-
972
- @triton.jit
973
- def get_target_cache_loc(
974
- tgt_cache_loc,
975
- to_free_slots,
976
- accept_length,
977
- to_free_num_slots,
978
- out_cache_loc,
979
- num_verify_tokens: tl.constexpr,
980
- num_verify_tokens_upper: tl.constexpr,
981
- bs_upper: tl.constexpr,
982
- ):
983
- bid = tl.program_id(axis=0)
984
- offset = tl.arange(0, num_verify_tokens_upper)
985
- bs_offset = tl.arange(0, bs_upper)
986
-
987
- # write the first part to tgt_cache_loc
988
- accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
989
- tgt_cache_loc_start = tl.sum(accept_len_all) + bid
990
- copy_len = tl.load(accept_length + bid) + 1
991
- out_cache_loc_row = tl.load(
992
- out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
993
- )
994
- tl.store(
995
- tgt_cache_loc + tgt_cache_loc_start + offset,
996
- out_cache_loc_row,
997
- mask=offset < copy_len,
998
- )
590
+ # Inputs for draft extend
591
+ # shape: (b,)
592
+ seq_lens_for_draft_extend: torch.Tensor = None
593
+ seq_lens_for_draft_extend_cpu: torch.Tensor = None
594
+ req_pool_indices_for_draft_extend: torch.Tensor = None
999
595
 
1000
- # write the second part to to_free_num_pages
1001
- to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
1002
- to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
1003
- out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
1004
- to_free_slots_start = tl.sum(to_free_num_slots_all)
596
+ def __post_init__(self):
597
+ super().__init__(SpecInputType.EAGLE_DRAFT)
1005
598
 
1006
- copy_len = to_free_num_slots_cur
1007
- out_cache_loc_row = tl.load(
1008
- out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
1009
- mask=offset < copy_len,
1010
- )
1011
- tl.store(
1012
- to_free_slots + to_free_slots_start + offset,
1013
- out_cache_loc_row,
1014
- mask=offset < copy_len,
1015
- )
599
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
600
+ return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch
1016
601
 
602
+ def prepare_for_extend(self, batch: ScheduleBatch):
1017
603
 
1018
- @torch.compile(dynamic=True)
1019
- def get_src_tgt_cache_loc(
1020
- seq_lens: torch.Tensor,
1021
- out_cache_loc: torch.Tensor,
1022
- accept_index: torch.Tensor,
1023
- accept_length: torch.Tensor,
1024
- draft_token_num: int,
1025
- page_size: int,
1026
- ):
1027
- src_cache_loc = out_cache_loc[accept_index]
1028
- tgt_cache_loc = torch.empty_like(src_cache_loc)
1029
- extended_len = seq_lens + draft_token_num
1030
- keep_len = torch.minimum(
1031
- (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
1032
- extended_len,
1033
- )
1034
- to_free_num_slots = extended_len - keep_len
1035
- return src_cache_loc, tgt_cache_loc, to_free_num_slots
1036
-
1037
-
1038
- @triton.jit
1039
- def filter_finished_cache_loc_kernel(
1040
- out_cache_loc,
1041
- tgt_cache_loc,
1042
- accept_length,
1043
- accept_length_filter,
1044
- bs_upper: tl.constexpr,
1045
- num_verify_tokens_upper: tl.constexpr,
1046
- ):
1047
- bid = tl.program_id(0)
1048
- bs_offset = tl.arange(0, bs_upper)
1049
-
1050
- accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
1051
- old_start = tl.sum(accept_length_all) + bid
1052
-
1053
- accept_length_filter_all = tl.load(
1054
- accept_length_filter + bs_offset, mask=bs_offset < bid
1055
- )
1056
- new_start = tl.sum(accept_length_filter_all)
604
+ if batch.forward_mode.is_idle():
605
+ return
1057
606
 
1058
- copy_len = tl.load(accept_length_filter + bid)
1059
- copy_offset = tl.arange(0, num_verify_tokens_upper)
1060
- value = tl.load(
1061
- tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
1062
- )
1063
- tl.store(
1064
- out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
1065
- )
607
+ # Prefill only generate 1 token.
608
+ assert len(self.verified_id) == len(batch.seq_lens)
1066
609
 
610
+ pt = 0
611
+ for i, extend_len in enumerate(batch.extend_lens):
612
+ input_ids = batch.input_ids[pt : pt + extend_len]
613
+ batch.input_ids[pt : pt + extend_len] = torch.cat(
614
+ (input_ids[1:], self.verified_id[i].reshape(1))
615
+ )
616
+ pt += extend_len
1067
617
 
1068
- @torch.compile(dynamic=True)
1069
- def create_accept_length_filter(
1070
- accept_length: torch.Tensor,
1071
- unfinished_index_device: torch.Tensor,
1072
- seq_lens: torch.Tensor,
1073
- ):
1074
- accept_length_filter = torch.zeros_like(accept_length)
1075
- accept_length_filter[unfinished_index_device] = (
1076
- accept_length[unfinished_index_device] + 1
1077
- )
1078
- seq_lens.add_(accept_length + 1)
1079
- return accept_length_filter
1080
-
1081
-
1082
- @torch.compile(dynamic=True)
1083
- def select_top_k_tokens(
1084
- i: int,
1085
- topk_p: torch.Tensor,
1086
- topk_index: torch.Tensor,
1087
- hidden_states: torch.Tensor,
1088
- scores: torch.Tensor,
1089
- topk: int,
1090
- ):
1091
- if i == 0:
1092
- # The first step after extend
1093
- input_ids = topk_index.flatten()
1094
- hidden_states = hidden_states.repeat_interleave(topk, dim=0)
1095
- scores = topk_p # shape: (b, topk)
1096
-
1097
- tree_info = (
1098
- topk_p.unsqueeze(1), # shape: (b, 1, topk)
1099
- topk_index, # shape: (b, topk)
1100
- torch.arange(-1, topk, dtype=torch.long, device="cuda")
1101
- .unsqueeze(0)
1102
- .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
1103
- )
1104
- else:
1105
- # The later decode steps
1106
- expand_scores = torch.mul(
1107
- scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
1108
- ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
1109
- topk_cs_p, topk_cs_index = fast_topk(
1110
- expand_scores.flatten(start_dim=1), topk, dim=-1
1111
- ) # (b, topk)
1112
- scores = topk_cs_p # shape: (b, topk)
1113
-
1114
- topk_index = topk_index.reshape(-1, topk**2)
1115
- input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
1116
-
1117
- if hidden_states.shape[0] > 0:
1118
- selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
1119
- 0, hidden_states.shape[0], step=topk, device="cuda"
1120
- ).repeat_interleave(topk)
1121
- hidden_states = hidden_states[selected_input_index, :]
1122
-
1123
- tree_info = (
1124
- expand_scores, # shape: (b, topk, topk)
1125
- topk_index, # shape: (b, topk * topk)
1126
- topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
618
+ @classmethod
619
+ def create_idle_input(
620
+ cls,
621
+ device: torch.device,
622
+ hidden_size: int,
623
+ dtype: torch.dtype,
624
+ topk: int,
625
+ capture_hidden_mode: CaptureHiddenMode,
626
+ ):
627
+ return cls(
628
+ verified_id=torch.empty((0,), device=device, dtype=torch.int32),
629
+ hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
630
+ topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
631
+ topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
632
+ capture_hidden_mode=capture_hidden_mode,
633
+ accept_length=torch.empty((0,), device=device, dtype=torch.int32),
634
+ accept_length_cpu=[],
1127
635
  )
1128
636
 
1129
- return input_ids, hidden_states, scores, tree_info
1130
-
1131
-
1132
- def _generate_simulated_accept_index(
1133
- accept_index,
1134
- predict,
1135
- accept_length,
1136
- simulate_acc_len,
1137
- bs,
1138
- spec_steps,
1139
- ):
1140
- simulate_acc_len_float = float(simulate_acc_len)
1141
- if SIMULATE_ACC_METHOD == "multinomial":
1142
- simulated_values = torch.normal(
1143
- mean=simulate_acc_len_float,
1144
- std=1.0,
1145
- size=(1,),
1146
- device="cpu",
1147
- )
1148
- # clamp simulated values to be between 1 and self.spec_steps
1149
- simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
1150
- simulate_acc_len = int(simulated_values.round().item())
1151
- elif SIMULATE_ACC_METHOD == "match-expected":
1152
- # multinomial sampling does not match the expected length
1153
- # we keep it for the sake of compatibility of existing tests
1154
- # but it's better to use "match-expected" for the cases that need to
1155
- # match the expected length, One caveat is that this will only sample
1156
- # either round down or round up of the expected length
1157
- simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
1158
- lower = int(simulate_acc_len_float // 1)
1159
- upper = lower + 1 if lower < spec_steps + 1 else lower
1160
- if lower == upper:
1161
- simulate_acc_len = lower
1162
- else:
1163
- weight_upper = simulate_acc_len_float - lower
1164
- weight_lower = 1.0 - weight_upper
1165
- probs = torch.tensor([weight_lower, weight_upper], device="cpu")
1166
- sampled_index = torch.multinomial(probs, num_samples=1)
1167
- simulate_acc_len = lower if sampled_index == 0 else upper
1168
- else:
1169
- raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
1170
-
1171
- accept_indx_first_col = accept_index[:, 0].view(-1, 1)
1172
- sim_accept_index = torch.full(
1173
- (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
1174
- )
1175
- sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
1176
- simulate_acc_len, device=accept_index.device
1177
- )
1178
- accept_length.fill_(simulate_acc_len - 1)
1179
- predict.fill_(100) # some legit token id
1180
- return sim_accept_index
1181
-
1182
-
1183
- def traverse_tree(
1184
- retrieve_next_token: torch.Tensor,
1185
- retrieve_next_sibling: torch.Tensor,
1186
- draft_tokens: torch.Tensor,
1187
- grammar: BaseGrammarObject,
1188
- allocate_token_bitmask: torch.Tensor,
1189
- ):
1190
- """
1191
- Traverse the tree constructed by the draft model to generate the logits mask.
1192
- """
1193
- assert (
1194
- retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
1195
- )
637
+ def prepare_extend_after_decode(
638
+ self,
639
+ batch: ScheduleBatch,
640
+ speculative_num_steps: int,
641
+ ):
642
+
643
+ if batch.forward_mode.is_idle():
644
+ return
1196
645
 
1197
- allocate_token_bitmask.fill_(0)
646
+ batch.input_ids = self.verified_id
647
+ batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
648
+ batch.extend_num_tokens = sum(batch.extend_lens)
649
+ batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
650
+ batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
651
+ batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
652
+ batch.return_logprob = False
653
+ batch.return_hidden_states = False
1198
654
 
1199
- def dfs(
1200
- curr: int,
1201
- retrieve_next_token: torch.Tensor,
1202
- retrieve_next_sibling: torch.Tensor,
1203
- parent_pos: int,
655
+ self.capture_hidden_mode = CaptureHiddenMode.LAST
656
+ self.accept_length.add_(1)
657
+ self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
658
+ self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
659
+
660
+ create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
661
+ batch.input_ids,
662
+ batch.seq_lens,
663
+ self.accept_length,
664
+ self.positions,
665
+ self.verified_id,
666
+ next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
667
+ )
668
+
669
+ def generate_attn_arg_prefill(
670
+ self,
671
+ req_pool_indices: torch.Tensor,
672
+ paged_kernel_lens: torch.Tensor,
673
+ paged_kernel_lens_sum: int,
674
+ req_to_token: torch.Tensor,
1204
675
  ):
1205
- if curr == 0:
1206
- # the first token generated by the target model, and thus it is always
1207
- # accepted from the previous iteration
1208
- accepted = True
1209
- else:
1210
- parent_bitmask = allocate_token_bitmask[parent_pos]
1211
- curr_token_id = draft_tokens[curr]
1212
- # 32 boolean bitmask values are packed into 32-bit integers
1213
- accepted = (
1214
- parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
1215
- ) != 0
1216
-
1217
- if accepted:
1218
- if curr != 0:
1219
- # Accept the current token
1220
- grammar.accept_token(draft_tokens[curr])
1221
- if not grammar.is_terminated():
1222
- # Generate the bitmask for the current token
1223
- grammar.fill_vocab_mask(allocate_token_bitmask, curr)
1224
- if retrieve_next_token[curr] != -1:
1225
- # Visit the child node
1226
- dfs(
1227
- retrieve_next_token[curr],
1228
- retrieve_next_token,
1229
- retrieve_next_sibling,
1230
- curr,
1231
- )
676
+ bs = self.accept_length.numel()
677
+ qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
678
+ qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
679
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
680
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
1232
681
 
1233
- if curr != 0:
1234
- # Rollback the current token
1235
- grammar.rollback(1)
1236
-
1237
- if retrieve_next_sibling[curr] != -1:
1238
- # Visit the sibling node
1239
- dfs(
1240
- retrieve_next_sibling[curr],
1241
- retrieve_next_token,
1242
- retrieve_next_sibling,
1243
- parent_pos,
1244
- )
682
+ if paged_kernel_lens_sum is None:
683
+ paged_kernel_lens_sum = cum_kv_seq_len[-1]
1245
684
 
1246
- dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
1247
-
1248
-
1249
- def generate_token_bitmask(
1250
- reqs: List[Req],
1251
- verify_input: EagleVerifyInput,
1252
- retrieve_next_token_cpu: torch.Tensor,
1253
- retrieve_next_sibling_cpu: torch.Tensor,
1254
- draft_tokens_cpu: torch.Tensor,
1255
- vocab_size: int,
1256
- ):
1257
- """
1258
- Generate the logit mask for structured output.
1259
- Draft model's token can be either valid or invalid with respect to the grammar.
1260
- We need to perform DFS to
1261
- 1. figure out which tokens are accepted by the grammar.
1262
- 2. if so, what is the corresponding logit mask.
1263
- """
1264
-
1265
- num_draft_tokens = draft_tokens_cpu.shape[-1]
1266
-
1267
- allocate_token_bitmask = None
1268
- assert len(reqs) == retrieve_next_token_cpu.shape[0]
1269
- grammar = None
1270
- for i, req in enumerate(reqs):
1271
- if req.grammar is not None:
1272
- if allocate_token_bitmask is None:
1273
- allocate_token_bitmask = req.grammar.allocate_vocab_mask(
1274
- vocab_size=vocab_size,
1275
- batch_size=draft_tokens_cpu.numel(),
1276
- device="cpu",
1277
- )
1278
- grammar = req.grammar
1279
- s = time.perf_counter()
1280
- traverse_tree(
1281
- retrieve_next_token_cpu[i],
1282
- retrieve_next_sibling_cpu[i],
1283
- draft_tokens_cpu[i],
1284
- req.grammar,
1285
- allocate_token_bitmask[
1286
- i * num_draft_tokens : (i + 1) * num_draft_tokens
1287
- ],
1288
- )
1289
- tree_traverse_time = time.perf_counter() - s
1290
- if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
685
+ kv_indices = torch.empty(
686
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
687
+ )
688
+
689
+ create_flashinfer_kv_indices_triton[(bs,)](
690
+ req_to_token,
691
+ req_pool_indices,
692
+ paged_kernel_lens,
693
+ cum_kv_seq_len,
694
+ None,
695
+ kv_indices,
696
+ req_to_token.size(1),
697
+ )
698
+ return kv_indices, cum_kv_seq_len, qo_indptr, None
699
+
700
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
701
+ if has_been_filtered:
702
+ # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
703
+ # therefore, we don't need to filter the batch again in scheduler
704
+ if len(new_indices) != len(self.topk_p):
1291
705
  logger.warning(
1292
- f"Bit mask generation took {tree_traverse_time} seconds with "
1293
- f"grammar: {req.grammar}"
706
+ f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
1294
707
  )
708
+ self.topk_p = self.topk_p[: len(new_indices)]
709
+ self.topk_index = self.topk_index[: len(new_indices)]
710
+ self.hidden_states = self.hidden_states[: len(new_indices)]
711
+ self.verified_id = self.verified_id[: len(new_indices)]
712
+ else:
713
+ # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
714
+ self.topk_p = self.topk_p[new_indices]
715
+ self.topk_index = self.topk_index[new_indices]
716
+ self.hidden_states = self.hidden_states[new_indices]
717
+ self.verified_id = self.verified_id[new_indices]
718
+
719
+ def merge_batch(self, spec_info: "EagleDraftInput"):
720
+ if self.hidden_states is None:
721
+ self.hidden_states = spec_info.hidden_states
722
+ self.verified_id = spec_info.verified_id
723
+ self.topk_p = spec_info.topk_p
724
+ self.topk_index = spec_info.topk_index
725
+ return
726
+ if spec_info.hidden_states is None:
727
+ return
728
+ self.hidden_states = torch.cat(
729
+ [self.hidden_states, spec_info.hidden_states], axis=0
730
+ )
731
+ self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
732
+ self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
733
+ self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
734
+
1295
735
 
1296
- verify_input.grammar = grammar
1297
- return allocate_token_bitmask
736
+ @dataclass
737
+ class EagleVerifyOutput:
738
+ # Draft input batch
739
+ draft_input: EagleDraftInput
740
+ # Logit outputs from target worker
741
+ logits_output: LogitsProcessorOutput
742
+ # Accepted token ids including the bonus token
743
+ verified_id: torch.Tensor
744
+ # Accepted token length per sequence in a batch in CPU.
745
+ accept_length_per_req_cpu: List[int]
746
+ # Accepted indices from logits_output.next_token_logits
747
+ accepted_indices: torch.Tensor