sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,428 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch.nn.functional as F
15
+
16
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
17
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
18
+ from sglang.srt.layers.sampler import apply_custom_logit_processor
19
+ from sglang.srt.managers.schedule_batch import (
20
+ ScheduleBatch,
21
+ get_last_loc,
22
+ global_server_args_dict,
23
+ )
24
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
25
+ from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
26
+ from sglang.srt.speculative.spec_utils import (
27
+ TREE_SPEC_KERNEL_AVAILABLE,
28
+ assign_req_to_token_pool,
29
+ get_src_tgt_cache_loc,
30
+ get_target_cache_loc,
31
+ )
32
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
33
+
34
+ if is_cuda():
35
+ from sgl_kernel import (
36
+ top_k_renorm_prob,
37
+ top_p_renorm_prob,
38
+ tree_speculative_sampling_target_only,
39
+ verify_tree_greedy,
40
+ )
41
+ elif is_hip():
42
+ from sgl_kernel import verify_tree_greedy
43
+
44
+
45
+ @dataclass
46
+ class NgramVerifyInput(SpecInput):
47
+ def __init__(
48
+ self,
49
+ draft_token: torch.Tensor,
50
+ tree_mask: torch.Tensor,
51
+ positions: torch.Tensor,
52
+ retrive_index: torch.Tensor,
53
+ retrive_next_token: torch.Tensor,
54
+ retrive_next_sibling: torch.Tensor,
55
+ draft_token_num: int,
56
+ ):
57
+ super().__init__(SpecInputType.NGRAM_VERIFY)
58
+ self.draft_token = draft_token
59
+ self.custom_mask = tree_mask
60
+ self.positions = positions
61
+ self.retrive_index = retrive_index
62
+ self.retrive_next_token = retrive_next_token
63
+ self.retrive_next_sibling = retrive_next_sibling
64
+ self.draft_token_num = draft_token_num
65
+ self.device = self.custom_mask.device
66
+
67
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
68
+ return self.draft_token_num, self.draft_token_num
69
+
70
+ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
71
+ if batch.forward_mode.is_idle():
72
+ return
73
+
74
+ batch.input_ids = self.draft_token
75
+
76
+ if page_size == 1:
77
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
78
+ end_offset = batch.seq_lens + self.draft_token_num
79
+ else:
80
+ # TODO(lsyin): add prefix lens cpu here to support page size > 1
81
+ prefix_lens = batch.seq_lens
82
+ prefix_lens_cpu = batch.seq_lens_cpu
83
+ end_offset = prefix_lens + self.draft_token_num
84
+ end_offset_cpu = prefix_lens_cpu + self.draft_token_num
85
+ last_loc = get_last_loc(
86
+ batch.req_to_token_pool.req_to_token,
87
+ batch.req_pool_indices,
88
+ prefix_lens,
89
+ )
90
+ batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
91
+ prefix_lens,
92
+ prefix_lens_cpu,
93
+ end_offset,
94
+ end_offset_cpu,
95
+ last_loc,
96
+ len(batch.input_ids),
97
+ )
98
+ self.last_loc = last_loc
99
+
100
+ bs = batch.batch_size()
101
+ assign_req_to_token_pool[(bs,)](
102
+ batch.req_pool_indices,
103
+ batch.req_to_token_pool.req_to_token,
104
+ batch.seq_lens,
105
+ end_offset,
106
+ batch.out_cache_loc,
107
+ batch.req_to_token_pool.req_to_token.shape[1],
108
+ triton.next_power_of_2(bs),
109
+ )
110
+
111
+ def generate_attn_arg_prefill(
112
+ self,
113
+ req_pool_indices: torch.Tensor,
114
+ paged_kernel_lens: torch.Tensor,
115
+ paged_kernel_lens_sum: int,
116
+ req_to_token: torch.Tensor,
117
+ ):
118
+ bs = len(req_pool_indices)
119
+
120
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
121
+
122
+ paged_kernel_lens = paged_kernel_lens + self.draft_token_num
123
+ cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
124
+
125
+ self.qo_indptr = (
126
+ torch.arange(0, bs + 1, dtype=torch.int32, device=self.device)
127
+ * self.draft_token_num
128
+ )
129
+
130
+ kv_indices = torch.empty(
131
+ cum_kv_seq_len[-1], dtype=torch.int32, device=self.device
132
+ )
133
+
134
+ create_flashinfer_kv_indices_triton[(bs,)](
135
+ req_to_token,
136
+ req_pool_indices,
137
+ paged_kernel_lens,
138
+ cum_kv_seq_len,
139
+ None,
140
+ kv_indices,
141
+ req_to_token.size(1),
142
+ )
143
+ return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask
144
+
145
+ def _fill_requests(
146
+ self,
147
+ batch: ScheduleBatch,
148
+ logits_output: torch.Tensor,
149
+ ):
150
+ accept_index_cpu = self.accept_index.tolist()
151
+ predict_cpu = self.predict.tolist()
152
+ has_finished = False
153
+
154
+ # Iterate every accepted token and check if req has finished after append the token
155
+ # should be checked BEFORE free kv cache slots
156
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
157
+ for j, idx in enumerate(accept_index_row):
158
+ if idx == -1:
159
+ break
160
+ id = predict_cpu[idx]
161
+ req.output_ids.append(id)
162
+ req.check_finished()
163
+ if req.finished():
164
+ has_finished = True
165
+ # set all tokens after finished token to -1 and break
166
+ self.accept_index[i, j + 1 :] = -1
167
+ break
168
+ else:
169
+ if req.grammar is not None:
170
+ try:
171
+ req.grammar.accept_token(id)
172
+ except ValueError as e:
173
+ logger.info(
174
+ f"{i=}, {req=}\n"
175
+ f"{self.accept_index=}\n"
176
+ f"{self.predict=}\n"
177
+ )
178
+ raise e
179
+ req.spec_verify_ct += 1
180
+ if has_finished:
181
+ self.accept_length = (self.accept_index != -1).sum(dim=1) - 1
182
+ self.accept_index = self.accept_index[self.accept_index != -1]
183
+
184
+ logits_output.next_token_logits = logits_output.next_token_logits[
185
+ self.accept_index
186
+ ]
187
+ if logits_output.hidden_states:
188
+ logits_output.hidden_states = logits_output.hidden_states[self.accept_index]
189
+ self.verified_id = self.predict[self.accept_index]
190
+
191
+ def _free_cache(self, batch: ScheduleBatch, page_size: int):
192
+ bs = batch.batch_size()
193
+ # Free the KV cache for unaccepted tokens
194
+ if page_size == 1:
195
+ # TODO: boolean array index leads to a device sync. Remove it.
196
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
197
+ evict_mask[self.accept_index] = False
198
+ batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
199
+ batch.out_cache_loc = batch.out_cache_loc[self.accept_index]
200
+ else:
201
+ # Shift the accepted tokens to the beginning.
202
+ # Only evict the last part
203
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
204
+ batch.seq_lens,
205
+ batch.out_cache_loc,
206
+ self.accept_index,
207
+ self.accept_length,
208
+ self.draft_token_num,
209
+ page_size,
210
+ )
211
+ to_free_slots = torch.empty(
212
+ (to_free_num_slots.sum().item(),),
213
+ dtype=torch.int64,
214
+ device=to_free_num_slots.device,
215
+ )
216
+
217
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
218
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
219
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
220
+ # to_free_slots: [ 2, 5, 7 8]
221
+ # to_free_slots also needs to be page-aligned without the first partial page
222
+ #
223
+ # split each row of out_cache_loc into two parts.
224
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
225
+ # 2. the second part goes to to_free_slots.
226
+ get_target_cache_loc[(bs,)](
227
+ tgt_cache_loc,
228
+ to_free_slots,
229
+ self.accept_length,
230
+ to_free_num_slots,
231
+ batch.out_cache_loc,
232
+ self.draft_token_num,
233
+ next_power_of_2(self.draft_token_num),
234
+ next_power_of_2(bs),
235
+ )
236
+
237
+ # Free the kv cache
238
+ batch.token_to_kv_pool_allocator.free(to_free_slots)
239
+
240
+ # Copy the kv cache
241
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
242
+ tgt_cache_loc, src_cache_loc
243
+ )
244
+ batch.out_cache_loc = tgt_cache_loc
245
+
246
+ assign_req_to_token_pool[(bs,)](
247
+ batch.req_pool_indices,
248
+ batch.req_to_token_pool.req_to_token,
249
+ batch.seq_lens,
250
+ batch.seq_lens + self.accept_length + 1,
251
+ batch.out_cache_loc,
252
+ batch.req_to_token_pool.req_to_token.shape[1],
253
+ triton.next_power_of_2(bs),
254
+ )
255
+
256
+ def _greedy_verify(
257
+ self,
258
+ batch: ScheduleBatch,
259
+ logits_output: LogitsProcessorOutput,
260
+ ):
261
+ bs = batch.batch_size()
262
+ target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
263
+ target_predict = target_predict.reshape(bs, self.draft_token_num)
264
+
265
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
266
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
267
+ predict_shape[-1] += 1
268
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
269
+ self.accept_index = torch.full(
270
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
271
+ )
272
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
273
+
274
+ verify_tree_greedy(
275
+ predicts=self.predict, # mutable
276
+ accept_index=self.accept_index, # mutable
277
+ accept_token_num=self.accept_length, # mutable
278
+ candidates=candidates,
279
+ retrive_index=self.retrive_index,
280
+ retrive_next_token=self.retrive_next_token,
281
+ retrive_next_sibling=self.retrive_next_sibling,
282
+ target_predict=target_predict,
283
+ )
284
+
285
+ def _sampling_verify(
286
+ self,
287
+ batch: ScheduleBatch,
288
+ logits_output: LogitsProcessorOutput,
289
+ sampling_info: SamplingBatchInfo,
290
+ ):
291
+ bs = batch.batch_size()
292
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
293
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
294
+ predict_shape[-1] += 1
295
+ self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
296
+ self.accept_index = torch.full(
297
+ (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
298
+ )
299
+ self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
300
+ # apply temperature and get target probs
301
+ expanded_temperature = torch.repeat_interleave(
302
+ sampling_info.temperatures, self.draft_token_num, dim=0
303
+ ) # (bs * draft_token_num, 1)
304
+
305
+ target_probs = F.softmax(
306
+ logits_output.next_token_logits / expanded_temperature, dim=-1
307
+ ) # (bs * draft_token_num, vocab_size)
308
+
309
+ # NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors
310
+ # contributing to the poor performance of _sampling_verify.
311
+ target_probs = top_k_renorm_prob(
312
+ target_probs,
313
+ torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0),
314
+ ) # (bs * draft_token_num, vocab_size)
315
+
316
+ if sampling_info.need_top_p_sampling:
317
+ # logger.info("Using top-p sampling in speculative decoding verification.")
318
+ target_probs = top_p_renorm_prob(
319
+ target_probs,
320
+ torch.repeat_interleave(
321
+ sampling_info.top_ps, self.draft_token_num, dim=0
322
+ ),
323
+ )
324
+
325
+ target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
326
+ draft_probs = torch.zeros(
327
+ target_probs.shape, dtype=torch.float32, device=self.device
328
+ )
329
+
330
+ # coins for rejection sampling
331
+ coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device)
332
+ # coins for final sampling
333
+ coins_for_final_sampling = torch.rand(
334
+ (bs,), dtype=torch.float32, device=self.device
335
+ )
336
+ tree_speculative_sampling_target_only(
337
+ predicts=self.predict, # mutable
338
+ accept_index=self.accept_index, # mutable
339
+ accept_token_num=self.accept_length, # mutable
340
+ candidates=candidates.to(torch.int64),
341
+ retrive_index=self.retrive_index.to(torch.int64),
342
+ retrive_next_token=self.retrive_next_token.to(torch.int64),
343
+ retrive_next_sibling=self.retrive_next_sibling.to(torch.int64),
344
+ uniform_samples=coins,
345
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
346
+ target_probs=target_probs,
347
+ draft_probs=draft_probs,
348
+ threshold_single=global_server_args_dict[
349
+ "speculative_accept_threshold_single"
350
+ ],
351
+ threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
352
+ deterministic=True,
353
+ )
354
+
355
+ def verify(
356
+ self,
357
+ batch: ScheduleBatch,
358
+ logits_output: LogitsProcessorOutput,
359
+ page_size: int,
360
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
361
+ ) -> torch.Tensor:
362
+ bs = self.retrive_index.shape[0]
363
+ sampling_info = batch.sampling_info
364
+
365
+ if bs != len(sampling_info):
366
+ sampling_info = copy.deepcopy(sampling_info)
367
+ # NOTE: retrive_index are the indices of the requests that are kept.
368
+ sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
369
+
370
+ # Apply the custom logit processors if registered in the sampling info.
371
+ if sampling_info.has_custom_logit_processor:
372
+ apply_custom_logit_processor(
373
+ logits_output.next_token_logits,
374
+ sampling_info,
375
+ num_tokens_in_batch=self.draft_token_num,
376
+ )
377
+
378
+ # Apply penalty
379
+ if sampling_info.penalizer_orchestrator.is_required:
380
+ # This is a relaxed version of penalties for speculative decoding.
381
+ linear_penalty = torch.zeros(
382
+ (bs, logits_output.next_token_logits.shape[1]),
383
+ dtype=torch.float32,
384
+ device=self.device,
385
+ )
386
+ sampling_info.apply_logits_bias(linear_penalty)
387
+ logits_output.next_token_logits.add_(
388
+ torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
389
+ )
390
+
391
+ # Apply grammar mask
392
+ if vocab_mask is not None:
393
+ assert self.grammar is not None
394
+ self.grammar.apply_vocab_mask(
395
+ logits=logits_output.next_token_logits, vocab_mask=vocab_mask
396
+ )
397
+
398
+ # Sample tokens. Force greedy sampling on AMD
399
+ is_all_greedy = sampling_info.is_all_greedy
400
+ if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
401
+ logger.warning(
402
+ "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
403
+ "Falling back to greedy verification."
404
+ )
405
+
406
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
407
+ self._greedy_verify(batch, logits_output)
408
+ else:
409
+ # NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor.
410
+ self._greedy_verify(batch, logits_output)
411
+ # self._sampling_verify(batch, logits_output, sampling_info)
412
+
413
+ self._fill_requests(batch, logits_output)
414
+ self._free_cache(batch, page_size)
415
+
416
+ accept_length_cpu = self.accept_length.cpu()
417
+ num_accepted_tokens = accept_length_cpu.sum().item()
418
+
419
+ batch.seq_lens.add_(self.accept_length + 1)
420
+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
421
+
422
+ return logits_output, self.verified_id, num_accepted_tokens
423
+
424
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
425
+ pass
426
+
427
+ def merge_batch(self, spec_info: NgramVerifyInput):
428
+ pass
@@ -0,0 +1,245 @@
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
7
+
8
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
9
+ from sglang.srt.managers.tp_worker import TpModelWorker
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
11
+ from sglang.srt.server_args import ServerArgs
12
+ from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
13
+ from sglang.srt.speculative.ngram_utils import NgramVerifyInput
14
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ USE_FULL_MASK = True
19
+
20
+
21
+ class NGRAMWorker:
22
+ def __init__(
23
+ self,
24
+ server_args: ServerArgs,
25
+ gpu_id: int,
26
+ tp_rank: int,
27
+ dp_rank: Optional[int],
28
+ moe_ep_rank: int,
29
+ nccl_port: int,
30
+ target_worker: TpModelWorker,
31
+ ):
32
+ self.target_worker = target_worker
33
+ self.model_runner = target_worker.model_runner
34
+ self.tp_rank = tp_rank
35
+ self.page_size = server_args.page_size
36
+ self.draft_token_num: int = server_args.speculative_num_draft_tokens
37
+ self.branch_length: int = server_args.speculative_ngram_branch_length
38
+ self.max_match_window_size: int = (
39
+ server_args.speculative_ngram_max_match_window_size
40
+ )
41
+
42
+ self.max_batch_size = target_worker.max_running_requests
43
+ self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
44
+
45
+ self._init_preallocated_tensors()
46
+
47
+ self.ngram_cache = NgramCache(
48
+ min_match_window_size=server_args.speculative_ngram_min_match_window_size,
49
+ max_match_window_size=server_args.speculative_ngram_max_match_window_size,
50
+ min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
51
+ max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
52
+ capacity=server_args.speculative_ngram_capacity,
53
+ branch_length=server_args.speculative_ngram_branch_length,
54
+ draft_token_num=server_args.speculative_num_draft_tokens,
55
+ )
56
+
57
+ def clear_cache_pool(self):
58
+ self.ngram_cache.reset()
59
+
60
+ def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
61
+ seq2_len = len(seq2)
62
+ if seq2_len >= n:
63
+ return seq2[-n:]
64
+
65
+ need_from_seq1 = n - seq2_len
66
+ return seq1[-need_from_seq1:] + seq2
67
+
68
+ def _init_preallocated_tensors(self):
69
+ max_total_drafts = self.max_batch_size * self.draft_token_num
70
+ max_total_mask_size = (
71
+ self.max_batch_size * self.draft_token_num * self.draft_token_num
72
+ )
73
+
74
+ self.draft_tokens = torch.empty(
75
+ (max_total_drafts,), dtype=torch.int64, device=self.device
76
+ )
77
+ self.retrieve_indexes = torch.empty(
78
+ (self.max_batch_size, self.draft_token_num),
79
+ dtype=torch.int64,
80
+ device=self.device,
81
+ )
82
+ self.retrive_next_token = torch.empty(
83
+ (self.max_batch_size, self.draft_token_num),
84
+ dtype=torch.int64,
85
+ device=self.device,
86
+ )
87
+ self.retrive_next_sibling = torch.empty(
88
+ (self.max_batch_size, self.draft_token_num),
89
+ dtype=torch.int64,
90
+ device=self.device,
91
+ )
92
+ self.positions = torch.empty(
93
+ (max_total_drafts,), dtype=torch.int64, device=self.device
94
+ )
95
+ self.tree_mask = torch.empty(
96
+ (max_total_mask_size,), dtype=torch.bool, device=self.device
97
+ )
98
+
99
+ self.draft_tokens_batch = []
100
+ self.tree_mask_batch = []
101
+ self.retrieve_indexes_batch = []
102
+ self.retrive_next_token_batch = []
103
+ self.retrive_next_sibling_batch = []
104
+ self.positions_batch = []
105
+
106
+ for bs in range(0, self.max_batch_size + 1):
107
+ self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :])
108
+ self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :])
109
+ self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :])
110
+ self.positions_batch.append(self.positions[: bs * self.draft_token_num])
111
+ self.draft_tokens_batch.append(
112
+ self.draft_tokens[: bs * self.draft_token_num]
113
+ )
114
+ self.tree_mask_batch.append(
115
+ self.tree_mask[: bs * self.draft_token_num * self.draft_token_num]
116
+ )
117
+
118
+ def _prepare_draft_tokens(
119
+ self, batch: ScheduleBatch
120
+ ) -> tuple[np.ndarray, np.ndarray]:
121
+ bs = batch.batch_size()
122
+
123
+ self.ngram_cache.synchronize()
124
+ batch_tokens = []
125
+ for req in batch.reqs:
126
+ check_token = self._efficient_concat_last_n(
127
+ req.origin_input_ids, req.output_ids, self.max_match_window_size
128
+ )
129
+ batch_tokens.append(check_token)
130
+ req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
131
+ total_draft_token_num = len(req_drafts)
132
+
133
+ # Check if speculative decoding is needed; here we always enforce it
134
+ assert (
135
+ total_draft_token_num == bs * self.draft_token_num
136
+ ), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}"
137
+ return req_drafts, mask
138
+
139
+ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch):
140
+ if batch.forward_mode.is_extend():
141
+ return
142
+
143
+ bs = batch.batch_size()
144
+
145
+ retrive_index = self.retrieve_indexes_batch[bs]
146
+ retrive_next_token = self.retrive_next_token_batch[bs]
147
+ retrive_next_sibling = self.retrive_next_sibling_batch[bs]
148
+ positions = self.positions_batch[bs]
149
+ tree_mask = self.tree_mask_batch[bs]
150
+ draft_tokens = self.draft_tokens_batch[bs]
151
+
152
+ req_drafts, mask = self._prepare_draft_tokens(batch)
153
+ tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
154
+ draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
155
+
156
+ reconstruct_indices_from_tree_mask(
157
+ tree_mask,
158
+ batch.seq_lens,
159
+ positions, # mutable
160
+ retrive_index, # mutable
161
+ retrive_next_token, # mutable
162
+ retrive_next_sibling, # mutable
163
+ bs,
164
+ self.draft_token_num,
165
+ )
166
+
167
+ # NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer.
168
+ # Testing shows about 8% performance improvement (the effect is roughly proportional to batch size).
169
+ if USE_FULL_MASK:
170
+ tree_mask = []
171
+ mask = mask.reshape(
172
+ batch.batch_size(), self.draft_token_num, self.draft_token_num
173
+ )
174
+ for i, req in enumerate(batch.reqs):
175
+ seq_len = len(req.origin_input_ids) + len(req.output_ids)
176
+ req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
177
+ req_mask = torch.cat(
178
+ (req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
179
+ ).to(torch.bool)
180
+ tree_mask.append(req_mask.flatten())
181
+ tree_mask = torch.cat(tree_mask, dim=0)
182
+
183
+ batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
184
+ batch.forward_mode = ForwardMode.TARGET_VERIFY
185
+ batch.spec_info = NgramVerifyInput(
186
+ draft_tokens,
187
+ tree_mask,
188
+ positions,
189
+ retrive_index,
190
+ retrive_next_token,
191
+ retrive_next_sibling,
192
+ self.draft_token_num,
193
+ )
194
+ batch.spec_info.prepare_for_verify(batch, self.page_size)
195
+
196
+ def _update_ngram_cache(self, batch: ScheduleBatch):
197
+ batch_tokens = []
198
+ for req in batch.reqs:
199
+ # FIXME: Whether to insert 'extend' into the cache or not, after testing,
200
+ # there is not much difference, so we will not insert it for now.
201
+ # if batch.forward_mode.is_extend():
202
+ # put_ids = req.origin_input_ids + req.output_ids
203
+ # else:
204
+ put_ids = self._efficient_concat_last_n(
205
+ req.origin_input_ids, req.output_ids, self.branch_length
206
+ )
207
+ batch_tokens.append(put_ids)
208
+ self.ngram_cache.batch_put(batch_tokens)
209
+
210
+ def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
211
+ self._prepare_for_speculative_decoding(batch)
212
+ model_worker_batch = batch.get_model_worker_batch()
213
+ num_accepted_tokens = 0
214
+
215
+ if model_worker_batch.forward_mode.is_target_verify():
216
+ forward_batch_output = self.target_worker.forward_batch_generation(
217
+ model_worker_batch, is_verify=True
218
+ )
219
+ logits_output, can_run_cuda_graph = (
220
+ forward_batch_output.logits_output,
221
+ forward_batch_output.can_run_cuda_graph,
222
+ )
223
+ verify_input = model_worker_batch.spec_info
224
+ logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
225
+ batch, logits_output, self.page_size
226
+ )
227
+ self._update_ngram_cache(batch)
228
+ batch.forward_mode = ForwardMode.DECODE
229
+
230
+ else:
231
+ forward_batch_output = self.target_worker.forward_batch_generation(
232
+ model_worker_batch
233
+ )
234
+ logits_output, next_token_ids, can_run_cuda_graph = (
235
+ forward_batch_output.logits_output,
236
+ forward_batch_output.next_token_ids,
237
+ forward_batch_output.can_run_cuda_graph,
238
+ )
239
+
240
+ return ForwardBatchOutput(
241
+ logits_output=logits_output,
242
+ next_token_ids=next_token_ids,
243
+ num_accepted_tokens=num_accepted_tokens,
244
+ can_run_cuda_graph=can_run_cuda_graph,
245
+ )