sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -12,38 +12,44 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """A tensor parallel worker."""
15
+ from __future__ import annotations
15
16
 
16
17
  import logging
17
- import threading
18
- from typing import Optional, Tuple, Union
18
+ from typing import TYPE_CHECKING, Optional
19
19
 
20
20
  import torch
21
21
 
22
22
  from sglang.srt.configs.model_config import ModelConfig
23
23
  from sglang.srt.distributed import get_pp_group, get_world_group
24
- from sglang.srt.hf_transformers_utils import (
25
- get_processor,
26
- get_tokenizer,
27
- get_tokenizer_from_processor,
28
- )
29
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
30
24
  from sglang.srt.managers.io_struct import (
25
+ DestroyWeightsUpdateGroupReqInput,
31
26
  GetWeightsByNameReqInput,
27
+ InitWeightsSendGroupForRemoteInstanceReqInput,
32
28
  InitWeightsUpdateGroupReqInput,
33
29
  LoadLoRAAdapterReqInput,
30
+ SendWeightsToRemoteInstanceReqInput,
34
31
  UnloadLoRAAdapterReqInput,
35
32
  UpdateWeightFromDiskReqInput,
36
33
  UpdateWeightsFromDistributedReqInput,
37
34
  UpdateWeightsFromTensorReqInput,
38
35
  )
39
36
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
37
+ from sglang.srt.managers.scheduler import GenerationBatchResult
40
38
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
41
39
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
42
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
43
41
  from sglang.srt.model_executor.model_runner import ModelRunner
44
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
45
42
  from sglang.srt.server_args import ServerArgs
46
43
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
44
+ from sglang.srt.utils.hf_transformers_utils import (
45
+ get_processor,
46
+ get_tokenizer,
47
+ get_tokenizer_from_processor,
48
+ )
49
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
50
+
51
+ if TYPE_CHECKING:
52
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
47
53
 
48
54
  logger = logging.getLogger(__name__)
49
55
 
@@ -78,6 +84,11 @@ class TpModelWorker:
78
84
  if not is_draft_worker
79
85
  else server_args.speculative_draft_model_path
80
86
  ),
87
+ model_revision=(
88
+ server_args.revision
89
+ if not is_draft_worker
90
+ else server_args.speculative_draft_model_revision
91
+ ),
81
92
  is_draft_model=is_draft_worker,
82
93
  )
83
94
 
@@ -137,8 +148,8 @@ class TpModelWorker:
137
148
  assert self.max_running_requests > 0, "max_running_request is zero"
138
149
  self.max_queued_requests = server_args.max_queued_requests
139
150
  assert (
140
- self.max_running_requests > 0
141
- ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
151
+ self.max_queued_requests is None or self.max_queued_requests >= 1
152
+ ), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
142
153
  self.max_req_len = min(
143
154
  self.model_config.context_len - 1,
144
155
  self.max_total_num_tokens - 1,
@@ -162,10 +173,10 @@ class TpModelWorker:
162
173
 
163
174
  self.hicache_layer_transfer_counter = None
164
175
 
165
- def register_hicache_layer_transfer_counter(self, counter):
176
+ def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
166
177
  self.hicache_layer_transfer_counter = counter
167
178
 
168
- def set_hicache_consumer(self, consumer_index):
179
+ def set_hicache_consumer(self, consumer_index: int):
169
180
  if self.hicache_layer_transfer_counter is not None:
170
181
  self.hicache_layer_transfer_counter.set_consumer(consumer_index)
171
182
 
@@ -220,11 +231,11 @@ class TpModelWorker:
220
231
  def forward_batch_generation(
221
232
  self,
222
233
  model_worker_batch: ModelWorkerBatch,
223
- launch_done: Optional[threading.Event] = None,
224
- skip_sample: bool = False,
225
- ) -> Tuple[
226
- Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
227
- ]:
234
+ is_verify: bool = False,
235
+ ) -> GenerationBatchResult:
236
+ # update the consumer index of hicache to the running batch
237
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
238
+
228
239
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
229
240
 
230
241
  pp_proxy_tensors = None
@@ -239,23 +250,51 @@ class TpModelWorker:
239
250
  logits_output, can_run_cuda_graph = self.model_runner.forward(
240
251
  forward_batch, pp_proxy_tensors=pp_proxy_tensors
241
252
  )
