sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,71 +1,65 @@
1
1
  import logging
2
- import os
3
2
  import time
4
- from contextlib import contextmanager
5
3
  from typing import List, Optional, Tuple
6
4
 
7
5
  import torch
8
- from huggingface_hub import snapshot_download
9
6
 
10
- from sglang.srt.distributed import (
11
- GroupCoordinator,
12
- get_tp_group,
13
- patch_tensor_parallel_group,
14
- )
7
+ from sglang.srt.distributed import get_tp_group
15
8
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
9
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
17
- from sglang.srt.managers.mm_utils import embed_mm_inputs
18
- from sglang.srt.managers.schedule_batch import (
19
- ScheduleBatch,
10
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
11
+ from sglang.srt.managers.scheduler import GenerationBatchResult
12
+ from sglang.srt.managers.tp_worker import TpModelWorker
13
+ from sglang.srt.mem_cache.common import (
14
+ alloc_paged_token_slots_extend,
15
+ alloc_token_slots,
20
16
  get_last_loc,
21
- global_server_args_dict,
22
17
  )
23
- from sglang.srt.managers.tp_worker import TpModelWorker
24
18
  from sglang.srt.model_executor.forward_batch_info import (
25
19
  CaptureHiddenMode,
26
20
  ForwardBatch,
27
21
  ForwardMode,
28
22
  )
29
23
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
24
+ from sglang.srt.speculative.draft_utils import DraftBackendFactory
31
25
  from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
32
26
  EAGLEDraftCudaGraphRunner,
33
27
  )
34
28
  from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
35
29
  EAGLEDraftExtendCudaGraphRunner,
36
30
  )
