sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+
10
+ from sglang.srt.server_args import get_global_server_args
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ from dataclasses import dataclass
15
+
16
+ import torch.nn.functional as F
17
+
18
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
19
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
20
+ from sglang.srt.layers.sampler import apply_custom_logit_processor
21
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
22
+ from sglang.srt.mem_cache.common import (
23
+ alloc_paged_token_slots_extend,
24
+ alloc_token_slots,
25
+ get_last_loc,
26
+ )
27
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
28
+ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
29
+ from sglang.srt.speculative.spec_utils import (
30
+ TREE_SPEC_KERNEL_AVAILABLE,
31
+ assign_req_to_token_pool,
32
+ get_src_tgt_cache_loc,
33
+ get_target_cache_loc,
34
+ )
35
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
36
+
37
+ if is_cuda():
38
+ from sgl_kernel import (
39
+ top_k_renorm_prob,
40
+ top_p_renorm_prob,
41
+ tree_speculative_sampling_target_only,
42
+ verify_tree_greedy,
43
+ )
44
+ elif is_hip():
45
+ from sgl_kernel import verify_tree_greedy
46
+
47
+
48
+ @dataclass
49
+ class NgramVerifyInput(SpecInput):
50
+ def __init__(
51
+ self,
52
+ draft_token: torch.Tensor,
53
+ tree_mask: torch.Tensor,
54
+ positions: torch.Tensor,
55
+ retrive_index: torch.Tensor,
56
+ retrive_next_token: torch.Tensor,
57
+ retrive_next_sibling: torch.Tensor,
58
+ draft_token_num: int,
59
+ ):
60
+ super().__init__(SpecInputType.NGRAM_VERIFY)
61
+ self.draft_token = draft_token
62
+ self.custom_mask = tree_mask
63
+ self.positions = positions
64
+ self.retrive_index = retrive_index
65
+ self.retrive_next_token = retrive_next_token
66
+ self.retrive_next_sibling = retrive_next_sibling
67
+ self.draft_token_num = draft_token_num
68
+ self.device = self.custom_mask.device
69
+
70
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
71
+ return self.draft_token_num, self.draft_token_num
72
+
73
+ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
74
+ if batch.forward_mode.is_idle():
75
+ return
76
+
77
+ batch.input_ids = self.draft_token
78
+
79
+ if page_size == 1:
80
+ batch.out_cache_loc = alloc_token_slots(
81
+ batch.tree_cache,
82
+ len(batch.input_ids),
83
+ )
84
+ end_offset = batch.seq_lens + self.draft_token_num
85
+ else:
86
+ # TODO(lsyin): add prefix lens cpu here to support page size > 1
87
+ prefix_lens = batch.seq_lens
88
+ prefix_lens_cpu = batch.seq_lens_cpu
89
+ end_offset = prefix_lens + self.draft_token_num
90
+ end_offset_cpu = prefix_lens_cpu + self.draft_token_num
91
+ last_loc = get_last_loc(
92
+ batch.req_to_token_pool.req_to_token,
93
+ batch.req_pool_indices,
94
+ prefix_lens,
95
+ )
96
+ batch.out_cache_loc = alloc_paged_token_slots_extend(
97
+ batch.tree_cache,
98
+ prefix_lens,
99
+ prefix_lens_cpu,
100
+ end_offset,
101
+ end_offset_cpu,
102
+ last_loc,
103
+ len(batch.input_ids),
104
+ )
105
+ self.last_loc = last_loc
106
+
107
+ bs = batch.batch_size()
108
+ assign_req_to_token_pool[(bs,)](
109
+ batch.req_pool_indices,
110
+ batch.req_to_token_pool.req_to_token,
111
+ batch.seq_lens,
112
+ end_offset,
113
+ batch.out_cache_loc,
114
+ batch.req_to_token_pool.req_to_token.shape[1],
115
+ triton.next_power_of_2(bs),
116
+ )
117
+
118
+ def generate_attn_arg_prefill(
119
+ self,
120
+ req_pool_indices: torch.Tensor,
121
+ paged_kernel_lens: torch.Tensor,
122
+ paged_kernel_lens_sum: int,
123
+ req_to_token: torch.Tensor,
124
+ ):
125
+ bs = len(req_pool_indices)
126
+
127
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
128
+
129
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
130
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
131
+
132
+ self.qo_indptr = (
133
+ torch.arange(0, bs + 1, dtype=torch.int32, device=self.device)
134
+ * self.draft_token_num
135
+ )
136
+
137
+ kv_indices = torch.empty(
138
+ cum_kv_seq_len[-1], dtype=torch.int32, device=self.device
139
+ )
140
+
141
+ create_flashinfer_kv_indices_triton[(bs,)](
142
+ req_to_token,
143
+ req_pool_indices,
144
+ paged_kernel_lens,
145
+ cum_kv_seq_len,
146
+ None,
147
+ kv_indices,
148
+ req_to_token.size(1),
149
+ )
150
+ return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask
151
+
152
+ def _fill_requests(
153
+ self,
154
+ batch: ScheduleBatch,
155
+ logits_output: torch.Tensor,
156
+ ):
157
+ accept_index_cpu = self.accept_index.tolist()
158
+ predict_cpu = self.predict.tolist()
159
+ has_finished = False
160
+
161
+ # Iterate every accepted token and check if req has finished after append the token
162
+ # should be checked BEFORE free kv cache slots
163
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
164
+ for j, idx in enumerate(accept_index_row):
165
+ if idx == -1:
166
+ break
167
+ id = predict_cpu[idx]
168
+ req.output_ids.append(id)
169
+ req.check_finished()
170
+ if req.finished():
171
+ has_finished = True
172
+ # set all tokens after finished token to -1 and break
173
+ self.accept_index[i, j + 1 :] = -1
174
+ break
175
+ else:
176
+ if req.grammar is not None:
177
+ try:
178
+ req.grammar.accept_token(id)
179
+ except ValueError as e:
180
+ logger.info(
181
+ f"{i=}, {req=}\n"
182
+ f"{self.accept_index=}\n"
183
+ f"{self.predict=}\n"
184
+ )
185
+ raise e
186
+ req.spec_verify_ct += 1
187
+ if has_finished:
188
+ self.accept_length = (self.accept_index != -1).sum(dim=1) - 1
189
+ self.accept_index = self.accept_index[self.accept_index != -1]
190
+
191
+ logits_output.next_token_logits = logits_output.next_token_logits[
192
+ self.accept_index
193
+ ]
194
+ if logits_output.hidden_states:
195
+ logits_output.hidden_states = logits_output.hidden_states[self.accept_index]
196
+ self.verified_id = self.predict[self.accept_index]
197
+
198
+ def _free_cache(self, batch: ScheduleBatch, page_size: int):
199
+ bs = batch.batch_size()
200
+ # Free the KV cache for unaccepted tokens
201
+ if page_size == 1:
202
+ # TODO: boolean array index leads to a device sync. Remove it.
203
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
204
+ evict_mask[self.accept_index] = False
205
+ batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
206
+ batch.out_cache_loc = batch.out_cache_loc[self.accept_index]
207
+ else:
208
+ # Shift the accepted tokens to the beginning.
209
+ # Only evict the last part
210
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
211
+ batch.seq_lens,
212
+ batch.out_cache_loc,
213
+ self.accept_index,
214
+ self.accept_length,
215
+ self.draft_token_num,
216
+ page_size,
217
+ )
218
+ to_free_slots = torch.empty(
219
+ (to_free_num_slots.sum().item(),),
220
+ dtype=torch.int64,
221
+ device=to_free_num_slots.device,
222
+ )
223
+
224
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
225
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
226
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
227
+ # to_free_slots: [ 2, 5, 7 8]
228
+ # to_free_slots also needs to be page-aligned without the first partial page
229
+ #
230
+ # split each row of out_cache_loc into two parts.
231
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
232
+ # 2. the second part goes to to_free_slots.
233
+ get_target_cache_loc[(bs,)](
234
+ tgt_cache_loc,
235
+ to_free_slots,
236
+ self.accept_length,
237
+ to_free_num_slots,
238
+ batch.out_cache_loc,
239
+ self.draft_token_num,
240
+ next_power_of_2(self.draft_token_num),
241
+ next_power_of_2(bs),
242
+ )
243
+
244
+ # Free the kv cache
245
+ batch.token_to_kv_pool_allocator.free(to_free_slots)
246
+
247
+ # Copy the kv cache
248
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
249
+ tgt_cache_loc, src_cache_loc
250
+ )
251
+ batch.out_cache_loc = tgt_cache_loc
252
+
253
+ assign_req_to_token_pool[(bs,)](
254
+ batch.req_pool_indices,
255
+ batch.req_to_token_pool.req_to_token,
256
+ batch.seq_lens,
257
+ batch.seq_lens + self.accept_length + 1,
258
+ batch.out_cache_loc,
259
+ batch.req_to_token_pool.req_to_token.shape[1],
260
+ triton.next_power_of_2(bs),
261
+ )
262
+
263
+ def _greedy_verify(
264
+ self,
265
+ batch: ScheduleBatch,
266
+ logits_output: LogitsProcessorOutput,
267
+ ):
268
+ bs = batch.batch_size()
269
+ target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
270
+ target_predict = target_predict.reshape(bs, self.draft_token_num)
271
+
272
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
273
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
274
+ predict_shape[-1] += 1
275
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
276
+ self.accept_index = torch.full(
277
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
278
+ )
279
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
280
+
281
+ verify_tree_greedy(
282
+ predicts=self.predict, # mutable
283
+ accept_index=self.accept_index, # mutable
284
+ accept_token_num=self.accept_length, # mutable
285
+ candidates=candidates,
286
+ retrive_index=self.retrive_index,
287
+ retrive_next_token=self.retrive_next_token,
288
+ retrive_next_sibling=self.retrive_next_sibling,
289
+ target_predict=target_predict,
290
+ )
291
+
292
+ def _sampling_verify(
293
+ self,
294
+ batch: ScheduleBatch,
295
+ logits_output: LogitsProcessorOutput,
296
+ sampling_info: SamplingBatchInfo,
297
+ ):
298
+ bs = batch.batch_size()
299
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
300
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
301
+ predict_shape[-1] += 1
302
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
303
+ self.accept_index = torch.full(
304
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
305
+ )
306
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
307
+ # apply temperature and get target probs
308
+ expanded_temperature = torch.repeat_interleave(
309
+ sampling_info.temperatures, self.draft_token_num, dim=0
310
+ ) # (bs * draft_token_num, 1)
311
+
312
+ target_probs = F.softmax(
313
+ logits_output.next_token_logits / expanded_temperature, dim=-1
314
+ ) # (bs * draft_token_num, vocab_size)
315
+
316
+ # NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors
317
+ # contributing to the poor performance of _sampling_verify.
318
+ target_probs = top_k_renorm_prob(
319
+ target_probs,
320
+ torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0),
321
+ ) # (bs * draft_token_num, vocab_size)
322
+
323
+ if sampling_info.need_top_p_sampling:
324
+ # logger.info("Using top-p sampling in speculative decoding verification.")
325
+ target_probs = top_p_renorm_prob(
326
+ target_probs,
327
+ torch.repeat_interleave(
328
+ sampling_info.top_ps, self.draft_token_num, dim=0
329
+ ),
330
+ )
331
+
332
+ target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
333
+ draft_probs = torch.zeros(
334
+ target_probs.shape, dtype=torch.float32, device=self.device
335
+ )
336
+
337
+ # coins for rejection sampling
338
+ coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device)
339
+ # coins for final sampling
340
+ coins_for_final_sampling = torch.rand(
341
+ (bs,), dtype=torch.float32, device=self.device
342
+ )
343
+ tree_speculative_sampling_target_only(
344
+ predicts=self.predict, # mutable
345
+ accept_index=self.accept_index, # mutable
346
+ accept_token_num=self.accept_length, # mutable
347
+ candidates=candidates.to(torch.int64),
348
+ retrive_index=self.retrive_index.to(torch.int64),
349
+ retrive_next_token=self.retrive_next_token.to(torch.int64),
350
+ retrive_next_sibling=self.retrive_next_sibling.to(torch.int64),
351
+ uniform_samples=coins,
352
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
353
+ target_probs=target_probs,
354
+ draft_probs=draft_probs,
355
+ threshold_single=get_global_server_args().speculative_accept_threshold_single,
356
+ threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
357
+ deterministic=True,
358
+ )
359
+
360
+ def verify(
361
+ self,
362
+ batch: ScheduleBatch,
363
+ logits_output: LogitsProcessorOutput,
364
+ page_size: int,
365
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
366
+ ) -> torch.Tensor:
367
+ bs = self.retrive_index.shape[0]
368
+ sampling_info = batch.sampling_info
369
+
370
+ if bs != len(sampling_info):
371
+ sampling_info = copy.deepcopy(sampling_info)
372
+ # NOTE: retrive_index are the indices of the requests that are kept.
373
+ sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
374
+
375
+ # Apply the custom logit processors if registered in the sampling info.
376
+ if sampling_info.has_custom_logit_processor:
377
+ apply_custom_logit_processor(
378
+ logits_output.next_token_logits,
379
+ sampling_info,
380
+ num_tokens_in_batch=self.draft_token_num,
381
+ )
382
+
383
+ # Apply penalty
384
+ if sampling_info.penalizer_orchestrator.is_required:
385
+ # This is a relaxed version of penalties for speculative decoding.
386
+ linear_penalty = torch.zeros(
387
+ (bs, logits_output.next_token_logits.shape[1]),
388
+ dtype=torch.float32,
389
+ device=self.device,
390
+ )
391
+ sampling_info.apply_logits_bias(linear_penalty)
392
+ logits_output.next_token_logits.add_(
393
+ torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
394
+ )
395
+
396
+ # Apply grammar mask
397
+ if vocab_mask is not None:
398
+ assert self.grammar is not None
399
+ self.grammar.apply_vocab_mask(
400
+ logits=logits_output.next_token_logits, vocab_mask=vocab_mask
401
+ )
402
+
403
+ # Sample tokens. Force greedy sampling on AMD
404
+ is_all_greedy = sampling_info.is_all_greedy
405
+ if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
406
+ logger.warning(
407
+ "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
408
+ "Falling back to greedy verification."
409
+ )
410
+
411
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
412
+ self._greedy_verify(batch, logits_output)
413
+ else:
414
+ # NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor.
415
+ self._greedy_verify(batch, logits_output)
416
+ # self._sampling_verify(batch, logits_output, sampling_info)
417
+
418
+ self._fill_requests(batch, logits_output)
419
+ self._free_cache(batch, page_size)
420
+
421
+ accept_length_cpu = self.accept_length.cpu()
422
+ num_accepted_tokens = accept_length_cpu.sum().item()
423
+
424
+ batch.seq_lens.add_(self.accept_length + 1)
425
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
426
+
427
+ return logits_output, self.verified_id, num_accepted_tokens
428
+
429
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
430
+ pass
431
+
432
+ def merge_batch(self, spec_info: NgramVerifyInput):
433
+ pass
@@ -0,0 +1,246 @@
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
7
+
8
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
9
+ from sglang.srt.managers.scheduler import GenerationBatchResult
10
+ from sglang.srt.managers.tp_worker import TpModelWorker
11
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
12
+ from sglang.srt.server_args import ServerArgs
13
+ from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
14
+ from sglang.srt.speculative.ngram_info import NgramVerifyInput
15
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ USE_FULL_MASK = True
20
+
21
+
22
+ class NGRAMWorker:
23
+ def __init__(
24
+ self,
25
+ server_args: ServerArgs,
26
+ gpu_id: int,
27
+ tp_rank: int,
28
+ dp_rank: Optional[int],
29
+ moe_ep_rank: int,
30
+ nccl_port: int,
31
+ target_worker: TpModelWorker,
32
+ ):
33
+ self.target_worker = target_worker
34
+ self.model_runner = target_worker.model_runner
35
+ self.tp_rank = tp_rank
36
+ self.page_size = server_args.page_size
37
+ self.draft_token_num: int = server_args.speculative_num_draft_tokens
38
+ self.branch_length: int = server_args.speculative_ngram_branch_length
39
+ self.max_match_window_size: int = (
40
+ server_args.speculative_ngram_max_match_window_size
41
+ )
42
+
43
+ self.max_batch_size = target_worker.max_running_requests
44
+ self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
45
+
46
+ self._init_preallocated_tensors()
47
+
48
+ self.ngram_cache = NgramCache(
49
+ min_match_window_size=server_args.speculative_ngram_min_match_window_size,
50
+ max_match_window_size=server_args.speculative_ngram_max_match_window_size,
51
+ min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
52
+ max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
53
+ capacity=server_args.speculative_ngram_capacity,
54
+ branch_length=server_args.speculative_ngram_branch_length,
55
+ draft_token_num=server_args.speculative_num_draft_tokens,
56
+ )
57
+
58
+ def clear_cache_pool(self):
59
+ self.ngram_cache.reset()
60
+
61
+ def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
62
+ seq2_len = len(seq2)
63
+ if seq2_len >= n:
64
+ return seq2[-n:]
65
+
66
+ need_from_seq1 = n - seq2_len
67
+ return seq1[-need_from_seq1:] + seq2
68
+
69
+ def _init_preallocated_tensors(self):
70
+ max_total_drafts = self.max_batch_size * self.draft_token_num
71
+ max_total_mask_size = (
72
+ self.max_batch_size * self.draft_token_num * self.draft_token_num
73
+ )
74
+
75
+ self.draft_tokens = torch.empty(
76
+ (max_total_drafts,), dtype=torch.int64, device=self.device
77
+ )
78
+ self.retrieve_indexes = torch.empty(
79
+ (self.max_batch_size, self.draft_token_num),
80
+ dtype=torch.int64,
81
+ device=self.device,
82
+ )
83
+ self.retrive_next_token = torch.empty(
84
+ (self.max_batch_size, self.draft_token_num),
85
+ dtype=torch.int64,
86
+ device=self.device,
87
+ )
88
+ self.retrive_next_sibling = torch.empty(
89
+ (self.max_batch_size, self.draft_token_num),
90
+ dtype=torch.int64,
91
+ device=self.device,
92
+ )
93
+ self.positions = torch.empty(
94
+ (max_total_drafts,), dtype=torch.int64, device=self.device
95
+ )
96
+ self.tree_mask = torch.empty(
97
+ (max_total_mask_size,), dtype=torch.bool, device=self.device
98
+ )
99
+
100
+ self.draft_tokens_batch = []
101
+ self.tree_mask_batch = []
102
+ self.retrieve_indexes_batch = []
103
+ self.retrive_next_token_batch = []
104
+ self.retrive_next_sibling_batch = []
105
+ self.positions_batch = []
106
+
107
+ for bs in range(0, self.max_batch_size + 1):
108
+ self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :])
109
+ self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :])
110
+ self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :])
111
+ self.positions_batch.append(self.positions[: bs * self.draft_token_num])
112
+ self.draft_tokens_batch.append(
113
+ self.draft_tokens[: bs * self.draft_token_num]
114
+ )
115
+ self.tree_mask_batch.append(
116
+ self.tree_mask[: bs * self.draft_token_num * self.draft_token_num]
117
+ )
118
+
119
+ def _prepare_draft_tokens(
120
+ self, batch: ScheduleBatch
121
+ ) -> tuple[np.ndarray, np.ndarray]:
122
+ bs = batch.batch_size()
123
+
124
+ self.ngram_cache.synchronize()
125
+ batch_tokens = []
126
+ for req in batch.reqs:
127
+ check_token = self._efficient_concat_last_n(
128
+ req.origin_input_ids, req.output_ids, self.max_match_window_size
129
+ )
130
+ batch_tokens.append(check_token)
131
+ req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
132
+ total_draft_token_num = len(req_drafts)
133
+
134
+ # Check if speculative decoding is needed; here we always enforce it
135
+ assert (
136
+ total_draft_token_num == bs * self.draft_token_num
137
+ ), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}"
138
+ return req_drafts, mask
139
+
140
+ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch):
141
+ if batch.forward_mode.is_extend():
142
+ return
143
+
144
+ bs = batch.batch_size()
145
+
146
+ retrive_index = self.retrieve_indexes_batch[bs]
147
+ retrive_next_token = self.retrive_next_token_batch[bs]
148
+ retrive_next_sibling = self.retrive_next_sibling_batch[bs]
149
+ positions = self.positions_batch[bs]
150
+ tree_mask = self.tree_mask_batch[bs]
151
+ draft_tokens = self.draft_tokens_batch[bs]
152
+
153
+ req_drafts, mask = self._prepare_draft_tokens(batch)
154
+ tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
155
+ draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
156
+
157
+ reconstruct_indices_from_tree_mask(
158
+ tree_mask,
159
+ batch.seq_lens,
160
+ positions, # mutable
161
+ retrive_index, # mutable
162
+ retrive_next_token, # mutable
163
+ retrive_next_sibling, # mutable
164
+ bs,
165
+ self.draft_token_num,
166
+ )
167
+
168
+ # NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer.
169
+ # Testing shows about 8% performance improvement (the effect is roughly proportional to batch size).
170
+ if USE_FULL_MASK:
171
+ tree_mask = []
172
+ mask = mask.reshape(
173
+ batch.batch_size(), self.draft_token_num, self.draft_token_num
174
+ )
175
+ for i, req in enumerate(batch.reqs):
176
+ seq_len = len(req.origin_input_ids) + len(req.output_ids)
177
+ req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
178
+ req_mask = torch.cat(
179
+ (req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
180
+ ).to(torch.bool)
181
+ tree_mask.append(req_mask.flatten())
182
+ tree_mask = torch.cat(tree_mask, dim=0)
183
+
184
+ batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
185
+ batch.forward_mode = ForwardMode.TARGET_VERIFY
186
+ batch.spec_info = NgramVerifyInput(
187
+ draft_tokens,
188
+ tree_mask,
189
+ positions,
190
+ retrive_index,
191
+ retrive_next_token,
192
+ retrive_next_sibling,
193
+ self.draft_token_num,
194
+ )
195
+ batch.spec_info.prepare_for_verify(batch, self.page_size)
196
+
197
+ def _update_ngram_cache(self, batch: ScheduleBatch):
198
+ batch_tokens = []
199
+ for req in batch.reqs:
200
+ # FIXME: Whether to insert 'extend' into the cache or not, after testing,
201
+ # there is not much difference, so we will not insert it for now.
202
+ # if batch.forward_mode.is_extend():
203
+ # put_ids = req.origin_input_ids + req.output_ids
204
+ # else:
205
+ put_ids = self._efficient_concat_last_n(
206
+ req.origin_input_ids, req.output_ids, self.branch_length
207
+ )
208
+ batch_tokens.append(put_ids)
209
+ self.ngram_cache.batch_put(batch_tokens)
210
+
211
+ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
212
+ self._prepare_for_speculative_decoding(batch)
213
+ model_worker_batch = batch.get_model_worker_batch()
214
+ num_accepted_tokens = 0
215
+
216
+ if model_worker_batch.forward_mode.is_target_verify():
217
+ batch_result = self.target_worker.forward_batch_generation(
218
+ model_worker_batch, is_verify=True
219
+ )
220
+ logits_output, can_run_cuda_graph = (
221
+ batch_result.logits_output,
222
+ batch_result.can_run_cuda_graph,
223
+ )
224
+ verify_input = model_worker_batch.spec_info
225
+ logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
226
+ batch, logits_output, self.page_size
227
+ )
228
+ self._update_ngram_cache(batch)
229
+ batch.forward_mode = ForwardMode.DECODE
230
+
231
+ else:
232
+ batch_result = self.target_worker.forward_batch_generation(
233
+ model_worker_batch
234
+ )
235
+ logits_output, next_token_ids, can_run_cuda_graph = (
236
+ batch_result.logits_output,
237
+ batch_result.next_token_ids,
238
+ batch_result.can_run_cuda_graph,
239
+ )
240
+
241
+ return GenerationBatchResult(
242
+ logits_output=logits_output,
243
+ next_token_ids=next_token_ids,
244
+ num_accepted_tokens=num_accepted_tokens,
245
+ can_run_cuda_graph=can_run_cuda_graph,
246
+ )