242
- if launch_done is not None:
243
- launch_done.set()
253
+ batch_result = GenerationBatchResult(
254
+ logits_output=logits_output,
255
+ can_run_cuda_graph=can_run_cuda_graph,
256
+ )
244
257
 
245
- if skip_sample:
246
- next_token_ids = None
258
+ if is_verify:
259
+ # Skip sampling and return logits for target forward
260
+ return batch_result
261
+
262
+ if model_worker_batch.delay_sample_launch:
263
+ batch_result.delay_sample_launch = True
264
+ batch_result.forward_batch = forward_batch
265
+ return batch_result
266
+
267
+ if model_worker_batch.is_prefill_only:
268
+ # For prefill-only requests, create dummy token IDs on CPU
269
+ # The size should match the batch size (number of sequences), not total tokens
270
+ batch_result.next_token_ids = torch.zeros(
271
+ len(model_worker_batch.seq_lens),
272
+ dtype=torch.long,
273
+ device=model_worker_batch.input_ids.device,
274
+ )
275
+ if (
276
+ model_worker_batch.return_logprob
277
+ and logits_output.next_token_logits is not None
278
+ ):
279
+ # NOTE: Compute logprobs without full sampling
280
+ self.model_runner.compute_logprobs_only(
281
+ logits_output, model_worker_batch
282
+ )
247
283
  else:
248
- next_token_ids = self.model_runner.sample(
249
- logits_output, model_worker_batch
284
+ batch_result.next_token_ids = self.model_runner.sample(
285
+ logits_output, forward_batch
250
286
  )
251
287
 
252
- return logits_output, next_token_ids, can_run_cuda_graph
288
+ return batch_result
253
289
  else:
254
290
  pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
255
291
  forward_batch,
256
292
  pp_proxy_tensors=pp_proxy_tensors,
257
293
  )
258
- return pp_proxy_tensors.tensors, None, can_run_cuda_graph
294
+ return GenerationBatchResult(
295
+ pp_hidden_states_proxy_tensors=pp_proxy_tensors,
296
+ can_run_cuda_graph=can_run_cuda_graph,
297
+ )
259
298
 
260
299
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
261
300
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -280,6 +319,37 @@ class TpModelWorker:
280
319
  )
281
320
  return success, message
282
321
 
322
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
323
+ success, message = self.model_runner.destroy_weights_update_group(
324
+ recv_req.group_name,
325
+ )
326
+ return success, message
327
+
328
+ def init_weights_send_group_for_remote_instance(
329
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
330
+ ):
331
+ success, message = (
332
+ self.model_runner.init_weights_send_group_for_remote_instance(
333
+ recv_req.master_address,
334
+ recv_req.ports,
335
+ recv_req.group_rank,
336
+ recv_req.world_size,
337
+ recv_req.group_name,
338
+ recv_req.backend,
339
+ )
340
+ )
341
+ return success, message
342
+
343
+ def send_weights_to_remote_instance(
344
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
345
+ ):
346
+ success, message = self.model_runner.send_weights_to_remote_instance(
347
+ recv_req.master_address,
348
+ recv_req.ports,
349
+ recv_req.group_name,
350
+ )
351
+ return success, message
352
+
283
353
  def update_weights_from_distributed(
284
354
  self, recv_req: UpdateWeightsFromDistributedReqInput
285
355
  ):
@@ -2,11 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import multiprocessing as mp
5
- from http import HTTPStatus
6
5
  from typing import TYPE_CHECKING, Dict, List, Optional
7
6
 
8
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
8
+ from sglang.srt.managers.schedule_batch import Req
10
9
  from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
11
10
 
12
11
  if TYPE_CHECKING:
@@ -97,46 +96,3 @@ def get_logprob_from_pp_outputs(
97
96
  ]
98
97
 
99
98
  return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