37
- from sglang.srt.speculative.eagle_utils import (
31
+ from sglang.srt.speculative.eagle_info import (
38
32
  EagleDraftInput,
39
33
  EagleVerifyInput,
40
34
  EagleVerifyOutput,
35
+ )
36
+ from sglang.srt.speculative.eagle_utils import (
37
+ build_tree_kernel_efficient,
38
+ organize_draft_results,
39
+ )
40
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
41
+ from sglang.srt.speculative.spec_utils import (
41
42
  assign_draft_cache_locs,
43
+ detect_nan,
44
+ draft_tp_context,
42
45
  fast_topk,
43
46
  generate_token_bitmask,
47
+ load_token_map,
44
48
  select_top_k_tokens,
45
49
  )
46
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
47
50
  from sglang.srt.utils import (
48
51
  empty_context,
49
52
  get_available_gpu_memory,
50
53
  get_bool_env_var,
51
- is_blackwell,
52
54
  is_cuda,
53
55
  next_power_of_2,
54
56
  )
55
57
 
56
58
  if is_cuda():
57
- from sgl_kernel import segment_packbits
59
+ from sgl_kernel import segment_packbits # noqa: F401
58
60
 
59
61
  logger = logging.getLogger(__name__)
60
- RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
61
-
62
-
63
- @contextmanager
64
- def draft_tp_context(tp_group: GroupCoordinator):
65
- # Draft model doesn't use dp and has its own tp group.
66
- # We disable mscclpp now because it doesn't support 2 comm groups.
67
- with patch_tensor_parallel_group(tp_group):
68
- yield
62
+ SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
69
63
 
70
64
 
71
65
  class EAGLEWorker(TpModelWorker):
@@ -93,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
93
87
  self.speculative_algorithm = SpeculativeAlgorithm.from_string(
94
88
  server_args.speculative_algorithm
95
89
  )
96
- self.padded_static_len = -1
97
90
 
98
91
  # Override the context length of the draft model to be the same as the target model.
99
92
  server_args.context_length = target_worker.model_runner.model_config.context_len
@@ -185,201 +178,22 @@ class EAGLEWorker(TpModelWorker):
185
178
 
186
179
  def init_attention_backend(self):
187
180
  # Create multi-step attn backends and cuda graph runners
188
-
189
- self.has_prefill_wrapper_verify = False
190
- self.draft_extend_attn_backend = None
181
+ draft_backend_factory = DraftBackendFactory(
182
+ self.server_args,
183
+ self.draft_model_runner,
184
+ self.topk,
185
+ self.speculative_num_steps,
186
+ )
191
187
 
192
188
  # Initialize decode attention backend
193
- self.draft_attn_backend = self._create_decode_backend()
189
+ self.draft_attn_backend = draft_backend_factory.create_decode_backend()
194
190
 
195
191
  # Initialize draft extend attention backend (respects speculative_attention_mode setting)
196
- self.draft_extend_attn_backend = self._create_draft_extend_backend()
197
-
198
- self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
199
-
200
- def _create_backend(
201
- self, backend_name: str, backend_map: dict, error_template: str
202
- ):
203
- backend_type = getattr(self.server_args, backend_name)
204
- if backend_type is None:
205
- backend_type = self.server_args.attention_backend
206
-
207
- if backend_type not in backend_map:
208
- raise ValueError(error_template.format(backend_type=backend_type))
209
-
210
- return backend_map[backend_type]()
211
-
212
- def _create_decode_backend(self):
213
- backend_map = {
214
- "flashinfer": self._create_flashinfer_decode_backend,
215
- "triton": self._create_triton_decode_backend,
216
- "aiter": self._create_aiter_decode_backend,
217
- "fa3": self._create_fa3_decode_backend,
218
- "hybrid_linear_attn": (
219
- self._create_fa3_decode_backend
220
- if not is_blackwell()
221
- else self._create_triton_decode_backend
222
- ),
223
- "flashmla": self._create_flashmla_decode_backend,
224
- "trtllm_mha": self._create_trtllm_mha_decode_backend,
225
- "trtllm_mla": self._create_trtllm_mla_decode_backend,
226
- }
227
-
228
- return self._create_backend(
229
- "decode_attention_backend",
230
- backend_map,
231
- "EAGLE is not supported in decode attention backend {backend_type}",
232
- )
233
-
234
- def _create_draft_extend_backend(self):
235
- backend_map = {
236
- "flashinfer": self._create_flashinfer_prefill_backend,
237
- "triton": self._create_triton_prefill_backend,
238
- "aiter": self._create_aiter_prefill_backend,
239
- "fa3": self._create_fa3_prefill_backend,
240
- "hybrid_linear_attn": (
241
- self._create_fa3_prefill_backend
242
- if not is_blackwell()
243
- else self._create_triton_prefill_backend
244
- ),
245
- "trtllm_mha": self._create_trtllm_mha_prefill_backend,
246
- "trtllm_mla": self._create_trtllm_mla_prefill_backend,
247
- }
248
- backend_name = (
249
- "decode_attention_backend"
250
- if self.server_args.speculative_attention_mode == "decode"
251
- else "prefill_attention_backend"
252
- )
253
- return self._create_backend(
254
- backend_name,
255
- backend_map,
256
- "EAGLE is not supported in attention backend {backend_type}",
257
- )
258
-
259
- def _create_flashinfer_decode_backend(self):
260
- if not global_server_args_dict["use_mla_backend"]:
261
- from sglang.srt.layers.attention.flashinfer_backend import (
262
- FlashInferMultiStepDraftBackend,
263
- )
264
-
265
- self.has_prefill_wrapper_verify = True
266
- return FlashInferMultiStepDraftBackend(
267
- self.draft_model_runner, self.topk, self.speculative_num_steps
268
- )
269
- else:
270
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
271
- FlashInferMLAMultiStepDraftBackend,
272
- )
273
-
274
- self.has_prefill_wrapper_verify = True
275
- return FlashInferMLAMultiStepDraftBackend(
276
- self.draft_model_runner, self.topk, self.speculative_num_steps
277
- )
278
-
279
- def _create_triton_decode_backend(self):
280
- from sglang.srt.layers.attention.triton_backend import (
281
- TritonMultiStepDraftBackend,
282
- )
283
-
284
- return TritonMultiStepDraftBackend(
285
- self.draft_model_runner, self.topk, self.speculative_num_steps
286
- )
287
-
288
- def _create_aiter_decode_backend(self):
289
- from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
290
-
291
- return AiterMultiStepDraftBackend(
292
- self.draft_model_runner, self.topk, self.speculative_num_steps
293
- )
294
-
295
- def _create_fa3_decode_backend(self):
296
- from sglang.srt.layers.attention.flashattention_backend import (
297
- FlashAttentionMultiStepBackend,
298
- )
299
-
300
- return FlashAttentionMultiStepBackend(
301
- self.draft_model_runner, self.topk, self.speculative_num_steps
192
+ self.draft_extend_attn_backend = (
193
+ draft_backend_factory.create_draft_extend_backend()
302
194
  )
303
195
 
304
- def _create_flashmla_decode_backend(self):
305
- from sglang.srt.layers.attention.flashmla_backend import (
306
- FlashMLAMultiStepDraftBackend,
307
- )
308
-
309
- return FlashMLAMultiStepDraftBackend(
310
- self.draft_model_runner, self.topk, self.speculative_num_steps
311
- )
312
-
313
- def _create_trtllm_mha_decode_backend(self):
314
- from sglang.srt.layers.attention.trtllm_mha_backend import (
315
- TRTLLMHAAttnMultiStepDraftBackend,
316
- )
317
-
318
- self.has_prefill_wrapper_verify = True
319
- return TRTLLMHAAttnMultiStepDraftBackend(
320
- self.draft_model_runner, self.topk, self.speculative_num_steps
321
- )
322
-
323
- def _create_trtllm_mla_decode_backend(self):
324
- if not global_server_args_dict["use_mla_backend"]:
325
- raise ValueError(
326
- "trtllm_mla backend requires MLA model (use_mla_backend=True)."
327
- )
328
-
329
- from sglang.srt.layers.attention.trtllm_mla_backend import (
330
- TRTLLMMLAMultiStepDraftBackend,
331
- )
332
-
333
- self.has_prefill_wrapper_verify = True
334
- return TRTLLMMLAMultiStepDraftBackend(
335
- self.draft_model_runner, self.topk, self.speculative_num_steps
336
- )
337
-
338
- def _create_flashinfer_prefill_backend(self):
339
- if not global_server_args_dict["use_mla_backend"]:
340
- from sglang.srt.layers.attention.flashinfer_backend import (
341
- FlashInferAttnBackend,
342
- )
343
-
344
- return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
345
- else:
346
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
347
- FlashInferMLAAttnBackend,
348
- )
349
-
350
- return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
351
-
352
- def _create_triton_prefill_backend(self):
353
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
354
-
355
- return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
356
-
357
- def _create_aiter_prefill_backend(self):
358
- from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
359
-
360
- return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
361
-
362
- def _create_fa3_prefill_backend(self):
363
- from sglang.srt.layers.attention.flashattention_backend import (
364
- FlashAttentionBackend,
365
- )
366
-
367
- return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
368
-
369
- def _create_trtllm_mha_prefill_backend(self):
370
- from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
371
-
372
- return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
373
-
374
- def _create_trtllm_mla_prefill_backend(self):
375
- if not global_server_args_dict["use_mla_backend"]:
376
- raise ValueError(
377
- "trtllm_mla backend requires MLA model (use_mla_backend=True)."
378
- )
379
-
380
- from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
381
-
382
- return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
196
+ self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
383
197
 
