sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
1
+ import math
2
+ from enum import IntEnum
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+
7
+ from sglang.srt.utils import is_cuda, is_hip
8
+
9
+ if is_cuda() or is_hip():
10
+ from sgl_kernel import (
11
+ build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
12
+ )
13
+
14
+
15
+ def organize_draft_results(
16
+ score_list: List[torch.Tensor],
17
+ token_list: List[torch.Tensor],
18
+ parents_list: List[torch.Tensor],
19
+ num_draft_token: int,
20
+ ):
21
+ score_list = torch.cat(score_list, dim=1).flatten(1)
22
+ ss_token_list = torch.cat(token_list, dim=1)
23
+ top_scores = torch.topk(score_list, num_draft_token - 1, dim=-1)
24
+ top_scores_index = top_scores.indices
25
+ top_scores_index = torch.sort(top_scores_index).values
26
+ draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
27
+
28
+ if len(parents_list) > 1:
29
+ parent_list = torch.cat(parents_list[:-1], dim=1)
30
+ else:
31
+ batch_size = parents_list[0].shape[0]
32
+ parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
33
+
34
+ return parent_list, top_scores_index, draft_tokens
35
+
36
+
37
+ class TreeMaskMode(IntEnum):
38
+ FULL_MASK = 0
39
+ QLEN_ONLY = 1
40
+ QLEN_ONLY_BITPACKING = 2
41
+
42
+
43
+ def build_tree_kernel_efficient(
44
+ verified_id: torch.Tensor,
45
+ parent_list: List[torch.Tensor],
46
+ top_scores_index: torch.Tensor,
47
+ draft_tokens: torch.Tensor,
48
+ seq_lens: torch.Tensor,
49
+ seq_lens_sum: int,
50
+ topk: int,
51
+ spec_steps: int,
52
+ num_verify_tokens: int,
53
+ tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
54
+ tree_mask_buf: Optional[torch.Tensor] = None,
55
+ position_buf: Optional[torch.Tensor] = None,
56
+ ):
57
+ draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
58
+
59
+ # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
60
+ bs = seq_lens.numel()
61
+ device = seq_lens.device
62
+ # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
63
+ # where each row indicates the attending pattern of each draft token
64
+ # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
65
+ if tree_mask_buf is not None:
66
+ tree_mask = tree_mask_buf
67
+ if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
68
+ tree_mask.fill_(True)
69
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
70
+ tree_mask.fill_(0)
71
+ elif tree_mask_mode == TreeMaskMode.FULL_MASK:
72
+ tree_mask.fill_(True)
73
+ else:
74
+ raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
75
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
76
+ tree_mask = torch.full(
77
+ (num_verify_tokens * bs * num_verify_tokens,),
78
+ True,
79
+ dtype=torch.bool,
80
+ device=device,
81
+ )
82
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
83
+ packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
84
+ packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
85
+ tree_mask = torch.zeros(
86
+ (num_verify_tokens * bs,),
87
+ dtype=packed_dtypes[packed_dtype_idx],
88
+ device=device,
89
+ )
90
+ elif tree_mask_mode == TreeMaskMode.FULL_MASK:
91
+ tree_mask = torch.full(
92
+ (
93
+ seq_lens_sum * num_verify_tokens
94
+ + num_verify_tokens * num_verify_tokens * bs,
95
+ ),
96
+ True,
97
+ device=device,
98
+ )
99
+ else:
100
+ raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
101
+
102
+ # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
103
+ retrive_buf = torch.full(
104
+ (3, bs, num_verify_tokens), -1, device=device, dtype=torch.long
105
+ )
106
+ retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf
107
+ # position: where each token belongs to
108
+ # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
109
+ # then, positions = [7, 8, 8, 9]
110
+ if position_buf is not None:
111
+ positions = position_buf
112
+ else:
113
+ positions = torch.empty(
114
+ (bs * num_verify_tokens,), device=device, dtype=torch.long
115
+ )
116
+
117
+ sgl_build_tree_kernel_efficient(
118
+ parent_list,
119
+ top_scores_index,
120
+ seq_lens,
121
+ tree_mask,
122
+ positions,
123
+ retrive_index,
124
+ retrive_next_token,
125
+ retrive_next_sibling,
126
+ topk,
127
+ spec_steps,
128
+ num_verify_tokens,
129
+ tree_mask_mode,
130
+ )
131
+ return (
132
+ tree_mask,
133
+ positions,
134
+ retrive_index,
135
+ retrive_next_token,
136
+ retrive_next_sibling,
137
+ draft_tokens,
138
+ )
@@ -1,33 +1,27 @@
1
1
  import logging