100
-
101
-
102
- class DPBalanceMeta:
103
- """
104
- This class will be use in scheduler and dp controller
105
- """
106
-
107
- def __init__(self, num_workers: int):
108
- self.num_workers = num_workers
109
- self._manager = mp.Manager()
110
- self.mutex = self._manager.Lock()
111
-
112
- init_local_tokens = [0] * self.num_workers
113
- init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
114
-
115
- self.shared_state = self._manager.Namespace()
116
- self.shared_state.local_tokens = self._manager.list(init_local_tokens)
117
- self.shared_state.onfly_info = self._manager.list(init_onfly_info)
118
-
119
- def destructor(self):
120
- # we must destructor this class manually
121
- self._manager.shutdown()
122
-
123
- def get_shared_onfly(self) -> List[Dict[int, int]]:
124
- return [dict(d) for d in self.shared_state.onfly_info]
125
-
126
- def set_shared_onfly_info(self, data: List[Dict[int, int]]):
127
- self.shared_state.onfly_info = data
128
-
129
- def get_shared_local_tokens(self) -> List[int]:
130
- return list(self.shared_state.local_tokens)
131
-
132
- def set_shared_local_tokens(self, data: List[int]):
133
- self.shared_state.local_tokens = data
134
-
135
- def __getstate__(self):
136
- state = self.__dict__.copy()
137
- del state["_manager"]
138
- return state
139
-
140
- def __setstate__(self, state):
141
- self.__dict__.update(state)
142
- self._manager = None
@@ -27,7 +27,7 @@ import triton
27
27
  import triton.language as tl
28
28
 
29
29
  from sglang.srt.mem_cache.memory_pool import SWAKVPool
30
- from sglang.srt.utils import get_bool_env_var, next_power_of_2
30
+ from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
31
31
 
32
32
  if TYPE_CHECKING:
33
33
  from sglang.srt.mem_cache.memory_pool import KVCache
@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
274
274
  self.full_to_swa_index_mapping[free_index] = 0
275
275
 
276
276
  def backup_state(self):
277
- raise NotImplementedError
277
+ return [
278
+ self.full_attn_allocator.backup_state(),
279
+ self.swa_attn_allocator.backup_state(),
280
+ ]
278
281
 
279
282
  def restore_state(self, state):
280
- raise NotImplementedError
283
+ assert len(state) == 2
284
+ self.full_attn_allocator.restore_state(state[0])
285
+ self.swa_attn_allocator.restore_state(state[1])
281
286
 
282
287
  def clear(self):
283
288
  self.swa_attn_allocator.clear()
