sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from copy import copy
3
3
  from dataclasses import dataclass
4
- from typing import List, Optional, Tuple
4
+ from typing import ClassVar, List, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -10,23 +10,30 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
10
10
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
11
11
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
12
12
  from sglang.srt.layers.sampler import apply_custom_logit_processor
13
- from sglang.srt.managers.schedule_batch import (
14
- ScheduleBatch,
13
+ from sglang.srt.managers.overlap_utils import FutureIndices
14
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
15
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
16
+ from sglang.srt.mem_cache.common import (
17
+ alloc_paged_token_slots_extend,
18
+ alloc_token_slots,
15
19
  get_last_loc,
16
- global_server_args_dict,
17
20
  )
18
- from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
19
21
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
22
+ from sglang.srt.server_args import get_global_server_args
23
+ from sglang.srt.speculative.eagle_info_v2 import (
24
+ EagleDraftInputV2Mixin,
25
+ EagleVerifyInputV2Mixin,
26
+ )
20
27
  from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
21
28
  from sglang.srt.speculative.spec_utils import (
22
29
  SIMULATE_ACC_LEN,
23
30
  TREE_SPEC_KERNEL_AVAILABLE,
24
- _generate_simulated_accept_index,
25
31
  align_evict_mask_to_page_size,
26
32
  assign_req_to_token_pool,
27
33
  create_accept_length_filter,
28
34
  create_extend_after_decode_spec_info,
29
35
  filter_finished_cache_loc_kernel,
36
+ generate_simulated_accept_index,
30
37
  get_src_tgt_cache_loc,
31
38
  get_target_cache_loc,
32
39
  )
@@ -46,7 +53,7 @@ logger = logging.getLogger(__name__)
46
53
 
47
54
 
48
55
  @dataclass
49
- class EagleVerifyInput(SpecInput):
56
+ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
50
57
  draft_token: torch.Tensor
51
58
  custom_mask: torch.Tensor
52
59
  positions: torch.Tensor
@@ -100,7 +107,10 @@ class EagleVerifyInput(SpecInput):
100
107
  batch.input_ids = self.draft_token
101
108
 
102
109
  if page_size == 1:
103
- batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
110
+ batch.out_cache_loc = alloc_token_slots(
111
+ batch.tree_cache,
112
+ len(batch.input_ids),
113
+ )
104
114
  end_offset = batch.seq_lens + self.draft_token_num
105
115
  else:
106
116
  prefix_lens = batch.seq_lens
@@ -112,7 +122,8 @@ class EagleVerifyInput(SpecInput):
112
122
  batch.req_pool_indices,
113
123
  prefix_lens,
114
124
  )
115
- batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
125
+ batch.out_cache_loc = alloc_paged_token_slots_extend(
126
+ batch.tree_cache,
116
127
  prefix_lens,
117
128
  prefix_lens_cpu,
118
129
  end_offset,
@@ -235,7 +246,10 @@ class EagleVerifyInput(SpecInput):
235
246
  )
236
247
 
237
248
  # Apply penalty
238
- if sampling_info.penalizer_orchestrator.is_required:
249
+ if (
250
+ sampling_info.penalizer_orchestrator.is_required
251
+ or sampling_info.logit_bias is not None
252
+ ):
239
253
  # This is a relaxed version of penalties for speculative decoding.
240
254
  linear_penalty = torch.zeros(
241
255
  (bs, logits_output.next_token_logits.shape[1]),
@@ -322,18 +336,14 @@ class EagleVerifyInput(SpecInput):
322
336
  uniform_samples_for_final_sampling=coins_for_final_sampling,
323
337
  target_probs=target_probs,
324
338
  draft_probs=draft_probs,
325
- threshold_single=global_server_args_dict[
326
- "speculative_accept_threshold_single"
327
- ],
328
- threshold_acc=global_server_args_dict[
329
- "speculative_accept_threshold_acc"
330
- ],
339
+ threshold_single=get_global_server_args().speculative_accept_threshold_single,
340
+ threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
331
341
  deterministic=True,
332
342
  )
333
343
 
334
344
  if SIMULATE_ACC_LEN > 0.0:
335
345
  # Do simulation