2
- import os
3
2
  import time
4
- from contextlib import contextmanager
5
3
  from typing import List, Optional, Tuple
6
4
 
7
5
  import torch
8
- from huggingface_hub import snapshot_download
9
6
 
10
- from sglang.srt.distributed import (
11
- GroupCoordinator,
12
- get_tp_group,
13
- patch_tensor_parallel_group,
14
- )
7
+ from sglang.srt.distributed import get_tp_group
15
8
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
9
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
17
- from sglang.srt.managers.schedule_batch import (
18
- ScheduleBatch,
10
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
11
+ from sglang.srt.managers.scheduler import GenerationBatchResult
12
+ from sglang.srt.managers.tp_worker import TpModelWorker
13
+ from sglang.srt.mem_cache.common import (
14
+ alloc_paged_token_slots_extend,
15
+ alloc_token_slots,
19
16
  get_last_loc,
20
- global_server_args_dict,
21
17
  )
22
- from sglang.srt.managers.tp_worker import TpModelWorker
23
18
  from sglang.srt.model_executor.forward_batch_info import (
24
19
  CaptureHiddenMode,
25
20
  ForwardBatch,
26
- ForwardBatchOutput,
27
21
  ForwardMode,
28
22
  )
29
23
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
24
+ from sglang.srt.speculative.draft_utils import DraftBackendFactory
31
25
  from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
32
26
  EAGLEDraftCudaGraphRunner,
33
27
  )
@@ -39,35 +33,33 @@ from sglang.srt.speculative.eagle_info import (
39
33
  EagleVerifyInput,
40
34
  EagleVerifyOutput,
41
35
  )
36
+ from sglang.srt.speculative.eagle_utils import (
37
+ build_tree_kernel_efficient,
38
+ organize_draft_results,
39
+ )
42
40
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
43
41
  from sglang.srt.speculative.spec_utils import (
44
42
  assign_draft_cache_locs,
43
+ detect_nan,
44
+ draft_tp_context,
45
45
  fast_topk,
46
46
  generate_token_bitmask,
47
+ load_token_map,
47
48
  select_top_k_tokens,
48
49
  )
49
50
  from sglang.srt.utils import (
50
51
  empty_context,
51
52
  get_available_gpu_memory,
52
53
  get_bool_env_var,
53
- is_blackwell,
54
54
  is_cuda,
55
55
  next_power_of_2,
56
56
  )
57
57
 
58
58
  if is_cuda():
59
- from sgl_kernel import segment_packbits
59
+ from sgl_kernel import segment_packbits # noqa: F401
60
60
 
61
61
  logger = logging.getLogger(__name__)
62
- RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
63
-
64
-
65
- @contextmanager
66
- def draft_tp_context(tp_group: GroupCoordinator):
67
- # Draft model doesn't use dp and has its own tp group.
68
- # We disable mscclpp now because it doesn't support 2 comm groups.
69
- with patch_tensor_parallel_group(tp_group):
70
- yield
62
+ SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
71
63
 
72
64
 
73
65
  class EAGLEWorker(TpModelWorker):
@@ -95,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
95
87
  self.speculative_algorithm = SpeculativeAlgorithm.from_string(
96
88
  server_args.speculative_algorithm
97
89
  )
98
- self.padded_static_len = -1
99
90
 
100
91
  # Override the context length of the draft model to be the same as the target model.
101
92
  server_args.context_length = target_worker.model_runner.model_config.context_len
@@ -187,208 +178,22 @@ class EAGLEWorker(TpModelWorker):
187
178
 
188
179
  def init_attention_backend(self):
189
180
  # Create multi-step attn backends and cuda graph runners
190
-
191
- self.has_prefill_wrapper_verify = False
192
- self.draft_extend_attn_backend = None
181
+ draft_backend_factory = DraftBackendFactory(
182
+ self.server_args,
183
+ self.draft_model_runner,
184
+ self.topk,
185
+ self.speculative_num_steps,
186
+ )
193
187
 
194
188
  # Initialize decode attention backend
195
- self.draft_attn_backend = self._create_decode_backend()
189
+ self.draft_attn_backend = draft_backend_factory.create_decode_backend()
196
190
 
197
191
  # Initialize draft extend attention backend (respects speculative_attention_mode setting)
