sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,428 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch.nn.functional as F
15
+
16
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
17
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
18
+ from sglang.srt.layers.sampler import apply_custom_logit_processor
19
+ from sglang.srt.managers.schedule_batch import (
20
+ ScheduleBatch,
21
+ get_last_loc,
22
+ global_server_args_dict,
23
+ )
24
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
25
+ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
26
+ from sglang.srt.speculative.spec_utils import (
27
+ TREE_SPEC_KERNEL_AVAILABLE,
28
+ assign_req_to_token_pool,
29
+ get_src_tgt_cache_loc,
30
+ get_target_cache_loc,
31
+ )
32
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
33
+
34
+ if is_cuda():
35
+ from sgl_kernel import (
36
+ top_k_renorm_prob,
37
+ top_p_renorm_prob,
38
+ tree_speculative_sampling_target_only,
39
+ verify_tree_greedy,
40
+ )
41
+ elif is_hip():
42
+ from sgl_kernel import verify_tree_greedy
43
+
44
+
45
+ @dataclass
46
+ class NgramVerifyInput(SpecInput):
47
+ def __init__(
48
+ self,
49
+ draft_token: torch.Tensor,
50
+ tree_mask: torch.Tensor,
51
+ positions: torch.Tensor,
52
+ retrive_index: torch.Tensor,
53
+ retrive_next_token: torch.Tensor,
54
+ retrive_next_sibling: torch.Tensor,
55
+ draft_token_num: int,
56
+ ):
57
+ super().__init__(SpecInputType.NGRAM_VERIFY)
58
+ self.draft_token = draft_token
59
+ self.custom_mask = tree_mask
60
+ self.positions = positions
61
+ self.retrive_index = retrive_index
62
+ self.retrive_next_token = retrive_next_token
63
+ self.retrive_next_sibling = retrive_next_sibling
64
+ self.draft_token_num = draft_token_num
65
+ self.device = self.custom_mask.device
66
+
67
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
68
+ return self.draft_token_num, self.draft_token_num
69
+
70
+ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
71
+ if batch.forward_mode.is_idle():
72
+ return
73
+
74
+ batch.input_ids = self.draft_token
75
+
76
+ if page_size == 1:
77
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
78
+ end_offset = batch.seq_lens + self.draft_token_num
79
+ else:
80
+ # TODO(lsyin): add prefix lens cpu here to support page size > 1
81
+ prefix_lens = batch.seq_lens
82
+ prefix_lens_cpu = batch.seq_lens_cpu
83
+ end_offset = prefix_lens + self.draft_token_num
84
+ end_offset_cpu = prefix_lens_cpu + self.draft_token_num
85
+ last_loc = get_last_loc(
86
+ batch.req_to_token_pool.req_to_token,
87
+ batch.req_pool_indices,
88
+ prefix_lens,
89
+ )
90
+ batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
91
+ prefix_lens,
92
+ prefix_lens_cpu,
93
+ end_offset,
94
+ end_offset_cpu,
95
+ last_loc,
96
+ len(batch.input_ids),
97
+ )
98
+ self.last_loc = last_loc
99
+
100
+ bs = batch.batch_size()
101
+ assign_req_to_token_pool[(bs,)](
102
+ batch.req_pool_indices,
103
+ batch.req_to_token_pool.req_to_token,
104
+ batch.seq_lens,
105
+ end_offset,
106
+ batch.out_cache_loc,
107
+ batch.req_to_token_pool.req_to_token.shape[1],
108
+ triton.next_power_of_2(bs),
109
+ )
110
+
111
+ def generate_attn_arg_prefill(
112
+ self,
113
+ req_pool_indices: torch.Tensor,
114
+ paged_kernel_lens: torch.Tensor,
115
+ paged_kernel_lens_sum: int,
116
+ req_to_token: torch.Tensor,
117
+ ):
118
+ bs = len(req_pool_indices)
119
+
120
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
121
+
122
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
123
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
124
+
125
+ self.qo_indptr = (
126
+ torch.arange(0, bs + 1, dtype=torch.int32, device=self.device)
127
+ * self.draft_token_num
128
+ )
129
+
130
+ kv_indices = torch.empty(
131
+ cum_kv_seq_len[-1], dtype=torch.int32, device=self.device
132
+ )
133
+
134
+ create_flashinfer_kv_indices_triton[(bs,)](
135
+ req_to_token,
136
+ req_pool_indices,
137
+ paged_kernel_lens,
138
+ cum_kv_seq_len,
139
+ None,
140
+ kv_indices,
141
+ req_to_token.size(1),
142
+ )
143
+ return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask
144
+
145
+ def _fill_requests(
146
+ self,
147
+ batch: ScheduleBatch,
148
+ logits_output: torch.Tensor,
149
+ ):
150
+ accept_index_cpu = self.accept_index.tolist()
151
+ predict_cpu = self.predict.tolist()
152
+ has_finished = False
153
+
154
+ # Iterate every accepted token and check if req has finished after append the token
155
+ # should be checked BEFORE free kv cache slots
156
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
157
+ for j, idx in enumerate(accept_index_row):
158
+ if idx == -1:
159
+ break
160
+ id = predict_cpu[idx]
161
+ req.output_ids.append(id)
162
+ req.check_finished()
163
+ if req.finished():
164
+ has_finished = True
165
+ # set all tokens after finished token to -1 and break
166
+ self.accept_index[i, j + 1 :] = -1
167
+ break
168
+ else:
169
+ if req.grammar is not None:
170
+ try:
171
+ req.grammar.accept_token(id)
172
+ except ValueError as e:
173
+ logger.info(
174
+ f"{i=}, {req=}\n"
175
+ f"{self.accept_index=}\n"
176
+ f"{self.predict=}\n"
177
+ )
178
+ raise e
179
+ req.spec_verify_ct += 1
180
+ if has_finished:
181
+ self.accept_length = (self.accept_index != -1).sum(dim=1) - 1
182
+ self.accept_index = self.accept_index[self.accept_index != -1]
183
+
184
+ logits_output.next_token_logits = logits_output.next_token_logits[
185
+ self.accept_index
186
+ ]
187
+ if logits_output.hidden_states:
188
+ logits_output.hidden_states = logits_output.hidden_states[self.accept_index]
189
+ self.verified_id = self.predict[self.accept_index]
190
+
191
+ def _free_cache(self, batch: ScheduleBatch, page_size: int):
192
+ bs = batch.batch_size()
193
+ # Free the KV cache for unaccepted tokens
194
+ if page_size == 1:
195
+ # TODO: boolean array index leads to a device sync. Remove it.
196
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
197
+ evict_mask[self.accept_index] = False
198
+ batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
199
+ batch.out_cache_loc = batch.out_cache_loc[self.accept_index]
200
+ else:
201
+ # Shift the accepted tokens to the beginning.
202
+ # Only evict the last part
203
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
204
+ batch.seq_lens,
205
+ batch.out_cache_loc,
206
+ self.accept_index,
207
+ self.accept_length,
208
+ self.draft_token_num,
209
+ page_size,
210
+ )
211
+ to_free_slots = torch.empty(
212
+ (to_free_num_slots.sum().item(),),
213
+ dtype=torch.int64,
214
+ device=to_free_num_slots.device,
215
+ )
216
+
217
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
218
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
219
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
220
+ # to_free_slots: [ 2, 5, 7 8]
221
+ # to_free_slots also needs to be page-aligned without the first partial page
222
+ #
223
+ # split each row of out_cache_loc into two parts.
224
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
225
+ # 2. the second part goes to to_free_slots.
226
+ get_target_cache_loc[(bs,)](
227
+ tgt_cache_loc,
228
+ to_free_slots,
229
+ self.accept_length,
230
+ to_free_num_slots,
231
+ batch.out_cache_loc,
232
+ self.draft_token_num,
233
+ next_power_of_2(self.draft_token_num),
234
+ next_power_of_2(bs),
235
+ )
236
+
237
+ # Free the kv cache
238
+ batch.token_to_kv_pool_allocator.free(to_free_slots)
239
+
240
+ # Copy the kv cache
241
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
242
+ tgt_cache_loc, src_cache_loc
243
+ )
244
+ batch.out_cache_loc = tgt_cache_loc
245
+
246
+ assign_req_to_token_pool[(bs,)](
247
+ batch.req_pool_indices,
248
+ batch.req_to_token_pool.req_to_token,
249
+ batch.seq_lens,
250
+ batch.seq_lens + self.accept_length + 1,
251
+ batch.out_cache_loc,
252
+ batch.req_to_token_pool.req_to_token.shape[1],
253
+ triton.next_power_of_2(bs),
254
+ )
255
+
256
+ def _greedy_verify(
257
+ self,
258
+ batch: ScheduleBatch,
259
+ logits_output: LogitsProcessorOutput,
260
+ ):
261
+ bs = batch.batch_size()
262
+ target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
263
+ target_predict = target_predict.reshape(bs, self.draft_token_num)
264
+
265
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
266
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
267
+ predict_shape[-1] += 1
268
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
269
+ self.accept_index = torch.full(
270
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
271
+ )
272
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
273
+
274
+ verify_tree_greedy(
275
+ predicts=self.predict, # mutable
276
+ accept_index=self.accept_index, # mutable
277
+ accept_token_num=self.accept_length, # mutable
278
+ candidates=candidates,
279
+ retrive_index=self.retrive_index,
280
+ retrive_next_token=self.retrive_next_token,
281
+ retrive_next_sibling=self.retrive_next_sibling,
282
+ target_predict=target_predict,
283
+ )
284
+
285
+ def _sampling_verify(
286
+ self,
287
+ batch: ScheduleBatch,
288
+ logits_output: LogitsProcessorOutput,
289
+ sampling_info: SamplingBatchInfo,
290
+ ):
291
+ bs = batch.batch_size()
292
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
293
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
294
+ predict_shape[-1] += 1
295
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
296
+ self.accept_index = torch.full(
297
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
298
+ )
299
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
300
+ # apply temperature and get target probs
301
+ expanded_temperature = torch.repeat_interleave(
302
+ sampling_info.temperatures, self.draft_token_num, dim=0
303
+ ) # (bs * draft_token_num, 1)
304
+
305
+ target_probs = F.softmax(
306
+ logits_output.next_token_logits / expanded_temperature, dim=-1
307
+ ) # (bs * draft_token_num, vocab_size)
308
+
309
+ # NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors
310
+ # contributing to the poor performance of _sampling_verify.
311
+ target_probs = top_k_renorm_prob(
312
+ target_probs,
313
+ torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0),
314
+ ) # (bs * draft_token_num, vocab_size)
315
+
316
+ if sampling_info.need_top_p_sampling:
317
+ # logger.info("Using top-p sampling in speculative decoding verification.")
318
+ target_probs = top_p_renorm_prob(
319
+ target_probs,
320
+ torch.repeat_interleave(
321
+ sampling_info.top_ps, self.draft_token_num, dim=0
322
+ ),
323
+ )
324
+
325
+ target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
326
+ draft_probs = torch.zeros(
327
+ target_probs.shape, dtype=torch.float32, device=self.device
328
+ )
329
+
330
+ # coins for rejection sampling
331
+ coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device)
332
+ # coins for final sampling
333
+ coins_for_final_sampling = torch.rand(
334
+ (bs,), dtype=torch.float32, device=self.device
335
+ )
336
+ tree_speculative_sampling_target_only(
337
+ predicts=self.predict, # mutable
338
+ accept_index=self.accept_index, # mutable
339
+ accept_token_num=self.accept_length, # mutable
340
+ candidates=candidates.to(torch.int64),
341
+ retrive_index=self.retrive_index.to(torch.int64),
342
+ retrive_next_token=self.retrive_next_token.to(torch.int64),
343
+ retrive_next_sibling=self.retrive_next_sibling.to(torch.int64),
344
+ uniform_samples=coins,
345
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
346
+ target_probs=target_probs,
347
+ draft_probs=draft_probs,
348
+ threshold_single=global_server_args_dict[
349
+ "speculative_accept_threshold_single"
350
+ ],
351
+ threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
352
+ deterministic=True,
353
+ )
354
+
355
+ def verify(
356
+ self,
357
+ batch: ScheduleBatch,
358
+ logits_output: LogitsProcessorOutput,
359
+ page_size: int,
360
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
361
+ ) -> torch.Tensor:
362
+ bs = self.retrive_index.shape[0]
363
+ sampling_info = batch.sampling_info
364
+
365
+ if bs != len(sampling_info):
366
+ sampling_info = copy.deepcopy(sampling_info)
367
+ # NOTE: retrive_index are the indices of the requests that are kept.
368
+ sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
369
+
370
+ # Apply the custom logit processors if registered in the sampling info.
371
+ if sampling_info.has_custom_logit_processor:
372
+ apply_custom_logit_processor(
373
+ logits_output.next_token_logits,
374
+ sampling_info,
375
+ num_tokens_in_batch=self.draft_token_num,
376
+ )
377
+
378
+ # Apply penalty
379
+ if sampling_info.penalizer_orchestrator.is_required:
380
+ # This is a relaxed version of penalties for speculative decoding.
381
+ linear_penalty = torch.zeros(
382
+ (bs, logits_output.next_token_logits.shape[1]),
383
+ dtype=torch.float32,
384
+ device=self.device,
385
+ )
386
+ sampling_info.apply_logits_bias(linear_penalty)
387
+ logits_output.next_token_logits.add_(
388
+ torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
389
+ )
390
+
391
+ # Apply grammar mask
392
+ if vocab_mask is not None:
393
+ assert self.grammar is not None
394
+ self.grammar.apply_vocab_mask(
395
+ logits=logits_output.next_token_logits, vocab_mask=vocab_mask
396
+ )
397
+
398
+ # Sample tokens. Force greedy sampling on AMD
399
+ is_all_greedy = sampling_info.is_all_greedy
400
+ if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
401
+ logger.warning(
402
+ "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
403
+ "Falling back to greedy verification."
404
+ )
405
+
406
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
407
+ self._greedy_verify(batch, logits_output)
408
+ else:
409
+ # NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor.
410
+ self._greedy_verify(batch, logits_output)
411
+ # self._sampling_verify(batch, logits_output, sampling_info)
412
+
413
+ self._fill_requests(batch, logits_output)
414
+ self._free_cache(batch, page_size)
415
+
416
+ accept_length_cpu = self.accept_length.cpu()
417
+ num_accepted_tokens = accept_length_cpu.sum().item()
418
+
419
+ batch.seq_lens.add_(self.accept_length + 1)
420
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
421
+
422
+ return logits_output, self.verified_id, num_accepted_tokens
423
+
424
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
425
+ pass
426
+
427
+ def merge_batch(self, spec_info: NgramVerifyInput):
428
+ pass
@@ -0,0 +1,246 @@
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
7
+
8
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
9
+ from sglang.srt.managers.scheduler import GenerationBatchResult
10
+ from sglang.srt.managers.tp_worker import TpModelWorker
11
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
12
+ from sglang.srt.server_args import ServerArgs
13
+ from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
14
+ from sglang.srt.speculative.ngram_info import NgramVerifyInput
15
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ USE_FULL_MASK = True
20
+
21
+
22
+ class NGRAMWorker:
23
+ def __init__(
24
+ self,
25
+ server_args: ServerArgs,
26
+ gpu_id: int,
27
+ tp_rank: int,
28
+ dp_rank: Optional[int],
29
+ moe_ep_rank: int,
30
+ nccl_port: int,
31
+ target_worker: TpModelWorker,
32
+ ):
33
+ self.target_worker = target_worker
34
+ self.model_runner = target_worker.model_runner
35
+ self.tp_rank = tp_rank
36
+ self.page_size = server_args.page_size
37
+ self.draft_token_num: int = server_args.speculative_num_draft_tokens
38
+ self.branch_length: int = server_args.speculative_ngram_branch_length
39
+ self.max_match_window_size: int = (
40
+ server_args.speculative_ngram_max_match_window_size
41
+ )
42
+
43
+ self.max_batch_size = target_worker.max_running_requests
44
+ self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
45
+
46
+ self._init_preallocated_tensors()
47
+
48
+ self.ngram_cache = NgramCache(
49
+ min_match_window_size=server_args.speculative_ngram_min_match_window_size,
50
+ max_match_window_size=server_args.speculative_ngram_max_match_window_size,
51
+ min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
52
+ max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
53
+ capacity=server_args.speculative_ngram_capacity,
54
+ branch_length=server_args.speculative_ngram_branch_length,
55
+ draft_token_num=server_args.speculative_num_draft_tokens,
56
+ )
57
+
58
+ def clear_cache_pool(self):
59
+ self.ngram_cache.reset()
60
+
61
+ def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
62
+ seq2_len = len(seq2)
63
+ if seq2_len >= n:
64
+ return seq2[-n:]
65
+
66
+ need_from_seq1 = n - seq2_len
67
+ return seq1[-need_from_seq1:] + seq2
68
+
69
+ def _init_preallocated_tensors(self):
70
+ max_total_drafts = self.max_batch_size * self.draft_token_num
71
+ max_total_mask_size = (
72
+ self.max_batch_size * self.draft_token_num * self.draft_token_num
73
+ )
74
+
75
+ self.draft_tokens = torch.empty(
76
+ (max_total_drafts,), dtype=torch.int64, device=self.device
77
+ )
78
+ self.retrieve_indexes = torch.empty(
79
+ (self.max_batch_size, self.draft_token_num),
80
+ dtype=torch.int64,
81
+ device=self.device,
82
+ )
83
+ self.retrive_next_token = torch.empty(
84
+ (self.max_batch_size, self.draft_token_num),
85
+ dtype=torch.int64,
86
+ device=self.device,
87
+ )
88
+ self.retrive_next_sibling = torch.empty(
89
+ (self.max_batch_size, self.draft_token_num),
90
+ dtype=torch.int64,
91
+ device=self.device,
92
+ )
93
+ self.positions = torch.empty(
94
+ (max_total_drafts,), dtype=torch.int64, device=self.device
95
+ )
96
+ self.tree_mask = torch.empty(
97
+ (max_total_mask_size,), dtype=torch.bool, device=self.device
98
+ )
99
+
100
+ self.draft_tokens_batch = []
101
+ self.tree_mask_batch = []
102
+ self.retrieve_indexes_batch = []
103
+ self.retrive_next_token_batch = []
104
+ self.retrive_next_sibling_batch = []
105
+ self.positions_batch = []
106
+
107
+ for bs in range(0, self.max_batch_size + 1):
108
+ self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :])
109
+ self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :])
110
+ self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :])
111
+ self.positions_batch.append(self.positions[: bs * self.draft_token_num])
112
+ self.draft_tokens_batch.append(
113
+ self.draft_tokens[: bs * self.draft_token_num]
114
+ )
115
+ self.tree_mask_batch.append(
116
+ self.tree_mask[: bs * self.draft_token_num * self.draft_token_num]
117
+ )
118
+
119
+ def _prepare_draft_tokens(
120
+ self, batch: ScheduleBatch
121
+ ) -> tuple[np.ndarray, np.ndarray]:
122
+ bs = batch.batch_size()
123
+
124
+ self.ngram_cache.synchronize()
125
+ batch_tokens = []
126
+ for req in batch.reqs:
127
+ check_token = self._efficient_concat_last_n(
128
+ req.origin_input_ids, req.output_ids, self.max_match_window_size
129
+ )
130
+ batch_tokens.append(check_token)
131
+ req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
132
+ total_draft_token_num = len(req_drafts)
133
+
134
+ # Check if speculative decoding is needed; here we always enforce it
135
+ assert (
136
+ total_draft_token_num == bs * self.draft_token_num
137
+ ), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}"
138
+ return req_drafts, mask
139
+
140
+ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch):
141
+ if batch.forward_mode.is_extend():
142
+ return
143
+
144
+ bs = batch.batch_size()
145
+
146
+ retrive_index = self.retrieve_indexes_batch[bs]
147
+ retrive_next_token = self.retrive_next_token_batch[bs]
148
+ retrive_next_sibling = self.retrive_next_sibling_batch[bs]
149
+ positions = self.positions_batch[bs]
150
+ tree_mask = self.tree_mask_batch[bs]
151
+ draft_tokens = self.draft_tokens_batch[bs]
152
+
153
+ req_drafts, mask = self._prepare_draft_tokens(batch)
154
+ tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
155
+ draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
156
+
157
+ reconstruct_indices_from_tree_mask(
158
+ tree_mask,
159
+ batch.seq_lens,
160
+ positions, # mutable
161
+ retrive_index, # mutable
162
+ retrive_next_token, # mutable
163
+ retrive_next_sibling, # mutable
164
+ bs,
165
+ self.draft_token_num,
166
+ )
167
+
168
+ # NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer.
169
+ # Testing shows about 8% performance improvement (the effect is roughly proportional to batch size).
170
+ if USE_FULL_MASK:
171
+ tree_mask = []
172
+ mask = mask.reshape(
173
+ batch.batch_size(), self.draft_token_num, self.draft_token_num
174
+ )
175
+ for i, req in enumerate(batch.reqs):
176
+ seq_len = len(req.origin_input_ids) + len(req.output_ids)
177
+ req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
178
+ req_mask = torch.cat(
179
+ (req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
180
+ ).to(torch.bool)
181
+ tree_mask.append(req_mask.flatten())
182
+ tree_mask = torch.cat(tree_mask, dim=0)
183
+
184
+ batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
185
+ batch.forward_mode = ForwardMode.TARGET_VERIFY
186
+ batch.spec_info = NgramVerifyInput(
187
+ draft_tokens,
188
+ tree_mask,
189
+ positions,
190
+ retrive_index,
191
+ retrive_next_token,
192
+ retrive_next_sibling,
193
+ self.draft_token_num,
194
+ )
195
+ batch.spec_info.prepare_for_verify(batch, self.page_size)
196
+
197
+ def _update_ngram_cache(self, batch: ScheduleBatch):
198
+ batch_tokens = []
199
+ for req in batch.reqs:
200
+ # FIXME: Whether to insert 'extend' into the cache or not, after testing,
201
+ # there is not much difference, so we will not insert it for now.
202
+ # if batch.forward_mode.is_extend():
203
+ # put_ids = req.origin_input_ids + req.output_ids
204
+ # else:
205
+ put_ids = self._efficient_concat_last_n(
206
+ req.origin_input_ids, req.output_ids, self.branch_length
207
+ )
208
+ batch_tokens.append(put_ids)
209
+ self.ngram_cache.batch_put(batch_tokens)
210
+
211
+ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
212
+ self._prepare_for_speculative_decoding(batch)
213
+ model_worker_batch = batch.get_model_worker_batch()
214
+ num_accepted_tokens = 0
215
+
216
+ if model_worker_batch.forward_mode.is_target_verify():
217
+ batch_result = self.target_worker.forward_batch_generation(
218
+ model_worker_batch, is_verify=True
219
+ )
220
+ logits_output, can_run_cuda_graph = (
221
+ batch_result.logits_output,
222
+ batch_result.can_run_cuda_graph,
223
+ )
224
+ verify_input = model_worker_batch.spec_info
225
+ logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
226
+ batch, logits_output, self.page_size
227
+ )
228
+ self._update_ngram_cache(batch)
229
+ batch.forward_mode = ForwardMode.DECODE
230
+
231
+ else:
232
+ batch_result = self.target_worker.forward_batch_generation(
233
+ model_worker_batch
234
+ )
235
+ logits_output, next_token_ids, can_run_cuda_graph = (
236
+ batch_result.logits_output,
237
+ batch_result.next_token_ids,
238
+ batch_result.can_run_cuda_graph,
239
+ )
240
+
241
+ return GenerationBatchResult(
242
+ logits_output=logits_output,
243
+ next_token_ids=next_token_ids,
244
+ num_accepted_tokens=num_accepted_tokens,
245
+ can_run_cuda_graph=can_run_cuda_graph,
246
+ )