sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,786 @@
1
+ import logging
2
+ from copy import copy
3
+ from dataclasses import dataclass
4
+ from typing import ClassVar, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
10
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
11
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
12
+ from sglang.srt.layers.sampler import apply_custom_logit_processor
13
+ from sglang.srt.managers.overlap_utils import FutureIndices
14
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
15
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
16
+ from sglang.srt.mem_cache.common import (
17
+ alloc_paged_token_slots_extend,
18
+ alloc_token_slots,
19
+ get_last_loc,
20
+ )
21
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
22
+ from sglang.srt.server_args import get_global_server_args
23
+ from sglang.srt.speculative.eagle_info_v2 import (
24
+ EagleDraftInputV2Mixin,
25
+ EagleVerifyInputV2Mixin,
26
+ )
27
+ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
28
+ from sglang.srt.speculative.spec_utils import (
29
+ SIMULATE_ACC_LEN,
30
+ TREE_SPEC_KERNEL_AVAILABLE,
31
+ align_evict_mask_to_page_size,
32
+ assign_req_to_token_pool,
33
+ create_accept_length_filter,
34
+ create_extend_after_decode_spec_info,
35
+ filter_finished_cache_loc_kernel,
36
+ generate_simulated_accept_index,
37
+ get_src_tgt_cache_loc,
38
+ get_target_cache_loc,
39
+ )
40
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
41
+
42
+ if is_cuda():
43
+ from sgl_kernel import (
44
+ top_k_renorm_prob,
45
+ top_p_renorm_prob,
46
+ tree_speculative_sampling_target_only,
47
+ verify_tree_greedy,
48
+ )
49
+ elif is_hip():
50
+ from sgl_kernel import verify_tree_greedy
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ @dataclass
56
+ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
57
+ draft_token: torch.Tensor
58
+ custom_mask: torch.Tensor
59
+ positions: torch.Tensor
60
+ retrive_index: torch.Tensor
61
+ retrive_next_token: torch.Tensor
62
+ retrive_next_sibling: torch.Tensor
63
+ retrive_cum_len: torch.Tensor
64
+ spec_steps: int
65
+ topk: int
66
+ draft_token_num: int
67
+ capture_hidden_mode: CaptureHiddenMode
68
+ seq_lens_sum: int
69
+ seq_lens_cpu: torch.Tensor
70
+ grammar: BaseGrammarObject = None
71
+
72
+ def __post_init__(self):
73
+ super().__init__(SpecInputType.EAGLE_VERIFY)
74
+
75
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
76
+ return self.draft_token_num, self.draft_token_num
77
+
78
+ @classmethod
79
+ def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
80
+ return cls(
81
+ draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
82
+ custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
83
+ positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
84
+ retrive_index=torch.full(
85
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
86
+ ),
87
+ retrive_next_token=torch.full(
88
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
89
+ ),
90
+ retrive_next_sibling=torch.full(
91
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
92
+ ),
93
+ retrive_cum_len=None,
94
+ topk=topk,
95
+ draft_token_num=num_verify_tokens,
96
+ spec_steps=spec_steps,
97
+ capture_hidden_mode=CaptureHiddenMode.FULL,
98
+ seq_lens_sum=0,
99
+ seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
100
+ )
101
+
102
+ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
103
+
104
+ if batch.forward_mode.is_idle():
105
+ return
106
+
107
+ batch.input_ids = self.draft_token
108
+
109
+ if page_size == 1:
110
+ batch.out_cache_loc = alloc_token_slots(
111
+ batch.tree_cache,
112
+ len(batch.input_ids),
113
+ )
114
+ end_offset = batch.seq_lens + self.draft_token_num
115
+ else:
116
+ prefix_lens = batch.seq_lens
117
+ prefix_lens_cpu = batch.seq_lens_cpu
118
+ end_offset = prefix_lens + self.draft_token_num
119
+ end_offset_cpu = prefix_lens_cpu + self.draft_token_num
120
+ last_loc = get_last_loc(
121
+ batch.req_to_token_pool.req_to_token,
122
+ batch.req_pool_indices,
123
+ prefix_lens,
124
+ )
125
+ batch.out_cache_loc = alloc_paged_token_slots_extend(
126
+ batch.tree_cache,
127
+ prefix_lens,
128
+ prefix_lens_cpu,
129
+ end_offset,
130
+ end_offset_cpu,
131
+ last_loc,
132
+ len(batch.input_ids),
133
+ )
134
+ self.last_loc = last_loc
135
+
136
+ bs = batch.batch_size()
137
+ assign_req_to_token_pool[(bs,)](
138
+ batch.req_pool_indices,
139
+ batch.req_to_token_pool.req_to_token,
140
+ batch.seq_lens,
141
+ end_offset,
142
+ batch.out_cache_loc,
143
+ batch.req_to_token_pool.req_to_token.shape[1],
144
+ next_power_of_2(bs),
145
+ )
146
+
147
+ def generate_attn_arg_prefill(
148
+ self,
149
+ req_pool_indices: torch.Tensor,
150
+ paged_kernel_lens: torch.Tensor,
151
+ paged_kernel_lens_sum: int,
152
+ req_to_token: torch.Tensor,
153
+ ):
154
+ batch_size = len(req_pool_indices)
155
+ qo_indptr = torch.arange(
156
+ 0,
157
+ (1 + batch_size) * self.draft_token_num,
158
+ step=self.draft_token_num,
159
+ dtype=torch.int32,
160
+ device="cuda",
161
+ )
162
+ cum_kv_seq_len = torch.zeros(
163
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
164
+ )
165
+
166
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
167
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
168
+
169
+ kv_indices = torch.empty(
170
+ paged_kernel_lens_sum + self.draft_token_num * batch_size,
171
+ dtype=torch.int32,
172
+ device="cuda",
173
+ )
174
+ create_flashinfer_kv_indices_triton[(batch_size,)](
175
+ req_to_token,
176
+ req_pool_indices,
177
+ paged_kernel_lens,
178
+ cum_kv_seq_len,
179
+ None,
180
+ kv_indices,
181
+ req_to_token.size(1),
182
+ )
183
+ return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
184
+
185
+ def verify(
186
+ self,
187
+ batch: ScheduleBatch,
188
+ logits_output: LogitsProcessorOutput,
189
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
190
+ page_size: int,
191
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
192
+ ) -> torch.Tensor:
193
+ """
194
+ Verify and find accepted tokens based on logits output and batch
195
+ (which contains spec decoding information).
196
+
197
+ WARNING: This API in-place modifies the states of logits_output
198
+
199
+ This API updates values inside logits_output based on the accepted
200
+ tokens. I.e., logits_output.next_token_logits only contains
201
+ accepted token logits.
202
+ """
203
+ if batch.forward_mode.is_idle():
204
+ return EagleVerifyOutput(
205
+ draft_input=EagleDraftInput.create_idle_input(
206
+ device=batch.device,
207
+ hidden_size=batch.model_config.hidden_size,
208
+ dtype=batch.model_config.dtype,
209
+ topk=self.topk,
210
+ capture_hidden_mode=CaptureHiddenMode.LAST,
211
+ ),
212
+ logits_output=logits_output,
213
+ verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
214
+ accept_length_per_req_cpu=[],
215
+ accepted_indices=torch.full(
216
+ (0, self.spec_steps + 1),
217
+ -1,
218
+ dtype=torch.int32,
219
+ device=batch.device,
220
+ ),
221
+ )
222
+
223
+ bs = self.retrive_index.shape[0]
224
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
225
+ sampling_info = batch.sampling_info
226
+
227
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
228
+ predict_shape[-1] += 1
229
+ predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
230
+ accept_index = torch.full(
231
+ (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
232
+ )
233
+ accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
234
+
235
+ if bs != len(sampling_info):
236
+ sampling_info = copy.deepcopy(sampling_info)
237
+ # NOTE: retrive_index are the indices of the requests that are kept.
238
+ sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
239
+
240
+ # Apply the custom logit processors if registered in the sampling info.
241
+ if sampling_info.has_custom_logit_processor:
242
+ apply_custom_logit_processor(
243
+ logits_output.next_token_logits,
244
+ sampling_info,
245
+ num_tokens_in_batch=self.draft_token_num,
246
+ )
247
+
248
+ # Apply penalty
249
+ if (
250
+ sampling_info.penalizer_orchestrator.is_required
251
+ or sampling_info.logit_bias is not None
252
+ ):
253
+ # This is a relaxed version of penalties for speculative decoding.
254
+ linear_penalty = torch.zeros(
255
+ (bs, logits_output.next_token_logits.shape[1]),
256
+ dtype=torch.float32,
257
+ device="cuda",
258
+ )
259
+ sampling_info.apply_logits_bias(linear_penalty)
260
+ logits_output.next_token_logits.add_(
261
+ torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
262
+ )
263
+
264
+ # Apply grammar mask
265
+ if vocab_mask is not None:
266
+ assert self.grammar is not None
267
+ self.grammar.apply_vocab_mask(
268
+ logits=logits_output.next_token_logits, vocab_mask=vocab_mask
269
+ )
270
+
271
+ # Sample tokens. Force greedy sampling on AMD
272
+ is_all_greedy = sampling_info.is_all_greedy
273
+ if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
274
+ logger.warning(
275
+ "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
276
+ "Falling back to greedy verification."
277
+ )
278
+
279
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
280
+ target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
281
+ target_predict = target_predict.reshape(bs, self.draft_token_num)
282
+
283
+ verify_tree_greedy(
284
+ predicts=predict, # mutable
285
+ accept_index=accept_index, # mutable
286
+ accept_token_num=accept_length, # mutable
287
+ candidates=candidates,
288
+ retrive_index=self.retrive_index,
289
+ retrive_next_token=self.retrive_next_token,
290
+ retrive_next_sibling=self.retrive_next_sibling,
291
+ target_predict=target_predict,
292
+ )
293
+ else:
294
+ # apply temperature and get target probs
295
+ expanded_temperature = torch.repeat_interleave(
296
+ sampling_info.temperatures, self.draft_token_num, dim=0
297
+ ) # (bs * draft_token_num, 1)
298
+
299
+ target_probs = F.softmax(
300
+ logits_output.next_token_logits / expanded_temperature, dim=-1
301
+ ) # (bs * draft_token_num, vocab_size)
302
+ target_probs = top_k_renorm_prob(
303
+ target_probs,
304
+ torch.repeat_interleave(
305
+ sampling_info.top_ks, self.draft_token_num, dim=0
306
+ ),
307
+ ) # (bs * draft_token_num, vocab_size)
308
+ if not torch.all(sampling_info.top_ps == 1.0):
309
+ target_probs = top_p_renorm_prob(
310
+ target_probs,
311
+ torch.repeat_interleave(
312
+ sampling_info.top_ps, self.draft_token_num, dim=0
313
+ ),
314
+ )
315
+ target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
316
+
317
+ draft_probs = torch.zeros(
318
+ target_probs.shape, dtype=torch.float32, device="cuda"
319
+ )
320
+
321
+ # coins for rejection sampling
322
+ coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
323
+ # coins for final sampling
324
+ coins_for_final_sampling = torch.rand(
325
+ (bs,), dtype=torch.float32, device="cuda"
326
+ )
327
+ tree_speculative_sampling_target_only(
328
+ predicts=predict, # mutable
329
+ accept_index=accept_index, # mutable
330
+ accept_token_num=accept_length, # mutable
331
+ candidates=candidates,
332
+ retrive_index=self.retrive_index,
333
+ retrive_next_token=self.retrive_next_token,
334
+ retrive_next_sibling=self.retrive_next_sibling,
335
+ uniform_samples=coins,
336
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
337
+ target_probs=target_probs,
338
+ draft_probs=draft_probs,
339
+ threshold_single=get_global_server_args().speculative_accept_threshold_single,
340
+ threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
341
+ deterministic=True,
342
+ )
343
+
344
+ if SIMULATE_ACC_LEN > 0.0:
345
+ # Do simulation
346
+ accept_index = generate_simulated_accept_index(
347
+ accept_index=accept_index,
348
+ predict=predict, # mutable
349
+ accept_length=accept_length, # mutable
350
+ bs=bs,
351
+ spec_steps=self.spec_steps,
352
+ )
353
+
354
+ unfinished_index = []
355
+ unfinished_accept_index = []
356
+ accept_index_cpu = accept_index.tolist()
357
+ predict_cpu = predict.tolist()
358
+ has_finished = False
359
+
360
+ # Iterate every accepted token and check if req has finished after append the token
361
+ # should be checked BEFORE free kv cache slots
362
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
363
+ for j, idx in enumerate(accept_index_row):
364
+ if idx == -1:
365
+ break
366
+ id = predict_cpu[idx]
367
+ req.output_ids.append(id)
368
+ req.check_finished()
369
+ if req.finished():
370
+ has_finished = True
371
+ # set all tokens after finished token to -1 and break
372
+ accept_index[i, j + 1 :] = -1
373
+ break
374
+ else:
375
+ if req.grammar is not None:
376
+ try:
377
+ req.grammar.accept_token(id)
378
+ except ValueError as e:
379
+ logger.info(
380
+ f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
381
+ )
382
+ raise e
383
+ if not req.finished():
384
+ unfinished_index.append(i)
385
+ if idx == -1:
386
+ unfinished_accept_index.append(accept_index[i, :j])
387
+ else:
388
+ unfinished_accept_index.append(accept_index[i])
389
+ req.spec_verify_ct += 1
390
+ req.spec_accepted_tokens += (
391
+ sum(1 for idx in accept_index_row if idx != -1) - 1
392
+ )
393
+
394
+ if has_finished:
395
+ accept_length = (accept_index != -1).sum(dim=1) - 1
396
+
397
+ # Free the KV cache for unaccepted tokens
398
+ # TODO: fuse them
399
+ accept_index = accept_index[accept_index != -1]
400
+ verified_id = predict[accept_index]
401
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
402
+ evict_mask[accept_index] = False
403
+ accept_length_cpu = accept_length.cpu()
404
+ # FIXME: this `tolist()` fixes the numerical calculation consistency
405
+ # try to unify the tensor representation and list representation
406
+ accept_length_list = accept_length_cpu.tolist()
407
+
408
+ if page_size == 1:
409
+ # TODO: boolean array index leads to a device sync. Remove it.
410
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
411
+ else:
412
+ if self.topk == 1:
413
+ # Only evict full empty page. Do not evict partial empty page
414
+ align_evict_mask_to_page_size[len(batch.seq_lens),](
415
+ batch.seq_lens,
416
+ evict_mask,
417
+ page_size,
418
+ self.draft_token_num,
419
+ next_power_of_2(self.draft_token_num),
420
+ )
421
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
422
+ else:
423
+ # Shift the accepted tokens to the beginning.
424
+ # Only evict the last part
425
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
426
+ batch.seq_lens,
427
+ batch.out_cache_loc,
428
+ accept_index,
429
+ accept_length,
430
+ self.draft_token_num,
431
+ page_size,
432
+ )
433
+ to_free_slots = torch.empty(
434
+ (to_free_num_slots.sum().item(),),
435
+ dtype=torch.int64,
436
+ device=to_free_num_slots.device,
437
+ )
438
+
439
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
440
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
441
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
442
+ # to_free_slots: [ 2, 5, 7 8]
443
+ # to_free_slots also needs to be page-aligned without the first partial page
444
+ #
445
+ # split each row of out_cache_loc into two parts.
446
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
447
+ # 2. the second part goes to to_free_slots.
448
+ get_target_cache_loc[(bs,)](
449
+ tgt_cache_loc,
450
+ to_free_slots,
451
+ accept_length,
452
+ to_free_num_slots,
453
+ batch.out_cache_loc,
454
+ self.draft_token_num,
455
+ next_power_of_2(self.draft_token_num),
456
+ next_power_of_2(bs),
457
+ )
458
+
459
+ # Free the kv cache
460
+ token_to_kv_pool_allocator.free(to_free_slots)
461
+
462
+ # Copy the kv cache
463
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
464
+ tgt_cache_loc, src_cache_loc
465
+ )
466
+
467
+ # Construct EagleVerifyOutput
468
+ if not has_finished:
469
+ if page_size == 1 or self.topk == 1:
470
+ batch.out_cache_loc = batch.out_cache_loc[accept_index]
471
+ assign_req_to_token_pool[(bs,)](
472
+ batch.req_pool_indices,
473
+ batch.req_to_token_pool.req_to_token,
474
+ batch.seq_lens,
475
+ batch.seq_lens + accept_length + 1,
476
+ batch.out_cache_loc,
477
+ batch.req_to_token_pool.req_to_token.shape[1],
478
+ next_power_of_2(bs),
479
+ )
480
+ else:
481
+ batch.out_cache_loc = tgt_cache_loc
482
+ batch.seq_lens.add_(accept_length + 1)
483
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
484
+
485
+ draft_input = EagleDraftInput(
486
+ hidden_states=batch.spec_info.hidden_states[accept_index],
487
+ verified_id=verified_id,
488
+ accept_length=accept_length,
489
+ accept_length_cpu=accept_length_list,
490
+ seq_lens_for_draft_extend=batch.seq_lens,
491
+ seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
492
+ req_pool_indices_for_draft_extend=batch.req_pool_indices,
493
+ )
494
+
495
+ return EagleVerifyOutput(
496
+ draft_input=draft_input,
497
+ logits_output=logits_output,
498
+ verified_id=verified_id,
499
+ accept_length_per_req_cpu=draft_input.accept_length_cpu,
500
+ accepted_indices=accept_index,
501
+ )
502
+ else:
503
+ if page_size == 1 or self.topk == 1:
504
+ assign_req_to_token_pool[(bs,)](
505
+ batch.req_pool_indices,
506
+ batch.req_to_token_pool.req_to_token,
507
+ batch.seq_lens,
508
+ batch.seq_lens + accept_length + 1,
509
+ batch.out_cache_loc[accept_index],
510
+ batch.req_to_token_pool.req_to_token.shape[1],
511
+ next_power_of_2(bs),
512
+ )
513
+ batch.seq_lens.add_(accept_length + 1)
514
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
515
+
516
+ if len(unfinished_accept_index) > 0:
517
+ unfinished_accept_index = torch.cat(unfinished_accept_index)
518
+ unfinished_index_device = torch.tensor(
519
+ unfinished_index, dtype=torch.int64, device=predict.device
520
+ )
521
+ draft_input_accept_length_cpu = [
522
+ accept_length_list[i] for i in unfinished_index
523
+ ]
524
+ if page_size == 1 or self.topk == 1:
525
+ batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
526
+ else:
527
+ batch.out_cache_loc = torch.empty(
528
+ len(unfinished_index) + sum(draft_input_accept_length_cpu),
529
+ dtype=torch.int64,
530
+ device=predict.device,
531
+ )
532
+ accept_length_filter = create_accept_length_filter(
533
+ accept_length,
534
+ unfinished_index_device,
535
+ batch.seq_lens,
536
+ )
537
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
538
+ filter_finished_cache_loc_kernel[(bs,)](
539
+ batch.out_cache_loc,
540
+ tgt_cache_loc,
541
+ accept_length,
542
+ accept_length_filter,
543
+ next_power_of_2(bs),
544
+ next_power_of_2(self.draft_token_num),
545
+ )
546
+
547
+ draft_input = EagleDraftInput(
548
+ hidden_states=batch.spec_info.hidden_states[
549
+ unfinished_accept_index
550
+ ],
551
+ verified_id=predict[unfinished_accept_index],
552
+ accept_length_cpu=draft_input_accept_length_cpu,
553
+ accept_length=accept_length[unfinished_index_device],
554
+ seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
555
+ seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
556
+ req_pool_indices_for_draft_extend=batch.req_pool_indices[
557
+ unfinished_index_device
558
+ ],
559
+ )
560
+ else:
561
+ draft_input = EagleDraftInput.create_idle_input(
562
+ device=batch.device,
563
+ hidden_size=batch.model_config.hidden_size,
564
+ dtype=batch.model_config.dtype,
565
+ topk=self.topk,
566
+ capture_hidden_mode=CaptureHiddenMode.LAST,
567
+ )
568
+
569
+ return EagleVerifyOutput(
570
+ draft_input=draft_input,
571
+ logits_output=logits_output,
572
+ verified_id=verified_id,
573
+ accept_length_per_req_cpu=accept_length_list,
574
+ accepted_indices=accept_index,
575
+ )
576
+
577
+
578
+ @dataclass
579
+ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
580
+ # Constant: alloc length per decode step
581
+ ALLOC_LEN_PER_DECODE: ClassVar[int] = None
582
+
583
+ # The inputs for decode
584
+ # shape: (b, topk)
585
+ topk_p: torch.Tensor = None
586
+ topk_index: torch.Tensor = None
587
+ # shape: (b, hidden_size)
588
+ hidden_states: torch.Tensor = None
589
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
590
+
591
+ # Inputs for extend
592
+ # shape: (b,)
593
+ verified_id: torch.Tensor = None
594
+ accept_length: torch.Tensor = None
595
+ accept_length_cpu: List[int] = None
596
+
597
+ # Inputs for the attention backends
598
+ # shape: (b + 1,)
599
+ kv_indptr: torch.Tensor = None
600
+ kv_indices: torch.Tensor = None
601
+
602
+ # Shape info for padding
603
+ num_tokens_per_batch: int = -1
604
+ num_tokens_for_logprob_per_batch: int = -1
605
+
606
+ # Inputs for draft extend
607
+ # shape: (b,)
608
+ seq_lens_for_draft_extend: torch.Tensor = None
609
+ seq_lens_for_draft_extend_cpu: torch.Tensor = None
610
+ req_pool_indices_for_draft_extend: torch.Tensor = None
611
+
612
+ # Inputs for V2 overlap worker
613
+ future_indices: Optional[FutureIndices] = None
614
+ allocate_lens: Optional[torch.Tensor] = None
615
+ new_seq_lens: Optional[torch.Tensor] = None
616
+ verify_done: Optional[torch.cuda.Event] = None
617
+
618
+ def __post_init__(self):
619
+ super().__init__(SpecInputType.EAGLE_DRAFT)
620
+
621
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
622
+ return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch
623
+
624
+ def prepare_for_extend(self, batch: ScheduleBatch):
625
+
626
+ if batch.forward_mode.is_idle():
627
+ return
628
+
629
+ # Prefill only generate 1 token.
630
+ assert len(self.verified_id) == len(batch.seq_lens)
631
+
632
+ pt = 0
633
+ for i, extend_len in enumerate(batch.extend_lens):
634
+ input_ids = batch.input_ids[pt : pt + extend_len]
635
+ batch.input_ids[pt : pt + extend_len] = torch.cat(
636
+ (input_ids[1:], self.verified_id[i].reshape(1))
637
+ )
638
+ pt += extend_len
639
+
640
+ @classmethod
641
+ def create_idle_input(
642
+ cls,
643
+ device: torch.device,
644
+ hidden_size: int,
645
+ dtype: torch.dtype,
646
+ topk: int,
647
+ capture_hidden_mode: CaptureHiddenMode,
648
+ ):
649
+ return cls(
650
+ verified_id=torch.empty((0,), device=device, dtype=torch.int32),
651
+ hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
652
+ topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
653
+ topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
654
+ capture_hidden_mode=capture_hidden_mode,
655
+ accept_length=torch.empty((0,), device=device, dtype=torch.int32),
656
+ accept_length_cpu=[],
657
+ )
658
+
659
+ def prepare_extend_after_decode(
660
+ self,
661
+ batch: ScheduleBatch,
662
+ speculative_num_steps: int,
663
+ ):
664
+
665
+ if batch.forward_mode.is_idle():
666
+ return
667
+
668
+ batch.input_ids = self.verified_id
669
+ batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
670
+ batch.extend_num_tokens = sum(batch.extend_lens)
671
+ batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
672
+ batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
673
+ batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
674
+ batch.return_logprob = False
675
+ batch.return_hidden_states = False
676
+
677
+ self.capture_hidden_mode = CaptureHiddenMode.LAST
678
+ self.accept_length.add_(1)
679
+ self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
680
+ self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
681
+
682
+ create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
683
+ batch.input_ids,
684
+ batch.seq_lens,
685
+ self.accept_length,
686
+ self.positions,
687
+ self.verified_id,
688
+ next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
689
+ )
690
+
691
+ def generate_attn_arg_prefill(
692
+ self,
693
+ req_pool_indices: torch.Tensor,
694
+ paged_kernel_lens: torch.Tensor,
695
+ paged_kernel_lens_sum: int,
696
+ req_to_token: torch.Tensor,
697
+ ):
698
+ bs = self.accept_length.numel()
699
+ qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
700
+ qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
701
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
702
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
703
+
704
+ if paged_kernel_lens_sum is None:
705
+ paged_kernel_lens_sum = cum_kv_seq_len[-1]
706
+
707
+ kv_indices = torch.empty(
708
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
709
+ )
710
+
711
+ create_flashinfer_kv_indices_triton[(bs,)](
712
+ req_to_token,
713
+ req_pool_indices,
714
+ paged_kernel_lens,
715
+ cum_kv_seq_len,
716
+ None,
717
+ kv_indices,
718
+ req_to_token.size(1),
719
+ )
720
+ return kv_indices, cum_kv_seq_len, qo_indptr, None
721
+
722
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
723
+ if self.future_indices is not None:
724
+ self.future_indices.indices = self.future_indices.indices[new_indices]
725
+ self.allocate_lens = self.allocate_lens[new_indices]
726
+ return
727
+
728
+ if has_been_filtered:
729
+ # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
730
+ # therefore, we don't need to filter the batch again in scheduler
731
+ if len(new_indices) != len(self.topk_p):
732
+ logger.warning(
733
+ f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
734
+ )
735
+ self.topk_p = self.topk_p[: len(new_indices)]
736
+ self.topk_index = self.topk_index[: len(new_indices)]
737
+ self.hidden_states = self.hidden_states[: len(new_indices)]
738
+ self.verified_id = self.verified_id[: len(new_indices)]
739
+ else:
740
+ # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
741
+ self.topk_p = self.topk_p[new_indices]
742
+ self.topk_index = self.topk_index[new_indices]
743
+ self.hidden_states = self.hidden_states[new_indices]
744
+ self.verified_id = self.verified_id[new_indices]
745
+
746
+ def merge_batch(self, spec_info: "EagleDraftInput"):
747
+ if self.future_indices is not None:
748
+ assert spec_info.future_indices is not None
749
+ self.future_indices = FutureIndices(
750
+ indices=torch.cat(
751
+ [self.future_indices.indices, spec_info.future_indices.indices]
752
+ )
753
+ )
754
+ self.allocate_lens = torch.cat(
755
+ [self.allocate_lens, spec_info.allocate_lens]
756
+ )
757
+ return
758
+
759
+ if self.hidden_states is None:
760
+ self.hidden_states = spec_info.hidden_states
761
+ self.verified_id = spec_info.verified_id
762
+ self.topk_p = spec_info.topk_p
763
+ self.topk_index = spec_info.topk_index
764
+ return
765
+ if spec_info.hidden_states is None:
766
+ return
767
+ self.hidden_states = torch.cat(
768
+ [self.hidden_states, spec_info.hidden_states], axis=0
769
+ )
770
+ self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
771
+ self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
772
+ self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
773
+
774
+
775
+ @dataclass
776
+ class EagleVerifyOutput:
777
+ # Draft input batch
778
+ draft_input: EagleDraftInput
779
+ # Logit outputs from target worker
780
+ logits_output: LogitsProcessorOutput
781
+ # Accepted token ids including the bonus token
782
+ verified_id: torch.Tensor
783
+ # Accepted token length per sequence in a batch in CPU.
784
+ accept_length_per_req_cpu: List[int]
785
+ # Accepted indices from logits_output.next_token_logits
786
+ accepted_indices: torch.Tensor