336
- accept_index = _generate_simulated_accept_index(
346
+ accept_index = generate_simulated_accept_index(
337
347
  accept_index=accept_index,
338
348
  predict=predict, # mutable
339
349
  accept_length=accept_length, # mutable
@@ -377,6 +387,9 @@ class EagleVerifyInput(SpecInput):
377
387
  else:
378
388
  unfinished_accept_index.append(accept_index[i])
379
389
  req.spec_verify_ct += 1
390
+ req.spec_accepted_tokens += (
391
+ sum(1 for idx in accept_index_row if idx != -1) - 1
392
+ )
380
393
 
381
394
  if has_finished:
382
395
  accept_length = (accept_index != -1).sum(dim=1) - 1
@@ -563,7 +576,10 @@ class EagleVerifyInput(SpecInput):
563
576
 
564
577
 
565
578
  @dataclass
566
- class EagleDraftInput(SpecInput):
579
+ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
580
+ # Constant: alloc length per decode step
581
+ ALLOC_LEN_PER_DECODE: ClassVar[int] = None
582
+
567
583
  # The inputs for decode
568
584
  # shape: (b, topk)
569
585
  topk_p: torch.Tensor = None
@@ -593,6 +609,12 @@ class EagleDraftInput(SpecInput):
593
609
  seq_lens_for_draft_extend_cpu: torch.Tensor = None
594
610
  req_pool_indices_for_draft_extend: torch.Tensor = None
595
611
 
612
+ # Inputs for V2 overlap worker
613
+ future_indices: Optional[FutureIndices] = None
614
+ allocate_lens: Optional[torch.Tensor] = None
615
+ new_seq_lens: Optional[torch.Tensor] = None
616
+ verify_done: Optional[torch.cuda.Event] = None
617
+
596
618
  def __post_init__(self):
597
619
  super().__init__(SpecInputType.EAGLE_DRAFT)
598
620
 
@@ -698,6 +720,11 @@ class EagleDraftInput(SpecInput):
698
720
  return kv_indices, cum_kv_seq_len, qo_indptr, None
699
721
 
700
722
  def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
723
+ if self.future_indices is not None:
724
+ self.future_indices.indices = self.future_indices.indices[new_indices]
725
+ self.allocate_lens = self.allocate_lens[new_indices]
726
+ return
727
+
701
728
  if has_been_filtered:
702
729
  # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
703
730
  # therefore, we don't need to filter the batch again in scheduler
@@ -717,6 +744,18 @@ class EagleDraftInput(SpecInput):
717
744
  self.verified_id = self.verified_id[new_indices]
718
745
 
719
746
  def merge_batch(self, spec_info: "EagleDraftInput"):
747
+ if self.future_indices is not None:
748
+ assert spec_info.future_indices is not None
749
+ self.future_indices = FutureIndices(
750
+ indices=torch.cat(
751
+ [self.future_indices.indices, spec_info.future_indices.indices]
752
+ )
753
+ )
754
+ self.allocate_lens = torch.cat(
755
+ [self.allocate_lens, spec_info.allocate_lens]
756
+ )
757
+ return
758
+
720
759
  if self.hidden_states is None:
721
760
  self.hidden_states = spec_info.hidden_states
722
761
  self.verified_id = spec_info.verified_id
@@ -0,0 +1,458 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
12
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
13
+ from sglang.srt.mem_cache.common import (
14
+ alloc_paged_token_slots_extend,
15
+ alloc_token_slots,
16
+ get_last_loc,
17
+ )
18
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
19
+ from sglang.srt.model_executor.forward_batch_info import (
20
+ CaptureHiddenMode,
21
+ ForwardBatch,
22
+ ForwardMode,
23
+ )
24
+ from sglang.srt.model_executor.model_runner import ModelRunner
25
+ from sglang.srt.server_args import get_global_server_args
26
+ from sglang.srt.speculative.spec_utils import (
27
+ SIMULATE_ACC_LEN,
28
+ generate_simulated_accept_index,
29
+ )
30
+ from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
31
+
32
+ if TYPE_CHECKING:
33
+ from sglang.srt.managers.tp_worker import TpModelWorker
34
+ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
35
+ EAGLEDraftCudaGraphRunner,
36
+ )
37
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
38
+
39
+ if is_cuda():
40
+ from sgl_kernel import (
41
+ top_k_renorm_prob,
42
+ top_p_renorm_prob,
43
+ tree_speculative_sampling_target_only,
44
+ verify_tree_greedy,
45
+ )
46
+ from sgl_kernel.top_k import fast_topk
47
+ elif is_hip():
48
+ from sgl_kernel import verify_tree_greedy
49
+
50
+
51
+ @triton.jit
52
+ def assign_draft_cache_locs_page_size_1(
53
+ req_pool_indices,
54
+ req_to_token,
55
+ seq_lens,
56
+ out_cache_loc,
57
+ pool_len: tl.constexpr,
58
+ topk: tl.constexpr,
59
+ speculative_num_steps: tl.constexpr,
60
+ ):
61
+ BLOCK_SIZE: tl.constexpr = 128
62
+ pid = tl.program_id(axis=0)
63
+
64
+ copy_len = topk * speculative_num_steps
65
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
66
+
67
+ # Copy from req_to_token to out_cache_loc
68
+ kv_start = tl.load(seq_lens + pid)
69
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
70
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
71
+ for i in range(num_loop):
72
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
73
+ mask = copy_offset < copy_len
74
+ data = tl.load(token_pool + kv_start + copy_offset, mask=mask)
75
+ tl.store(out_cache_ptr + copy_offset, data, mask=mask)
76
+
77
+
78
+ @dataclass
79
+ class EagleDraftInputV2Mixin:
80
+ def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
81
+ from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
82
+
83
+ bs = batch.batch_size()
84
+
85
+ # TODO(lsyin): implement over-allocation
86
+ # Now seq_lens and allocate_lens are correct
87
+ batch.maybe_wait_verify_done()
88
+
89
+ page_size = batch.token_to_kv_pool_allocator.page_size
90
+
91
+ if page_size == 1:
92
+ new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
93
+ num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
94
+ out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
95
+ else:
96
+ last_loc = get_last_loc(
97
+ batch.req_to_token_pool.req_to_token,
98
+ batch.req_pool_indices,
99
+ self.allocate_lens,
100
+ )
101
+ new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
102
+ new_allocate_lens_cpu = new_allocate_lens.cpu()
103
+ allocate_lens_cpu = self.allocate_lens.cpu()
104
+ extend_num_tokens = sum(new_allocate_lens_cpu - allocate_lens_cpu).item()
105
+ out_cache_loc = alloc_paged_token_slots_extend(
106
+ batch.tree_cache,
107
+ self.allocate_lens,
108
+ allocate_lens_cpu,
109
+ new_allocate_lens,
110
+ new_allocate_lens_cpu,
111
+ last_loc,
112
+ extend_num_tokens,
113
+ )
114
+
115
+ assign_req_to_token_pool[(bs,)](
116
+ batch.req_pool_indices,
117
+ batch.req_to_token_pool.req_to_token,
118
+ self.allocate_lens,
119
+ new_allocate_lens,
120
+ out_cache_loc,
121
+ batch.req_to_token_pool.req_to_token.shape[1],
122
+ next_power_of_2(bs),
123
+ )
124
+ self.allocate_lens = new_allocate_lens
125
+
126
+ # FIXME(lsyin): make this sync optional
127
+ batch.seq_lens_cpu = batch.seq_lens.cpu()
128
+ batch.seq_lens_sum = batch.seq_lens_cpu.sum().item()
129
+
130
+ def prepare_for_v2_draft(
131
+ self: EagleDraftInput,
132
+ req_to_token_pool: ReqToTokenPool,
133
+ batch: ModelWorkerBatch,
134
+ cuda_graph_runner: EAGLEDraftCudaGraphRunner,
135
+ draft_model_runner: ModelRunner,
136
+ topk: int,
137
+ num_steps: int,
138
+ ):
139
+ bs = len(batch.seq_lens)
140
+
141
+ # Assign cache locations
142
+ batch.out_cache_loc = torch.empty(
143
+ (bs * topk * num_steps,),
144
+ dtype=torch.int64,
145
+ device=batch.input_ids.device,
146
+ )
147
+ # FIXME(lsyin): align with the default code path
148
+ assign_draft_cache_locs_page_size_1[(bs,)](
149
+ batch.req_pool_indices,
150
+ req_to_token_pool.req_to_token,
151
+ batch.seq_lens,
152
+ batch.out_cache_loc,
153
+ req_to_token_pool.req_to_token.shape[1],
154
+ topk,
155
+ num_steps,
156
+ )
157
+
158
+ # Get a forward batch
159
+ batch.capture_hidden_mode = CaptureHiddenMode.LAST
160
+ self.positions = batch.seq_lens.repeat_interleave(topk, dim=0)
161
+ forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
162
+ can_cuda_graph = cuda_graph_runner and cuda_graph_runner.can_run(forward_batch)
163
+ return forward_batch, can_cuda_graph
164
+
165
+ def prepare_for_extend_to_fill_draft_kvcache(
166
+ self,
167
+ batch: ModelWorkerBatch,
168
+ predict: torch.Tensor,
169
+ num_draft_tokens: int,
170
+ draft_model_runner: Any,
171
+ ):
172
+ seq_lens_cpu_ = batch.seq_lens_cpu
173
+ extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
174
+
175
+ batch.spec_info = self
176
+ batch.input_ids = predict
177
+ batch.seq_lens = batch.seq_lens + num_draft_tokens
178
+ batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
179
+ batch.seq_lens_sum += extend_num_tokens
180
+ batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
181
+ batch.extend_prefix_lens = seq_lens_cpu_.tolist()
182
+ batch.extend_num_tokens = extend_num_tokens
183
+ batch.capture_hidden_mode = CaptureHiddenMode.FULL
184
+ batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
185
+ forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
186
+ draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
187
+ return forward_batch
188
+
189
+
190
+ @dataclass
191
+ class EagleVerifyInputV2Mixin:
192
+ def prepare_for_v2_verify(
193
+ self: EagleVerifyInput,
194
+ req_to_token_pool: ReqToTokenPool,
195
+ batch: ModelWorkerBatch,
196
+ target_worker: TpModelWorker,
197
+ ):
198
+ # Assign cache locations
199
+ bs = len(batch.req_pool_indices)
200
+ batch.input_ids = self.draft_token
201
+ device = batch.input_ids.device
202
+ batch.out_cache_loc = torch.empty(
203
+ (bs * self.draft_token_num,),
204
+ dtype=torch.int64,
205
+ device=device,
206
+ )
207
+
208
+ assign_extend_cache_locs[(bs,)](
209
+ batch.req_pool_indices,
210
+ req_to_token_pool.req_to_token,
211
+ batch.seq_lens,
212
+ batch.seq_lens + self.draft_token_num,
213
+ batch.out_cache_loc,
214
+ req_to_token_pool.req_to_token.shape[1],
215
+ next_power_of_2(bs),
216
+ )
217
+
218
+ # Get a forward batch
219
+ batch.forward_mode = ForwardMode.TARGET_VERIFY
220
+ batch.capture_hidden_mode = CaptureHiddenMode.FULL
221
+ verify_forward_batch = ForwardBatch.init_new(batch, target_worker.model_runner)
222
+
223
+ # Run attention backend plan and cuda graph preparation
224
+ can_run_cuda_graph = bool(
225
+ target_worker.model_runner.graph_runner
226
+ and target_worker.model_runner.graph_runner.can_run(verify_forward_batch)
227
+ )
228
+ if can_run_cuda_graph:
229
+ target_worker.model_runner.graph_runner.replay_prepare(verify_forward_batch)
230
+ else:
231
+ target_worker.model_runner.attn_backend.init_forward_metadata(
232
+ verify_forward_batch
233
+ )
234
+
235
+ return verify_forward_batch, can_run_cuda_graph
236
+
237
+ def sample(
238
+ self: EagleVerifyInput,
239
+ batch: ModelWorkerBatch,
240
+ logits_output: LogitsProcessorOutput,
241
+ ):
242
+ """
243
+ Verify and find accepted tokens based on logits output and batch
244
+ (which contains spec decoding information).
245
+ """
246
+ bs = len(batch.seq_lens)
247
+ sampling_info = batch.sampling_info
248
+ next_token_logits = logits_output.next_token_logits
249
+ device = batch.input_ids.device
250
+
251
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
252
+ predict = torch.zeros(
253
+ (bs * (self.spec_steps + 1),), dtype=torch.int32, device=device
254
+ )
255
+ accept_index = torch.full(
256
+ (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=device
257
+ )
258
+ accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
259
+
260
+ # Sample tokens
261
+ if sampling_info.is_all_greedy:
262
+ target_predict = torch.argmax(next_token_logits, dim=-1)
263
+ target_predict = target_predict.reshape(bs, self.draft_token_num)
264
+
265
+ verify_tree_greedy(
266
+ predicts=predict, # mutable
267
+ accept_index=accept_index, # mutable
268
+ accept_token_num=accept_length, # mutable
269
+ candidates=candidates,
270
+ retrive_index=self.retrive_index,
271
+ retrive_next_token=self.retrive_next_token,
272
+ retrive_next_sibling=self.retrive_next_sibling,
273
+ target_predict=target_predict,
274
+ )
275
+ else:
276
+ # Apply temperature and get target probs
277
+ expanded_temperature = torch.repeat_interleave(
278
+ sampling_info.temperatures, self.draft_token_num, dim=0
279
+ ) # (bs * num_draft_tokens, 1)
280
+
281
+ target_probs = F.softmax(
282
+ next_token_logits / expanded_temperature, dim=-1
283
+ ) # (bs * num_draft_tokens, vocab_size)
284
+ target_probs = top_k_renorm_prob(
285
+ target_probs,
286
+ torch.repeat_interleave(
287
+ sampling_info.top_ks, self.draft_token_num, dim=0
288
+ ),
289
+ ) # (bs * num_draft_tokens, vocab_size)
290
+ target_probs = top_p_renorm_prob(
291
+ target_probs,
292
+ torch.repeat_interleave(
293
+ sampling_info.top_ps, self.draft_token_num, dim=0
294
+ ),
295
+ )
296
+ target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
297
+
298
+ # This is currently not used
299
+ draft_probs = torch.empty_like(target_probs)
300
+
301
+ # coins for rejection sampling
302
+ coins = torch.rand_like(candidates, dtype=torch.float32, device=device)
303
+ # coins for final sampling
304
+ coins_for_final_sampling = torch.rand(
305
+ (bs,), dtype=torch.float32, device=device
306
+ )
307
+
308
+ tree_speculative_sampling_target_only(
309
+ predicts=predict, # mutable
310
+ accept_index=accept_index, # mutable
311
+ accept_token_num=accept_length, # mutable
312
+ candidates=candidates,
313
+ retrive_index=self.retrive_index,
314
+ retrive_next_token=self.retrive_next_token,
315
+ retrive_next_sibling=self.retrive_next_sibling,
316
+ uniform_samples=coins,
317
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
318
+ target_probs=target_probs,
319
+ draft_probs=draft_probs,
320
+ threshold_single=get_global_server_args().speculative_accept_threshold_single,
321
+ threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
322
+ deterministic=True,
323
+ )
324
+
325
+ if SIMULATE_ACC_LEN > 0:
326
+ # Do simulation
327
+ accept_index = generate_simulated_accept_index(
328
+ accept_index=accept_index,
329
+ predict=predict, # mutable
330
+ accept_length=accept_length, # mutable
331
+ simulate_acc_len=SIMULATE_ACC_LEN,
332
+ bs=bs,
333
+ spec_steps=self.spec_steps,
334
+ )
335
+
336
+ # Include the bonus token
337
+ accept_length.add_(1)
338
+ return predict, accept_length, accept_index
339
+
340
+
341
+ @torch.compile(dynamic=True)
342
+ def select_top_k_tokens_tmp(
343
+ i: int,
344
+ topk_p: torch.Tensor,
345
+ topk_index: torch.Tensor,
346
+ hidden_states: torch.Tensor,
347
+ scores: torch.Tensor,
348
+ topk: int,
349
+ ):
350
+ # FIXME(lsyin): remove this duplicate code
351
+ if i == 0:
352
+ # The first step after extend
353
+ input_ids = topk_index.flatten()
354
+ hidden_states = hidden_states.repeat_interleave(topk, dim=0)
355
+ scores = topk_p # shape: (b, topk)
356
+
357
+ tree_info = (
358
+ topk_p.unsqueeze(1), # shape: (b, 1, topk)
359
+ topk_index, # shape: (b, topk)
360
+ torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
361
+ .unsqueeze(0)
362
+ .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
363
+ )
364
+ else:
365
+ # The later decode steps
366
+ expand_scores = torch.mul(
367
+ scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
368
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
369
+ topk_cs_p, topk_cs_index = fast_topk(
370
+ expand_scores.flatten(start_dim=1), topk, dim=-1
371
+ ) # (b, topk)
372
+ scores = topk_cs_p # shape: (b, topk)
373
+
374
+ topk_index = topk_index.reshape(-1, topk**2)
375
+ input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
376
+
377
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
378
+ 0, hidden_states.shape[0], step=topk, device=hidden_states.device
379
+ ).repeat_interleave(topk)
380
+ hidden_states = hidden_states[selected_input_index, :]
381
+
382
+ tree_info = (
383
+ expand_scores, # shape: (b, topk, topk)
384
+ topk_index, # shape: (b, topk * topk)
385
+ topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
386
+ )
387
+
388
+ return input_ids, hidden_states, scores, tree_info
389
+
390
+
391
+ @triton.jit
392
+ def fill_new_verified_id(
393
+ verified_id,
394
+ accept_lens,
395
+ new_verified_id,
396
+ num_draft_tokens: tl.constexpr,
397
+ ):
398
+ # NOTE: we cannot fuse any in-place operations of `accept_lens` inside this kernel
399
+ # because this kernel reads accept_lens
400
+ pid = tl.program_id(axis=0)
401
+ accept_length = tl.load(accept_lens + pid)
402
+
403
+ verified_id_idx = num_draft_tokens * pid + accept_length - 1
404
+ verified_id_data = tl.load(verified_id + verified_id_idx)
405
+ tl.store(new_verified_id + pid, verified_id_data)
406
+
407
+
408
+ @triton.jit
409
+ def fill_accepted_out_cache_loc(
410
+ accept_index,
411
+ out_cache_loc,
412
+ accepted_out_cache_loc,
413
+ size_upper: tl.constexpr,
414
+ ):
415
+ pid = tl.program_id(axis=0)
416
+ offset = tl.arange(0, size_upper)
417
+
418
+ masks = (tl.load(accept_index + offset, offset < pid, other=-1) != -1).to(tl.int64)
419
+ dst = tl.sum(masks)
420
+ src = tl.load(accept_index + pid)
421
+ if src > -1:
422
+ value = tl.load(out_cache_loc + src)
423
+ tl.store(accepted_out_cache_loc + dst, value)
424
+
425
+
426
+ @triton.jit
427
+ def assign_extend_cache_locs(
428
+ req_pool_indices,
429
+ req_to_token,
430
+ start_offset,
431
+ end_offset,
432
+ out_cache_loc,
433
+ pool_len: tl.constexpr,
434
+ bs_upper: tl.constexpr,
435
+ ):
436
+ BLOCK_SIZE: tl.constexpr = 32
437
+ pid = tl.program_id(axis=0)
438
+ kv_start = tl.load(start_offset + pid)
439
+ kv_end = tl.load(end_offset + pid)
440
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
441
+
442
+ length_offset = tl.arange(0, bs_upper)
443
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
444
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
445
+ out_offset = tl.sum(end - start, axis=0)
446
+
447
+ out_cache_ptr = out_cache_loc + out_offset
448
+
449
+ load_offset = tl.arange(0, BLOCK_SIZE) + kv_start
450
+ save_offset = tl.arange(0, BLOCK_SIZE)
451
+
452
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
453
+ for _ in range(num_loop):
454
+ mask = load_offset < kv_end
455
+ data = tl.load(token_pool + load_offset, mask=mask)
456
+ tl.store(out_cache_ptr + save_offset, data, mask=mask)
457
+ load_offset += BLOCK_SIZE
458
+ save_offset += BLOCK_SIZE