384
198
  def init_cuda_graphs(self):
385
199
  """Capture cuda graphs."""
@@ -390,16 +204,17 @@ class EAGLEWorker(TpModelWorker):
390
204
  return
391
205
 
392
206
  # Capture draft
393
- tic = time.perf_counter()
394
- before_mem = get_available_gpu_memory(self.device, self.gpu_id)
395
- logger.info(
396
- f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
397
- )
398
- self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
399
- after_mem = get_available_gpu_memory(self.device, self.gpu_id)
400
- logger.info(
401
- f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
402
- )
207
+ if self.speculative_num_steps > 1:
208
+ tic = time.perf_counter()
209
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
210
+ logger.info(
211
+ f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
212
+ )
213
+ self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
214
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
215
+ logger.info(
216
+ f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
217
+ )
403
218
 
404
219
  # Capture extend
405
220
  if self.draft_extend_attn_backend:
@@ -420,9 +235,7 @@ class EAGLEWorker(TpModelWorker):
420
235
  def draft_model_runner(self):
421
236
  return self.model_runner
422
237
 
423
- def forward_batch_speculative_generation(
424
- self, batch: ScheduleBatch
425
- ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
238
+ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
426
239
  """Run speculative decoding forward.
427
240
 
428
241
  NOTE: Many states of batch is modified as you go through. It is not guaranteed that
@@ -435,14 +248,19 @@ class EAGLEWorker(TpModelWorker):
435
248
  the batch id (used for overlap schedule), and number of accepted tokens.
436
249
  """
437
250
  if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
438
- logits_output, next_token_ids, bid, seq_lens_cpu = (
439
- self.forward_target_extend(batch)
251
+ logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
252
+ batch
440
253
  )
441
254
  with self.draft_tp_context(self.draft_model_runner.tp_group):
442
255
  self.forward_draft_extend(
443
256
  batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
444
257
  )
445
- return logits_output, next_token_ids, bid, 0, False
258
+ return GenerationBatchResult(
259
+ logits_output=logits_output,
260
+ next_token_ids=next_token_ids,
261
+ num_accepted_tokens=0,
262
+ can_run_cuda_graph=False,
263
+ )
446
264
  else:
447
265
  with self.draft_tp_context(self.draft_model_runner.tp_group):
448
266
  spec_info = self.draft(batch)
@@ -460,12 +278,11 @@ class EAGLEWorker(TpModelWorker):
460
278
  # decode is not finished
461
279
  self.forward_draft_extend_after_decode(batch)
462
280
 
463
- return (
464
- logits_output,
465
- verify_output.verified_id,
466
- model_worker_batch.bid,
467
- sum(verify_output.accept_length_per_req_cpu),
468
- can_run_cuda_graph,
281
+ return GenerationBatchResult(
282
+ logits_output=logits_output,
283
+ next_token_ids=verify_output.verified_id,
284
+ num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
285
+ can_run_cuda_graph=can_run_cuda_graph,
469
286
  )
470
287
 
471
288
  def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
@@ -497,19 +314,19 @@ class EAGLEWorker(TpModelWorker):
497
314
  Returns:
498
315
  logits_output: The output of logits. It will contain the full hidden states.
499
316
  next_token_ids: Next token ids generated.
500
- bid: The model batch ID. Used for overlap schedule.
501
317
  """
502
318
  # Forward with the target model and get hidden states.
503
319
  # We need the full hidden states to prefill the KV cache of the draft model.
504
320
  model_worker_batch = batch.get_model_worker_batch()
505
321
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
506
- logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
507
- model_worker_batch
322
+ batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
323
+ logits_output, next_token_ids = (
324
+ batch_result.logits_output,
325
+ batch_result.next_token_ids,
508
326
  )
509
327
  return (
510
328
  logits_output,
511
329
  next_token_ids,
512
- model_worker_batch.bid,
513
330
  model_worker_batch.seq_lens_cpu,
514
331
  )
515
332
 
@@ -530,8 +347,10 @@ class EAGLEWorker(TpModelWorker):
530
347
  # [ topk 0 ] [ topk 1 ]
531
348
  # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
532
349
  if self.page_size == 1:
533
- out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
534
- num_seqs * self.speculative_num_steps * self.topk, backup_state=True
350
+ out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots(
351
+ batch.tree_cache,
352
+ num_seqs * self.speculative_num_steps * self.topk,
353
+ backup_state=True,
535
354
  )
536
355
  else:
537
356
  if self.topk == 1:
@@ -541,6 +360,8 @@ class EAGLEWorker(TpModelWorker):
541
360
  batch.seq_lens,
542
361
  self.speculative_num_steps,
543
362
  )
363
+ prefix_lens_cpu = batch.seq_lens_cpu
364
+ seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
544
365
  extend_num_tokens = num_seqs * self.speculative_num_steps
545
366
  else:
546
367
  # In this case, the last partial page needs to be duplicated.
@@ -576,14 +397,24 @@ class EAGLEWorker(TpModelWorker):
576
397
  self.topk,
577
398
  self.page_size,
578
399
  )
579
-
580
- # TODO(lmzheng): remove this device sync
581
- extend_num_tokens = torch.sum(self.extend_lens).item()
400
+ prefix_lens_cpu = batch.seq_lens_cpu
401
+ last_page_lens = prefix_lens_cpu % self.page_size
402
+ num_new_pages_per_topk = (
403
+ last_page_lens + self.speculative_num_steps + self.page_size - 1
404
+ ) // self.page_size
405
+ seq_lens_cpu = (
406
+ prefix_lens_cpu // self.page_size * self.page_size
407
+ + num_new_pages_per_topk * (self.page_size * self.topk)
408
+ )
409
+ extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
582
410
 
583
411
  out_cache_loc, token_to_kv_pool_state_backup = (
584
- batch.alloc_paged_token_slots_extend(
412
+ alloc_paged_token_slots_extend(
413
+ batch.tree_cache,
585
414
  prefix_lens,
415
+ prefix_lens_cpu,
586
416
  seq_lens,
417
+ seq_lens_cpu,
587
418
  last_loc,
588
419
  extend_num_tokens,
589
420
  backup_state=True,
@@ -651,16 +482,21 @@ class EAGLEWorker(TpModelWorker):
651
482
  forward_batch
652
483
  )
653
484
  if can_cuda_graph:
654
- score_list, token_list, parents_list = self.cuda_graph_runner.replay(
485
+ parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay(
655
486
  forward_batch
656
487
  )
657
488
  else:
658
489
  forward_batch.can_run_dp_cuda_graph = False
659
- if not forward_batch.forward_mode.is_idle():
660
- # Initialize attention backend
490
+ if (
491
+ not forward_batch.forward_mode.is_idle()
492
+ and self.speculative_num_steps > 1
493
+ ):
494
+ # Skip attention backend init for idle mode or 1-step draft
661
495
  self.draft_attn_backend.init_forward_metadata(forward_batch)
662
496
  # Run forward steps
663
- score_list, token_list, parents_list = self.draft_forward(forward_batch)
497
+ parent_list, top_scores_index, draft_tokens = self.draft_forward(
498
+ forward_batch
499
+ )
664
500
 
665
501
  if batch.forward_mode.is_idle():
666
502
  return EagleVerifyInput.create_idle_input(
@@ -678,9 +514,9 @@ class EAGLEWorker(TpModelWorker):
678
514
  draft_tokens,
679
515
  ) = build_tree_kernel_efficient(
680
516
  spec_info.verified_id,
681
- score_list,
682
- token_list,
683
- parents_list,
517
+ parent_list,
518
+ top_scores_index,
519
+ draft_tokens,
684
520
  batch.seq_lens,
685
521
  batch.seq_lens_sum,
686
522
  self.topk,
@@ -762,14 +598,23 @@ class EAGLEWorker(TpModelWorker):
762
598
  logits_output, _ = self.draft_model_runner.forward(
763
599
  forward_batch, skip_attn_backend_init=True
764
600
  )
765
- self._detect_nan_if_needed(logits_output)
601
+ if self.server_args.enable_nan_detection:
602
+ detect_nan(logits_output)
766
603
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
767
604
  topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
768
605
  if self.hot_token_id is not None:
769
606
  topk_index = self.hot_token_id[topk_index]
770
607
  hidden_states = logits_output.hidden_states
771
608
 
772
- return score_list, token_list, parents_list
609
+ parent_list, top_scores_index, draft_tokens = organize_draft_results(
610
+ score_list, token_list, parents_list, self.speculative_num_draft_tokens
611
+ )
612
+
613
+ return parent_list, top_scores_index, draft_tokens
614
+
615
+ def clear_cache_pool(self):
616
+ # allocator and kv cache pool are shared with target worker
617
+ pass
773
618
 
774
619
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
775
620
  spec_info.prepare_for_verify(batch, self.page_size)
@@ -794,10 +639,12 @@ class EAGLEWorker(TpModelWorker):
794
639
  ).cpu()
795
640
 
796
641
  # Forward
797
- logits_output, _, can_run_cuda_graph = (
798
- self.target_worker.forward_batch_generation(
799
- model_worker_batch, skip_sample=True
800
- )
642
+ batch_result = self.target_worker.forward_batch_generation(
643
+ model_worker_batch, is_verify=True
644
+ )
645
+ logits_output, can_run_cuda_graph = (
646
+ batch_result.logits_output,
647
+ batch_result.can_run_cuda_graph,
801
648
  )
802
649
 
803
650
  vocab_mask = None
@@ -820,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
820
667
  # and will be applied to produce wrong results
821
668
  batch.sampling_info.vocab_mask = None
822
669
 
823
- self._detect_nan_if_needed(logits_output)
670
+ if self.enable_nan_detection:
671
+ detect_nan(logits_output)
672
+
824
673
  spec_info.hidden_states = logits_output.hidden_states
825
674
  res: EagleVerifyOutput = spec_info.verify(
826
675
  batch,
@@ -838,7 +687,7 @@ class EAGLEWorker(TpModelWorker):
838
687
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
839
688
 
840
689
  # QQ: can be optimized
841
- if self.target_worker.model_runner.is_hybrid_gdn:
690
+ if self.target_worker.model_runner.hybrid_gdn_config is not None:
842
691
  # res.draft_input.accept_length is on GPU but may be empty for last verify?
843
692
  accepted_length = (
844
693
  torch.tensor(
@@ -881,7 +730,7 @@ class EAGLEWorker(TpModelWorker):
881
730
  # acceptance indices are the indices in a "flattened" batch.
882
731
  # dividing it to num_draft_tokens will yield the actual batch index.
883
732
  temperatures = temperatures[accepted_indices // num_draft_tokens]
884
- if RETURN_ORIGINAL_LOGPROB:
733
+ if SGLANG_RETURN_ORIGINAL_LOGPROB:
885
734
  logprobs = torch.nn.functional.log_softmax(
886
735
  logits_output.next_token_logits, dim=-1
887
736
  )
@@ -973,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
973
822
  )
974
823
  forward_batch.return_logprob = False
975
824
  logits_output, _ = self.draft_model_runner.forward(forward_batch)
976
- self._detect_nan_if_needed(logits_output)
825
+ if self.enable_nan_detection:
826
+ detect_nan(logits_output)
977
827
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
978
828
  assert forward_batch.spec_info is batch.spec_info
979
829
  self.capture_for_decode(logits_output, forward_batch.spec_info)
@@ -997,6 +847,7 @@ class EAGLEWorker(TpModelWorker):
997
847
  assert isinstance(batch.spec_info, EagleDraftInput)
998
848
  # Backup fields that will be modified in-place
999
849
  seq_lens_backup = batch.seq_lens.clone()
850
+ seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
1000
851
  req_pool_indices_backup = batch.req_pool_indices
1001
852
  accept_length_backup = batch.spec_info.accept_length
1002
853
  return_logprob_backup = batch.return_logprob
@@ -1067,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
1067
918
  )
1068
919
  self.capture_for_decode(logits_output, forward_batch.spec_info)
1069
920
 
1070
- self._detect_nan_if_needed(logits_output)
921
+ if self.enable_nan_detection:
922
+ detect_nan(logits_output)
1071
923
 
1072
924
  # Restore backup.
1073
925
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
@@ -1075,6 +927,7 @@ class EAGLEWorker(TpModelWorker):
1075
927
  ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
1076
928
  )
1077
929
  batch.seq_lens = seq_lens_backup
930
+ batch.seq_lens_cpu = seq_lens_cpu_backup
1078
931
  batch.req_pool_indices = req_pool_indices_backup
1079
932
  batch.spec_info.accept_length = accept_length_backup
1080
933
  batch.return_logprob = return_logprob_backup
@@ -1086,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
1086
939
  draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
1087
940
  draft_input.hidden_states = logits_output.hidden_states
1088
941
 
1089
- def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
1090
- if self.enable_nan_detection:
1091
- logits = logits_output.next_token_logits
1092
- if torch.any(torch.isnan(logits)):
1093
- logger.error("Detected errors during sampling! NaN in the logits.")
1094
- raise ValueError("Detected errors during sampling! NaN in the logits.")
1095
-
1096
-
1097
- def load_token_map(token_map_path: str) -> List[int]:
1098
- if not os.path.exists(token_map_path):
1099
- cache_dir = snapshot_download(
1100
- os.path.dirname(token_map_path),
1101
- ignore_patterns=["*.bin", "*.safetensors"],
1102
- )
1103
- token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
1104
- hot_token_id = torch.load(token_map_path, weights_only=True)
1105
- return torch.tensor(hot_token_id, dtype=torch.int64)
1106
-
1107
942
 
1108
943
  @torch.compile(dynamic=True)
1109
944
  def get_last_loc_large_page_size_top_k_1(