198
- self.draft_extend_attn_backend = self._create_draft_extend_backend()
199
-
200
- self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
201
-
202
- def _create_backend(
203
- self, backend_name: str, backend_map: dict, error_template: str
204
- ):
205
- backend_type = getattr(self.server_args, backend_name)
206
- if backend_type is None:
207
- backend_type = self.server_args.attention_backend
208
-
209
- if backend_type not in backend_map:
210
- raise ValueError(error_template.format(backend_type=backend_type))
211
-
212
- return backend_map[backend_type]()
213
-
214
- def _create_decode_backend(self):
215
- backend_map = {
216
- "flashinfer": self._create_flashinfer_decode_backend,
217
- "triton": self._create_triton_decode_backend,
218
- "aiter": self._create_aiter_decode_backend,
219
- "fa3": self._create_fa3_decode_backend,
220
- "hybrid_linear_attn": (
221
- self._create_fa3_decode_backend
222
- if not is_blackwell()
223
- else self._create_triton_decode_backend
224
- ),
225
- "flashmla": self._create_flashmla_decode_backend,
226
- "trtllm_mha": self._create_trtllm_mha_decode_backend,
227
- "trtllm_mla": self._create_trtllm_mla_decode_backend,
228
- }
229
-
230
- return self._create_backend(
231
- "decode_attention_backend",
232
- backend_map,
233
- "EAGLE is not supported in decode attention backend {backend_type}",
234
- )
235
-
236
- def _create_draft_extend_backend(self):
237
- backend_map = {
238
- "flashinfer": self._create_flashinfer_prefill_backend,
239
- "triton": self._create_triton_prefill_backend,
240
- "aiter": self._create_aiter_prefill_backend,
241
- "fa3": self._create_fa3_prefill_backend,
242
- "hybrid_linear_attn": (
243
- self._create_fa3_prefill_backend
244
- if not is_blackwell()
245
- else self._create_triton_prefill_backend
246
- ),
247
- "flashmla": self._create_flashmla_prefill_backend,
248
- "trtllm_mha": self._create_trtllm_mha_prefill_backend,
249
- "trtllm_mla": self._create_trtllm_mla_prefill_backend,
250
- }
251
- backend_name = (
252
- "decode_attention_backend"
253
- if self.server_args.speculative_attention_mode == "decode"
254
- else "prefill_attention_backend"
255
- )
256
- return self._create_backend(
257
- backend_name,
258
- backend_map,
259
- "EAGLE is not supported in attention backend {backend_type}",
260
- )
261
-
262
- def _create_flashinfer_decode_backend(self):
263
- if not global_server_args_dict["use_mla_backend"]:
264
- from sglang.srt.layers.attention.flashinfer_backend import (
265
- FlashInferMultiStepDraftBackend,
266
- )
267
-
268
- self.has_prefill_wrapper_verify = True
269
- return FlashInferMultiStepDraftBackend(
270
- self.draft_model_runner, self.topk, self.speculative_num_steps
271
- )
272
- else:
273
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
274
- FlashInferMLAMultiStepDraftBackend,
275
- )
276
-
277
- self.has_prefill_wrapper_verify = True
278
- return FlashInferMLAMultiStepDraftBackend(
279
- self.draft_model_runner, self.topk, self.speculative_num_steps
280
- )
281
-
282
- def _create_triton_decode_backend(self):
283
- from sglang.srt.layers.attention.triton_backend import (
284
- TritonMultiStepDraftBackend,
285
- )
286
-
287
- return TritonMultiStepDraftBackend(
288
- self.draft_model_runner, self.topk, self.speculative_num_steps
289
- )
290
-
291
- def _create_aiter_decode_backend(self):
292
- from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
293
-
294
- return AiterMultiStepDraftBackend(
295
- self.draft_model_runner, self.topk, self.speculative_num_steps
296
- )
297
-
298
- def _create_fa3_decode_backend(self):
299
- from sglang.srt.layers.attention.flashattention_backend import (
300
- FlashAttentionMultiStepBackend,
301
- )
302
-
303
- return FlashAttentionMultiStepBackend(
304
- self.draft_model_runner, self.topk, self.speculative_num_steps
305
- )
306
-
307
- def _create_flashmla_decode_backend(self):
308
- from sglang.srt.layers.attention.flashmla_backend import (
309
- FlashMLAMultiStepDraftBackend,
310
- )
311
-
312
- return FlashMLAMultiStepDraftBackend(
313
- self.draft_model_runner, self.topk, self.speculative_num_steps
314
- )
315
-
316
- def _create_trtllm_mha_decode_backend(self):
317
- from sglang.srt.layers.attention.trtllm_mha_backend import (
318
- TRTLLMHAAttnMultiStepDraftBackend,
319
- )
320
-
321
- self.has_prefill_wrapper_verify = True
322
- return TRTLLMHAAttnMultiStepDraftBackend(
323
- self.draft_model_runner, self.topk, self.speculative_num_steps
324
- )
325
-
326
- def _create_trtllm_mla_decode_backend(self):
327
- if not global_server_args_dict["use_mla_backend"]:
328
- raise ValueError(
329
- "trtllm_mla backend requires MLA model (use_mla_backend=True)."
330
- )
331
-
332
- from sglang.srt.layers.attention.trtllm_mla_backend import (
333
- TRTLLMMLAMultiStepDraftBackend,
334
- )
335
-
336
- self.has_prefill_wrapper_verify = True
337
- return TRTLLMMLAMultiStepDraftBackend(
338
- self.draft_model_runner, self.topk, self.speculative_num_steps
339
- )
340
-
341
- def _create_flashinfer_prefill_backend(self):
342
- if not global_server_args_dict["use_mla_backend"]:
343
- from sglang.srt.layers.attention.flashinfer_backend import (
344
- FlashInferAttnBackend,
345
- )
346
-
347
- return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
348
- else:
349
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
350
- FlashInferMLAAttnBackend,
351
- )
352
-
353
- return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
354
-
355
- def _create_triton_prefill_backend(self):
356
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
357
-
358
- return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
359
-
360
- def _create_aiter_prefill_backend(self):
361
- from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
362
-
363
- return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
364
-
365
- def _create_fa3_prefill_backend(self):
366
- from sglang.srt.layers.attention.flashattention_backend import (
367
- FlashAttentionBackend,
192
+ self.draft_extend_attn_backend = (
193
+ draft_backend_factory.create_draft_extend_backend()
368
194
  )
369
195
 
370
- return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
371
-
372
- def _create_trtllm_mha_prefill_backend(self):
373
- from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
374
-
375
- return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
376
-
377
- def _create_trtllm_mla_prefill_backend(self):
378
- if not global_server_args_dict["use_mla_backend"]:
379
- raise ValueError(
380
- "trtllm_mla backend requires MLA model (use_mla_backend=True)."
381
- )
382
-
383
- from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
384
-
385
- return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
386
-
387
- def _create_flashmla_prefill_backend(self):
388
- logger.warning(
389
- "flashmla prefill backend is not yet supported for draft extend."
390
- )
391
- return None
196
+ self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
392
197
 
393
198
  def init_cuda_graphs(self):
394
199
  """Capture cuda graphs."""
@@ -399,16 +204,17 @@ class EAGLEWorker(TpModelWorker):
399
204
  return
400
205
 
401
206
  # Capture draft
402
- tic = time.perf_counter()
403
- before_mem = get_available_gpu_memory(self.device, self.gpu_id)
404
- logger.info(
405
- f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
406
- )
407
- self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
408
- after_mem = get_available_gpu_memory(self.device, self.gpu_id)
409
- logger.info(
410
- f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
411
- )
207
+ if self.speculative_num_steps > 1:
208
+ tic = time.perf_counter()
209
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
210
+ logger.info(
211
+ f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
212
+ )
213
+ self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
214
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
215
+ logger.info(
216
+ f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
217
+ )
412
218
 
413
219
  # Capture extend
414
220
  if self.draft_extend_attn_backend:
@@ -429,7 +235,7 @@ class EAGLEWorker(TpModelWorker):
429
235
  def draft_model_runner(self):
430
236
  return self.model_runner
431
237
 
432
- def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
238
+ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
433
239
  """Run speculative decoding forward.
434
240
 
435
241
  NOTE: Many states of batch is modified as you go through. It is not guaranteed that
@@ -449,7 +255,7 @@ class EAGLEWorker(TpModelWorker):
449
255
  self.forward_draft_extend(
450
256
  batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
451
257
  )
452
- return ForwardBatchOutput(
258
+ return GenerationBatchResult(
453
259
  logits_output=logits_output,
454
260
  next_token_ids=next_token_ids,
455
261
  num_accepted_tokens=0,
@@ -472,7 +278,7 @@ class EAGLEWorker(TpModelWorker):
472
278
  # decode is not finished
473
279
  self.forward_draft_extend_after_decode(batch)
474
280
 
475
- return ForwardBatchOutput(
281
+ return GenerationBatchResult(
476
282
  logits_output=logits_output,
477
283
  next_token_ids=verify_output.verified_id,
478
284
  num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
@@ -513,12 +319,10 @@ class EAGLEWorker(TpModelWorker):
513
319
  # We need the full hidden states to prefill the KV cache of the draft model.
514
320
  model_worker_batch = batch.get_model_worker_batch()
515
321
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
516
- forward_batch_output = self.target_worker.forward_batch_generation(
517
- model_worker_batch
518
- )
322
+ batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
519
323
  logits_output, next_token_ids = (
520
- forward_batch_output.logits_output,
521
- forward_batch_output.next_token_ids,
324
+ batch_result.logits_output,
325
+ batch_result.next_token_ids,
522
326
  )
523
327
  return (
524
328
  logits_output,
@@ -543,8 +347,10 @@ class EAGLEWorker(TpModelWorker):
543
347
  # [ topk 0 ] [ topk 1 ]
544
348
  # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
545
349
  if self.page_size == 1:
546
- out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
547
- num_seqs * self.speculative_num_steps * self.topk, backup_state=True
350
+ out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots(
351
+ batch.tree_cache,
352
+ num_seqs * self.speculative_num_steps * self.topk,
353
+ backup_state=True,
548
354
  )
549
355
  else:
550
356
  if self.topk == 1:
@@ -603,7 +409,8 @@ class EAGLEWorker(TpModelWorker):
603
409
  extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
604
410
 
605
411
  out_cache_loc, token_to_kv_pool_state_backup = (
606
- batch.alloc_paged_token_slots_extend(
412
+ alloc_paged_token_slots_extend(
413
+ batch.tree_cache,
607
414
  prefix_lens,
608
415
  prefix_lens_cpu,
609
416
  seq_lens,
@@ -675,16 +482,21 @@ class EAGLEWorker(TpModelWorker):
675
482
  forward_batch
676
483
  )
677
484
  if can_cuda_graph:
678
- score_list, token_list, parents_list = self.cuda_graph_runner.replay(
485
+ parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay(
679
486
  forward_batch
680
487
  )
681
488
  else:
682
489
  forward_batch.can_run_dp_cuda_graph = False
683
- if not forward_batch.forward_mode.is_idle():
684
- # Initialize attention backend
490
+ if (
491
+ not forward_batch.forward_mode.is_idle()
492
+ and self.speculative_num_steps > 1
493
+ ):
494
+ # Skip attention backend init for idle mode or 1-step draft
685
495
  self.draft_attn_backend.init_forward_metadata(forward_batch)
686
496
  # Run forward steps
687
- score_list, token_list, parents_list = self.draft_forward(forward_batch)
497
+ parent_list, top_scores_index, draft_tokens = self.draft_forward(
498
+ forward_batch
499
+ )
688
500
 
689
501
  if batch.forward_mode.is_idle():
690
502
  return EagleVerifyInput.create_idle_input(
@@ -702,9 +514,9 @@ class EAGLEWorker(TpModelWorker):
702
514
  draft_tokens,
703
515
  ) = build_tree_kernel_efficient(
704
516
  spec_info.verified_id,
705
- score_list,
706
- token_list,
707
- parents_list,
517
+ parent_list,
518
+ top_scores_index,
519
+ draft_tokens,
708
520
  batch.seq_lens,
709
521
  batch.seq_lens_sum,
710
522
  self.topk,
@@ -786,18 +598,23 @@ class EAGLEWorker(TpModelWorker):
786
598
  logits_output, _ = self.draft_model_runner.forward(
787
599
  forward_batch, skip_attn_backend_init=True
788
600
  )
789
- self._detect_nan_if_needed(logits_output)
601
+ if self.server_args.enable_nan_detection:
602
+ detect_nan(logits_output)
790
603
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
791
604
  topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
792
605
  if self.hot_token_id is not None:
793
606
  topk_index = self.hot_token_id[topk_index]
794
607
  hidden_states = logits_output.hidden_states
795
608
 
796
- return score_list, token_list, parents_list
609
+ parent_list, top_scores_index, draft_tokens = organize_draft_results(
610
+ score_list, token_list, parents_list, self.speculative_num_draft_tokens
611
+ )
612
+
613
+ return parent_list, top_scores_index, draft_tokens
797
614
 
798
615
  def clear_cache_pool(self):
799
- self.model_runner.req_to_token_pool.clear()
800
- self.model_runner.token_to_kv_pool_allocator.clear()
616
+ # allocator and kv cache pool are shared with target worker
617
+ pass
801
618
 
802
619
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
803
620
  spec_info.prepare_for_verify(batch, self.page_size)
@@ -822,12 +639,12 @@ class EAGLEWorker(TpModelWorker):
822
639
  ).cpu()
823
640
 
824
641
  # Forward
825
- forward_batch_output = self.target_worker.forward_batch_generation(
642
+ batch_result = self.target_worker.forward_batch_generation(
826
643
  model_worker_batch, is_verify=True
827
644
  )
828
645
  logits_output, can_run_cuda_graph = (
829
- forward_batch_output.logits_output,
830
- forward_batch_output.can_run_cuda_graph,
646
+ batch_result.logits_output,
647
+ batch_result.can_run_cuda_graph,
831
648
  )
832
649
 
833
650
  vocab_mask = None
@@ -850,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
850
667
  # and will be applied to produce wrong results
851
668
  batch.sampling_info.vocab_mask = None
852
669
 
853
- self._detect_nan_if_needed(logits_output)
670
+ if self.enable_nan_detection:
671
+ detect_nan(logits_output)
672
+
854
673
  spec_info.hidden_states = logits_output.hidden_states
855
674
  res: EagleVerifyOutput = spec_info.verify(
856
675
  batch,
@@ -868,7 +687,7 @@ class EAGLEWorker(TpModelWorker):
868
687
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
869
688
 
870
689
  # QQ: can be optimized
871
- if self.target_worker.model_runner.is_hybrid_gdn:
690
+ if self.target_worker.model_runner.hybrid_gdn_config is not None:
872
691
  # res.draft_input.accept_length is on GPU but may be empty for last verify?
873
692
  accepted_length = (
874
693
  torch.tensor(
@@ -911,7 +730,7 @@ class EAGLEWorker(TpModelWorker):
911
730
  # acceptance indices are the indices in a "flattened" batch.
912
731
  # dividing it to num_draft_tokens will yield the actual batch index.
913
732
  temperatures = temperatures[accepted_indices // num_draft_tokens]
914
- if RETURN_ORIGINAL_LOGPROB:
733
+ if SGLANG_RETURN_ORIGINAL_LOGPROB:
915
734
  logprobs = torch.nn.functional.log_softmax(
916
735
  logits_output.next_token_logits, dim=-1
917
736
  )
@@ -1003,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
1003
822
  )
1004
823
  forward_batch.return_logprob = False
1005
824
  logits_output, _ = self.draft_model_runner.forward(forward_batch)
1006
- self._detect_nan_if_needed(logits_output)
825
+ if self.enable_nan_detection:
826
+ detect_nan(logits_output)
1007
827
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
1008
828
  assert forward_batch.spec_info is batch.spec_info
1009
829
  self.capture_for_decode(logits_output, forward_batch.spec_info)
@@ -1098,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
1098
918
  )
1099
919
  self.capture_for_decode(logits_output, forward_batch.spec_info)
1100
920
 
1101
- self._detect_nan_if_needed(logits_output)
921
+ if self.enable_nan_detection:
922
+ detect_nan(logits_output)
1102
923
 
1103
924
  # Restore backup.
1104
925
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
@@ -1118,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
1118
939
  draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
1119
940
  draft_input.hidden_states = logits_output.hidden_states
1120
941
 
1121
- def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
1122
- if self.enable_nan_detection:
1123
- logits = logits_output.next_token_logits
1124
- if torch.any(torch.isnan(logits)):
1125
- logger.error("Detected errors during sampling! NaN in the logits.")
1126
- raise ValueError("Detected errors during sampling! NaN in the logits.")
1127
-
1128
-
1129
- def load_token_map(token_map_path: str) -> List[int]:
1130
- if not os.path.exists(token_map_path):
1131
- cache_dir = snapshot_download(
1132
- os.path.dirname(token_map_path),
1133
- ignore_patterns=["*.bin", "*.safetensors"],
1134
- )
1135
- token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
1136
- hot_token_id = torch.load(token_map_path, weights_only=True)
1137
- return torch.tensor(hot_token_id, dtype=torch.int64)
1138
-
1139
942
 
1140
943
  @torch.compile(dynamic=True)
1141
944
  def get_last_loc_large_page_size_top_k_1(