@@ -294,7 +299,6 @@ def alloc_extend_kernel(
294
299
  last_loc_ptr,
295
300
  free_page_ptr,
296
301
  out_indices,
297
- ret_values,
298
302
  bs_upper: tl.constexpr,
299
303
  page_size: tl.constexpr,
300
304
  max_num_extend_tokens: tl.constexpr,
@@ -323,13 +327,6 @@ def alloc_extend_kernel(
323
327
  sum_num_new_pages = tl.sum(num_new_pages)
324
328
  new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
325
329
 
326
- # Return value
327
- if pid == tl.num_programs(0) - 1:
328
- merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
329
- tl.int64
330
- )
331
- tl.store(ret_values, merged_value)
332
-
333
330
  # Part 1: fill the old partial page
334
331
  last_loc = tl.load(last_loc_ptr + pid)
335
332
  num_part1 = (
@@ -381,7 +378,6 @@ def alloc_decode_kernel(
381
378
  last_loc_ptr,
382
379
  free_page_ptr,
383
380
  out_indices,
384
- ret_values,
385
381
  bs_upper: tl.constexpr,
386
382
  page_size: tl.constexpr,
387
383
  ):
@@ -404,10 +400,6 @@ def alloc_decode_kernel(
404
400
  sum_num_new_pages = tl.sum(num_new_pages)
405
401
  new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
406
402
 
407
- # Return value
408
- if pid == tl.num_programs(0) - 1:
409
- tl.store(ret_values, sum_num_new_pages)
410
-
411
403
  if num_page_start_loc_self == 0:
412
404
  last_loc = tl.load(last_loc_ptr + pid)
413
405
  tl.store(out_indices + pid, last_loc + 1)
@@ -438,7 +430,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
438
430
  super().__init__(size, page_size, dtype, device, kvcache, need_sort)
439
431
  self.num_pages = size // page_size
440
432
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
441
- self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
442
433
  self.seen_max_num_extend_tokens_next_power_of_2 = 1
443
434
  self.clear()
444
435
 
@@ -468,7 +459,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
468
459
  def alloc_extend(
469
460
  self,
470
461
  prefix_lens: torch.Tensor,
462
+ prefix_lens_cpu: torch.Tensor,
471
463
  seq_lens: torch.Tensor,
464
+ seq_lens_cpu: torch.Tensor,
472
465
  last_loc: torch.Tensor,
473
466
  extend_num_tokens: int,
474
467
  ):
@@ -497,7 +490,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
497
490
  last_loc,
498
491
  self.free_pages,
499
492
  out_indices,
500
- self.ret_values,
501
493
  next_power_of_2(bs),
502
494
  self.page_size,
503
495
  self.seen_max_num_extend_tokens_next_power_of_2,
@@ -506,8 +498,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
506
498
  if self.debug_mode:
507
499
  assert len(torch.unique(out_indices)) == len(out_indices)
508
500
 
509
- merged_value = self.ret_values.item()
510
- num_new_pages = merged_value >> 32
501
+ num_new_pages = get_num_new_pages(
502
+ seq_lens=seq_lens_cpu,
503
+ page_size=self.page_size,
504
+ prefix_lens=prefix_lens_cpu,
505
+ )
511
506
  if num_new_pages > len(self.free_pages):
512
507
  return None
513
508
 
@@ -517,6 +512,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
517
512
  def alloc_decode(
518
513
  self,
519
514
  seq_lens: torch.Tensor,
515
+ seq_lens_cpu: torch.Tensor,
520
516
  last_loc: torch.Tensor,
521
517
  ):
522
518
  if self.debug_mode:
@@ -534,7 +530,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
534
530
  last_loc,
535
531
  self.free_pages,
536
532
  out_indices,
537
- self.ret_values,
538
533
  next_power_of_2(bs),
539
534
  self.page_size,
540
535
  )
@@ -542,7 +537,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
542
537
  if self.debug_mode:
543
538
  assert len(torch.unique(out_indices)) == len(out_indices)
544
539
 
545
- num_new_pages = self.ret_values.item()
540
+ num_new_pages = get_num_new_pages(
541
+ seq_lens=seq_lens_cpu,
542
+ page_size=self.page_size,
543
+ decode=True,
544
+ )
546
545
  if num_new_pages > len(self.free_pages):
547
546
  return None
548
547
 
@@ -1,13 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
4
-
5
3
  import torch
6
4
 
7
5
  from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
8
-
9
- if TYPE_CHECKING:
10
- from sglang.srt.mem_cache.memory_pool import KVCache
6
+ from sglang.srt.utils import get_num_new_pages
11
7
 
12
8
 
13
9
  def alloc_extend_kernel_ascend(
@@ -69,7 +65,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
69
65
  def alloc_extend(
70
66
  self,
71
67
  prefix_lens: torch.Tensor,
68
+ prefix_lens_cpu: torch.Tensor,
72
69
  seq_lens: torch.Tensor,
70
+ seq_lens_cpu: torch.Tensor,
73
71
  last_loc: torch.Tensor,
74
72
  extend_num_tokens: int,
75
73
  ):
@@ -79,42 +77,54 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
79
77
  )
80
78
 
81
79
  num_new_pages = (
82
- (
83
- (seq_lens + self.page_size - 1) // self.page_size
84
- - (prefix_lens + self.page_size - 1) // self.page_size
85
- )
86
- .sum()
87
- .item()
88
- )
89
- if self.need_sort and num_new_pages > len(self.free_pages):
80
+ (seq_lens + self.page_size - 1) // self.page_size
81
+ - (prefix_lens + self.page_size - 1) // self.page_size
82
+ ).sum()
83
+ num_new_pages_item = num_new_pages.item()
84
+ if self.need_sort and num_new_pages_item > len(self.free_pages):
90
85
  self.merge_and_sort_free()
91
86
 
92
- if num_new_pages > len(self.free_pages):
87
+ if num_new_pages_item > len(self.free_pages):
93
88
  return None
94
89
 
95
90
  out_indices = torch.empty(
96
- (extend_num_tokens,), dtype=torch.int32, device=self.device
91
+ (extend_num_tokens,), dtype=torch.int64, device=self.device
97
92
  )
98
93
 
99
- alloc_extend_kernel_ascend(
100
- prefix_lens,
101
- seq_lens,
102
- last_loc,
103
- self.free_pages,
104
- out_indices,
105
- self.page_size,
106
- self.device,
107
- )
94
+ if num_new_pages_item < 200:
95
+ import sgl_kernel_npu
96
+
97
+ torch.ops.npu.alloc_extend(
98
+ prefix_lens,
99
+ seq_lens,
100
+ last_loc,
101
+ self.free_pages,
102
+ self.page_size,
103
+ out_indices,
104
+ num_new_pages,
105
+ )
106
+
107
+ else:
108
+ alloc_extend_kernel_ascend(
109
+ prefix_lens,
110
+ seq_lens,
111
+ last_loc,
112
+ self.free_pages,
113
+ out_indices,
114
+ self.page_size,
115
+ self.device,
116
+ )
108
117
 
109
118
  if self.debug_mode:
110
119
  assert len(torch.unique(out_indices)) == len(out_indices)
111
120
 
112
- self.free_pages = self.free_pages[num_new_pages:]
121
+ self.free_pages = self.free_pages[num_new_pages_item:]
113
122
  return out_indices
114
123
 
115
124
  def alloc_decode(
116
125
  self,
117
126
  seq_lens: torch.Tensor,
127
+ seq_lens_cpu: torch.Tensor,
118
128
  last_loc: torch.Tensor,
119
129
  ):
120
130
  if self.debug_mode:
@@ -122,8 +132,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
122
132
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
123
133
  )
