sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,702 @@
1
+ import contextlib
2
+ import logging
3
+ import time
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch.cuda import Stream as CudaStream
8
+
9
+ from sglang.srt.environ import envs
10
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
11
+ from sglang.srt.managers.scheduler import GenerationBatchResult
12
+ from sglang.srt.managers.tp_worker import TpModelWorker
13
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
14
+ from sglang.srt.server_args import ServerArgs
15
+ from sglang.srt.speculative.base_spec_worker import BaseDraftWorker, BaseSpecWorker
16
+ from sglang.srt.speculative.draft_utils import DraftBackendFactory
17
+ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
18
+ EAGLEDraftCudaGraphRunner,
19
+ )
20
+ from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
21
+ EAGLEDraftExtendCudaGraphRunner,
22
+ )
23
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
24
+ from sglang.srt.speculative.eagle_info_v2 import (
25
+ assign_extend_cache_locs,
26
+ fill_accepted_out_cache_loc,
27
+ fill_new_verified_id,
28
+ select_top_k_tokens_tmp,
29
+ )
30
+ from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient
31
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
32
+ from sglang.srt.speculative.spec_utils import (
33
+ detect_nan,
34
+ draft_tp_context,
35
+ load_token_map,
36
+ )
37
+ from sglang.srt.utils.common import (
38
+ empty_context,
39
+ fast_topk,
40
+ get_available_gpu_memory,
41
+ next_power_of_2,
42
+ )
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ def _get_plan_stream(
48
+ device: str,
49
+ ) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
50
+ if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
51
+ plan_stream: CudaStream = torch.get_device_module(device).Stream()
52
+ plan_stream_ctx = torch.cuda.stream(plan_stream)
53
+ return plan_stream, plan_stream_ctx
54
+ else:
55
+ return None, contextlib.nullcontext()
56
+
57
+
58
+ class EagleDraftWorker(BaseDraftWorker):
59
+ def __init__(
60
+ self,
61
+ server_args: ServerArgs,
62
+ gpu_id: int,
63
+ tp_rank: int,
64
+ dp_rank: int,
65
+ moe_ep_rank: int,
66
+ nccl_port: int,
67
+ target_worker: TpModelWorker,
68
+ ):
69
+ # copy args
70
+ self.server_args = server_args
71
+ self.gpu_id = gpu_id
72
+ self.tp_rank = tp_rank
73
+ self.dp_rank = dp_rank
74
+ self.moe_ep_rank = moe_ep_rank
75
+ self.nccl_port = nccl_port
76
+ self.target_worker = target_worker
77
+
78
+ # Args for easy access
79
+ self.device = server_args.device
80
+ self.topk = server_args.speculative_eagle_topk
81
+ self.speculative_num_steps = server_args.speculative_num_steps
82
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
83
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
84
+ server_args.speculative_algorithm
85
+ )
86
+
87
+ # Set constant
88
+ EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
89
+ self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
90
+ )
91
+
92
+ # Do not capture cuda graph in `TpModelWorker` init,
93
+ # will capture later with init_cuda_graphs()
94
+ backup_disable_cuda_graph = server_args.disable_cuda_graph
95
+ server_args.disable_cuda_graph = True
96
+
97
+ # Share the allocator with a target worker.
98
+ # Draft and target worker own their own KV cache pools.
99
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
100
+ target_worker.get_memory_pool()
101
+ )
102
+ with empty_context():
103
+ # Init draft worker
104
+ self.draft_worker = TpModelWorker(
105
+ server_args=server_args,
106
+ gpu_id=gpu_id,
107
+ tp_rank=tp_rank,
108
+ pp_rank=0, # FIXME
109
+ dp_rank=dp_rank,
110
+ moe_ep_rank=moe_ep_rank,
111
+ nccl_port=nccl_port,
112
+ is_draft_worker=True,
113
+ req_to_token_pool=self.req_to_token_pool,
114
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
115
+ )
116
+
117
+ # Alias for better readability
118
+ self.draft_runner = self.draft_worker.model_runner
119
+
120
+ self.init_token_map()
121
+ self.init_lm_head()
122
+
123
+ # Init attention backend and cuda graphs
124
+ self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
125
+ self.draft_tp_context = (
126
+ draft_tp_context if server_args.enable_dp_attention else empty_context
127
+ )
128
+ with self.draft_tp_context(self.draft_runner.tp_group):
129
+ self.init_attention_backend()
130
+ self.init_cuda_graphs()
131
+
132
+ self.tree_mask_mode = TreeMaskMode.FULL_MASK
133
+
134
+ self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
135
+
136
+ def init_token_map(self):
137
+ # Load hot token ids
138
+ if self.speculative_algorithm.is_eagle3():
139
+ if self.server_args.speculative_token_map is not None:
140
+ logger.warning(
141
+ "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
142
+ )
143
+ self.hot_token_id = None
144
+ elif self.server_args.speculative_token_map is not None:
145
+ self.hot_token_id = load_token_map(self.server_args.speculative_token_map)
146
+ self.server_args.json_model_override_args = (
147
+ f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
148
+ )
149
+ else:
150
+ self.hot_token_id = None
151
+
152
+ def init_lm_head(self):
153
+ embed, head = self.target_worker.model_runner.model.get_embed_and_head()
154
+ if self.speculative_algorithm.is_eagle3():
155
+ # most cases EAGLE3 models don't share lm_head
156
+ # but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
157
+ if (
158
+ hasattr(self.draft_runner.model, "load_lm_head_from_target")
159
+ and self.draft_runner.model.load_lm_head_from_target
160
+ ):
161
+ self.draft_runner.model.set_embed_and_head(embed, head)
162
+ else:
163
+ self.draft_runner.model.set_embed(embed)
164
+
165
+ # grab hot token ids
166
+ if self.draft_runner.model.hot_token_id is not None:
167
+ self.hot_token_id = self.draft_runner.model.hot_token_id.to(
168
+ embed.device
169
+ )
170
+
171
+ else:
172
+ if self.hot_token_id is not None:
173
+ head = head.clone()
174
+ self.hot_token_id = self.hot_token_id.to(head.device)
175
+ head.data = head.data[self.hot_token_id]
176
+
177
+ # Share the embedding and lm_head
178
+ self.draft_runner.model.set_embed_and_head(embed, head)
179
+
180
+ def init_attention_backend(self):
181
+ # Create multi-step attn backends and cuda graph runners
182
+
183
+ self.has_prefill_wrapper_verify = False
184
+ self.draft_extend_attn_backend = None
185
+
186
+ draft_backend_factory = DraftBackendFactory(
187
+ self.server_args,
188
+ self.draft_runner,
189
+ self.topk,
190
+ self.speculative_num_steps,
191
+ )
192
+
193
+ # Initialize decode attention backend
194
+ self.draft_attn_backend = draft_backend_factory.create_decode_backend()
195
+
196
+ # Initialize draft extend attention backend (respects speculative_attention_mode setting)
197
+ self.draft_extend_attn_backend = (
198
+ draft_backend_factory.create_draft_extend_backend()
199
+ )
200
+
201
+ self.draft_runner.draft_attn_backend = self.draft_attn_backend
202
+ self.tree_mask_mode = TreeMaskMode.FULL_MASK
203
+
204
+ def init_cuda_graphs(self):
205
+ """Capture cuda graphs."""
206
+ self.cuda_graph_runner = None
207
+ self.cuda_graph_runner_for_draft_extend = None
208
+
209
+ if self.server_args.disable_cuda_graph:
210
+ return
211
+
212
+ # Capture draft
213
+ if self.speculative_num_steps > 1:
214
+ tic = time.perf_counter()
215
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
216
+ logger.info(
217
+ f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
218
+ )
219
+ self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
220
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
221
+ logger.info(
222
+ f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
223
+ )
224
+
225
+ # Capture extend
226
+ if self.draft_extend_attn_backend:
227
+ tic = time.perf_counter()
228
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
229
+ logger.info(
230
+ f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
231
+ )
232
+ self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
233
+ self
234
+ )
235
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
236
+ logger.info(
237
+ f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
238
+ )
239
+
240
+ def draft(self, model_worker_batch: ModelWorkerBatch):
241
+ draft_input: EagleDraftInput = model_worker_batch.spec_info
242
+ forward_batch, can_cuda_graph = draft_input.prepare_for_v2_draft(
243
+ self.req_to_token_pool,
244
+ model_worker_batch,
245
+ self.cuda_graph_runner,
246
+ self.draft_runner,
247
+ self.topk,
248
+ self.speculative_num_steps,
249
+ )
250
+
251
+ # Run draft
252
+ if can_cuda_graph:
253
+ parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay(
254
+ forward_batch,
255
+ )
256
+ else:
257
+ if self.speculative_num_steps > 1:
258
+ # Skip attention backend init for 1-step draft,
259
+ # `draft_forward` only does sample in this case.
260
+ self.draft_attn_backend.init_forward_metadata(forward_batch)
261
+ parent_list, top_scores_index, draft_tokens = self.draft_forward(
262
+ forward_batch
263
+ )
264
+
265
+ # Build tree mask
266
+ # Directly write to cuda graph buffers for verify attn
267
+ tree_mask_buf, position_buf = (
268
+ self.target_worker.model_runner.attn_backend.get_verify_buffers_to_fill_after_draft()
269
+ )
270
+
271
+ (
272
+ tree_mask,
273
+ position,
274
+ retrive_index,
275
+ retrive_next_token,
276
+ retrive_next_sibling,
277
+ draft_tokens,
278
+ ) = build_tree_kernel_efficient(
279
+ draft_input.verified_id,
280
+ parent_list,
281
+ top_scores_index,
282
+ draft_tokens,
283
+ model_worker_batch.seq_lens,
284
+ model_worker_batch.seq_lens_sum,
285
+ self.topk,
286
+ self.speculative_num_steps,
287
+ self.speculative_num_draft_tokens,
288
+ self.tree_mask_mode,
289
+ tree_mask_buf,
290
+ position_buf,
291
+ )
292
+
293
+ return EagleVerifyInput(
294
+ draft_token=draft_tokens,
295
+ custom_mask=tree_mask,
296
+ positions=position,
297
+ retrive_index=retrive_index,
298
+ retrive_next_token=retrive_next_token,
299
+ retrive_next_sibling=retrive_next_sibling,
300
+ retrive_cum_len=None,
301
+ spec_steps=self.speculative_num_steps,
302
+ topk=self.topk,
303
+ draft_token_num=self.speculative_num_draft_tokens,
304
+ capture_hidden_mode=None,
305
+ seq_lens_sum=None,
306
+ seq_lens_cpu=None,
307
+ )
308
+
309
+ def draft_forward(self, forward_batch: ForwardBatch):
310
+ # Parse args
311
+ spec_info: EagleDraftInput = forward_batch.spec_info
312
+ out_cache_loc = forward_batch.out_cache_loc
313
+ topk_p, topk_index, hidden_states = (
314
+ spec_info.topk_p,
315
+ spec_info.topk_index,
316
+ spec_info.hidden_states,
317
+ )
318
+ if self.hot_token_id is not None:
319
+ topk_index = self.hot_token_id[topk_index]
320
+
321
+ out_cache_loc = out_cache_loc.reshape(
322
+ forward_batch.batch_size, self.topk, self.speculative_num_steps
323
+ )
324
+ out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
325
+ self.speculative_num_steps, -1
326
+ )
327
+
328
+ # Return values
329
+ score_list: List[torch.Tensor] = []
330
+ token_list: List[torch.Tensor] = []
331
+ parents_list: List[torch.Tensor] = []
332
+
333
+ # Forward multiple steps
334
+ scores = None
335
+ for i in range(self.speculative_num_steps):
336
+ input_ids, hidden_states, scores, tree_info = select_top_k_tokens_tmp(
337
+ i, topk_p, topk_index, hidden_states, scores, self.topk
338
+ )
339
+ score_list.append(tree_info[0])
340
+ token_list.append(tree_info[1])
341
+ parents_list.append(tree_info[2])
342
+
343
+ # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
344
+ if i == self.speculative_num_steps - 1:
345
+ break
346
+
347
+ # Set inputs
348
+ forward_batch.input_ids = input_ids
349
+ forward_batch.out_cache_loc = out_cache_loc[i]
350
+ forward_batch.positions.add_(1)
351
+ forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
352
+ spec_info.hidden_states = hidden_states
353
+
354
+ # Run forward
355
+ logits_output = self.draft_runner.model.forward(
356
+ forward_batch.input_ids, forward_batch.positions, forward_batch
357
+ )
358
+ if self.server_args.enable_nan_detection:
359
+ detect_nan(logits_output)
360
+ probs = torch.softmax(logits_output.next_token_logits, dim=-1)
361
+ topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
362
+ if self.hot_token_id is not None:
363
+ topk_index = self.hot_token_id[topk_index]
364
+ hidden_states = logits_output.hidden_states
365
+
366
+ # Organize the results
367
+ score_list = torch.cat(score_list, dim=1).flatten(
368
+ 1
369
+ ) # b, n, topk; n= 1 + (num_steps-1) * self.topk
370
+ ss_token_list = torch.cat(
371
+ token_list, dim=1
372
+ ) # b, (self.topk + (num_steps-1) * self.topk)
373
+ top_scores = torch.topk(
374
+ score_list, self.speculative_num_draft_tokens - 1, dim=-1
375
+ )
376
+ top_scores_index = top_scores.indices
377
+ top_scores_index = torch.sort(top_scores_index).values
378
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
379
+
380
+ if len(parents_list) > 1:
381
+ parent_list = torch.cat(parents_list[:-1], dim=1)
382
+ else:
383
+ batch_size = parents_list[0].shape[0]
384
+ parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
385
+
386
+ return parent_list, top_scores_index, draft_tokens
387
+
388
+ def draft_extend(self):
389
+ pass
390
+
391
+ def _draft_extend_for_prefill(
392
+ self,
393
+ batch: ModelWorkerBatch,
394
+ target_hidden_states: torch.Tensor,
395
+ next_token_ids: torch.Tensor,
396
+ ):
397
+ """
398
+ Run draft model extend to correctly fill the KV cache.
399
+
400
+ Args:
401
+ batch: The batch to run.
402
+ target_hidden_states: Hidden states from the target model forward
403
+ next_token_ids: Next token ids generated from the target forward.
404
+ """
405
+ # Construct input_ids
406
+ pt = 0
407
+ for i, extend_len in enumerate(batch.extend_seq_lens):
408
+ input_ids = batch.input_ids[pt : pt + extend_len]
409
+ batch.input_ids[pt : pt + extend_len] = torch.cat(
410
+ (input_ids[1:], next_token_ids[i].reshape(1))
411
+ )
412
+ pt += extend_len
413
+
414
+ # Construct spec_info
415
+ next_draft_input = EagleDraftInput(
416
+ hidden_states=target_hidden_states,
417
+ verified_id=next_token_ids,
418
+ new_seq_lens=batch.seq_lens,
419
+ allocate_lens=batch.seq_lens,
420
+ )
421
+ batch.spec_info = next_draft_input
422
+
423
+ # Run forward
424
+ forward_batch = ForwardBatch.init_new(batch, self.draft_runner)
425
+ logits_output, _ = self.draft_runner.forward(forward_batch)
426
+
427
+ # Update spec_info for the next draft step
428
+ probs = torch.softmax(logits_output.next_token_logits, dim=-1)
429
+ next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
430
+ probs, self.topk, dim=-1
431
+ )
432
+ next_draft_input.hidden_states = logits_output.hidden_states
433
+ return next_draft_input
434
+
435
+ def _draft_extend_for_decode(
436
+ self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult
437
+ ):
438
+ # Batch 2: Draft extend
439
+ draft_input = EagleDraftInput(
440
+ hidden_states=batch_result.logits_output.hidden_states,
441
+ )
442
+ select_index = (
443
+ torch.arange(len(batch.seq_lens), device=self.device)
444
+ * self.speculative_num_draft_tokens
445
+ + batch_result.accept_lens
446
+ - 1
447
+ )
448
+
449
+ # Prepare for draft extend in a separate stream
450
+ with self.plan_stream_ctx:
451
+ forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
452
+ batch,
453
+ batch_result.next_token_ids,
454
+ self.speculative_num_draft_tokens,
455
+ self.draft_runner,
456
+ )
457
+
458
+ if self.plan_stream:
459
+ torch.cuda.current_stream().wait_stream(self.plan_stream)
460
+
461
+ # Run draft extend batch in the main compute stream
462
+ draft_logits_output = self.draft_runner.model.forward(
463
+ forward_batch.input_ids, forward_batch.positions, forward_batch
464
+ )
465
+
466
+ # Reorganize the spec info for the next batch
467
+ draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
468
+ select_index
469
+ ]
470
+ draft_logits_output.hidden_states = draft_logits_output.hidden_states[
471
+ select_index
472
+ ]
473
+ probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
474
+ ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
475
+ ret_hidden_states = draft_logits_output.hidden_states
476
+
477
+ # Construct the return values
478
+ next_draft_input = batch_result.next_draft_input
479
+ (
480
+ next_draft_input.topk_p,
481
+ next_draft_input.topk_index,
482
+ next_draft_input.hidden_states,
483
+ ) = (
484
+ ret_topk_p,
485
+ ret_topk_index,
486
+ ret_hidden_states,
487
+ )
488
+
489
+
490
+ class EAGLEWorkerV2(BaseSpecWorker):
491
+ def __init__(
492
+ self,
493
+ server_args: ServerArgs,
494
+ gpu_id: int,
495
+ tp_rank: int,
496
+ dp_rank: Optional[int],
497
+ moe_ep_rank: int,
498
+ nccl_port: int,
499
+ target_worker: TpModelWorker,
500
+ ):
501
+ # Parse arguments
502
+ self.server_args = server_args
503
+ self.topk = server_args.speculative_eagle_topk
504
+ self.speculative_num_steps = server_args.speculative_num_steps
505
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
506
+ self.enable_nan_detection = server_args.enable_nan_detection
507
+ self.gpu_id = gpu_id
508
+ self.device = server_args.device
509
+ self._target_worker = target_worker
510
+ self.page_size = server_args.page_size
511
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
512
+ server_args.speculative_algorithm
513
+ )
514
+
515
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
516
+ target_worker.get_memory_pool()
517
+ )
518
+
519
+ # Override the context length of the draft model to be the same as the target model.
520
+ server_args.context_length = target_worker.model_runner.model_config.context_len
521
+
522
+ self._draft_worker = EagleDraftWorker(
523
+ server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker
524
+ )
525
+
526
+ # Some dummy tensors
527
+ self.num_new_pages_per_topk = torch.empty(
528
+ (), dtype=torch.int64, device=self.device
529
+ )
530
+ self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
531
+
532
+ self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
533
+
534
+ @property
535
+ def target_worker(self):
536
+ return self._target_worker
537
+
538
+ @property
539
+ def draft_worker(self):
540
+ return self._draft_worker
541
+
542
+ def clear_cache_pool(self):
543
+ # allocator and kv cache pool are shared with target worker, which are cleared in scheduler
544
+ pass
545
+
546
+ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
547
+ if model_worker_batch.forward_mode.is_decode():
548
+ draft_input: EagleDraftInput = model_worker_batch.spec_info
549
+ assert draft_input.is_draft_input()
550
+ verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
551
+ assert verify_input.is_verify_input()
552
+ model_worker_batch.spec_info = verify_input
553
+ batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
554
+ self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
555
+ return batch_output
556
+ else:
557
+ # Target prefill
558
+ model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
559
+ batch_output = self.target_worker.forward_batch_generation(
560
+ model_worker_batch
561
+ )
562
+
563
+ # Draft prefill
564
+ model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
565
+ batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
566
+ model_worker_batch,
567
+ batch_output.logits_output.hidden_states,
568
+ batch_output.next_token_ids,
569
+ )
570
+ return batch_output
571
+
572
+ def verify(
573
+ self,
574
+ batch: ModelWorkerBatch,
575
+ cur_allocate_lens: torch.Tensor,
576
+ ):
577
+ # Since batch.seq_lens is allocated in another stream, we need
578
+ # record_stream() to prevent pytorch gc and reuse the gpu memory
579
+ # while forward_stream is still running.
580
+ batch.seq_lens.record_stream(torch.cuda.current_stream())
581
+
582
+ # Parse args
583
+ verify_input: EagleVerifyInput = batch.spec_info
584
+ bs = len(batch.seq_lens)
585
+
586
+ # Batch 1: Target verify
587
+ # Prepare for target verify in a separate stream
588
+ with self.plan_stream_ctx:
589
+ verify_forward_batch, can_run_cuda_graph = (
590
+ verify_input.prepare_for_v2_verify(
591
+ self.req_to_token_pool,
592
+ batch,
593
+ self.target_worker,
594
+ )
595
+ )
596
+
597
+ # Correct some buffers due to the overlap plan
598
+ if self.plan_stream:
599
+ torch.cuda.current_stream().wait_stream(self.plan_stream)
600
+
601
+ # Some values such as custom_mask and position depend on the output of draft,
602
+ # so the previous plan step used the wrong values. Here, we need to run the related
603
+ # computation again to update them to the correct values.
604
+ self.target_worker.model_runner.attn_backend.update_verify_buffers_to_fill_after_draft(
605
+ verify_input,
606
+ (
607
+ self.target_worker.model_runner.graph_runner.bs
608
+ if can_run_cuda_graph
609
+ else None
610
+ ),
611
+ )
612
+
613
+ # Run target verify batch in the main compute stream
614
+ forward_batch_output = self.target_worker.forward_batch_generation(
615
+ model_worker_batch=None,
616
+ forward_batch=verify_forward_batch,
617
+ is_verify=True,
618
+ skip_attn_backend_init=True,
619
+ )
620
+ logits_output = forward_batch_output.logits_output
621
+
622
+ # Sample
623
+ if self.enable_nan_detection:
624
+ detect_nan(logits_output)
625
+ (
626
+ predict,
627
+ accept_length,
628
+ accept_index,
629
+ ) = verify_input.sample(batch, logits_output)
630
+ new_seq_lens = batch.seq_lens + accept_length
631
+ verify_done = torch.cuda.Event()
632
+ verify_done.record()
633
+
634
+ all_verified_id = predict[accept_index]
635
+ verified_id = torch.empty_like(accept_length, dtype=torch.int32)
636
+ fill_new_verified_id[(bs,)](
637
+ all_verified_id,
638
+ accept_length,
639
+ verified_id,
640
+ self.speculative_num_draft_tokens,
641
+ )
642
+
643
+ # Construct the next draft input
644
+ next_draft_input = EagleDraftInput(
645
+ verified_id=verified_id,
646
+ new_seq_lens=new_seq_lens,
647
+ allocate_lens=cur_allocate_lens,
648
+ verify_done=verify_done,
649
+ )
650
+
651
+ return GenerationBatchResult(
652
+ logits_output=logits_output,
653
+ next_token_ids=predict,
654
+ can_run_cuda_graph=can_run_cuda_graph,
655
+ next_draft_input=next_draft_input,
656
+ accept_lens=accept_length,
657
+ allocate_lens=cur_allocate_lens,
658
+ )
659
+
660
+ def move_accepted_tokens_to_target_kvcache(
661
+ self,
662
+ batch: ModelWorkerBatch,
663
+ accept_index: torch.Tensor,
664
+ accept_length: torch.Tensor,
665
+ ):
666
+ """
667
+ Move accepted tokens to the target KV cache.
668
+
669
+ Args:
670
+ batch: The batch to run.
671
+ accept_index: The index of the accepted tokens.
672
+ accept_length: The length of the accepted tokens.
673
+ """
674
+ bs = len(batch.seq_lens)
675
+ size = bs * self.speculative_num_draft_tokens
676
+
677
+ tgt_cache_loc = torch.zeros(
678
+ size,
679
+ dtype=torch.int64,
680
+ device=self.device,
681
+ )
682
+ accepted_out_cache_loc = torch.zeros(
683
+ size, dtype=torch.int64, device=self.device
684
+ )
685
+ assign_extend_cache_locs[(bs,)](
686
+ batch.req_pool_indices,
687
+ self.req_to_token_pool.req_to_token,
688
+ batch.seq_lens,
689
+ batch.seq_lens + accept_length,
690
+ tgt_cache_loc,
691
+ self.req_to_token_pool.req_to_token.shape[1],
692
+ next_power_of_2(bs),
693
+ )
694
+ fill_accepted_out_cache_loc[(size,)](
695
+ accept_index,
696
+ batch.out_cache_loc,
697
+ accepted_out_cache_loc,
698
+ next_power_of_2(size),
699
+ )
700
+ self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
701
+ tgt_cache_loc, accepted_out_cache_loc
702
+ )