124
134
 
125
- need_new_pages = (seq_lens % self.page_size == 1).int()
126
- num_new_pages = need_new_pages.sum().item()
135
+ num_new_pages = get_num_new_pages(
136
+ seq_lens=seq_lens_cpu,
137
+ page_size=self.page_size,
138
+ decode=True,
139
+ )
127
140
 
128
141
  if num_new_pages > len(self.free_pages):
129
142
  self.merge_and_sort_free()
@@ -131,6 +144,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
131
144
  if num_new_pages > len(self.free_pages):
132
145
  return None
133
146
 
147
+ need_new_pages = (seq_lens % self.page_size == 1).int()
134
148
  end_new_pages = torch.cumsum(need_new_pages, 0)
135
149
  start_new_pages = end_new_pages - need_new_pages
136
150
  if num_new_pages == 0:
@@ -36,7 +36,7 @@ class BasePrefixCache(ABC):
36
36
  pass
37
37
 
38
38
  @abstractmethod
39
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
39
+ def match_prefix(self, key: Any, **kwargs) -> MatchResult:
40
40
  pass
41
41
 
42
42
  @abstractmethod
@@ -28,6 +28,13 @@ class ChunkCache(BasePrefixCache):
28
28
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
29
29
  self.page_size = page_size
30
30
 
31
+ # NOTE (csy): this is to determine if a cache has prefix matching feature.
32
+ # Chunk cache always return True to indicate no prefix matching.
33
+ # TODO (csy): Using a prefix cache trait to replace this
34
+ @property
35
+ def disable(self):
36
+ return True
37
+
31
38
  def reset(self):
32
39
  pass
33
40
 
@@ -38,7 +45,7 @@ class ChunkCache(BasePrefixCache):
38
45
  last_host_node=None,
39
46
  )
40
47
 
41
- def cache_finished_req(self, req: Req):
48
+ def cache_finished_req(self, req: Req, insert: bool = True):
42
49
  kv_indices = self.req_to_token_pool.req_to_token[
43
50
  req.req_pool_idx,
44
51
  # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
@@ -53,7 +60,7 @@ class ChunkCache(BasePrefixCache):
53
60
  ]
54
61
 
55
62
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
56
- req.prefix_indices = kv_indices
63
+ req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
57
64
 
58
65
  def evict(self, num_tokens: int):
59
66
  pass
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, List, Tuple, Union
5
+
6
+ if TYPE_CHECKING:
7
+ from sglang.srt.mem_cache.radix_cache import TreeNode
8
+
9
+
10
+ class EvictionStrategy(ABC):
11
+ @abstractmethod
12
+ def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
13
+ pass
14
+
15
+
16
+ class LRUStrategy(EvictionStrategy):
17
+ def get_priority(self, node: "TreeNode") -> float:
18
+ return node.last_access_time
19
+
20
+
21
+ class LFUStrategy(EvictionStrategy):
22
+ def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
23
+ return (node.hit_count, node.last_access_time)