sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -12,25 +12,24 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """A tensor parallel worker."""
15
+ from __future__ import annotations
15
16
 
16
17
  import logging
17
18
  import threading
18
- from typing import Optional, Tuple, Union
19
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
19
20
 
20
21
  import torch
21
22
 
22
23
  from sglang.srt.configs.model_config import ModelConfig
23
24
  from sglang.srt.distributed import get_pp_group, get_world_group
24
- from sglang.srt.hf_transformers_utils import (
25
- get_processor,
26
- get_tokenizer,
27
- get_tokenizer_from_processor,
28
- )
29
25
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
30
26
  from sglang.srt.managers.io_struct import (
27
+ DestroyWeightsUpdateGroupReqInput,
31
28
  GetWeightsByNameReqInput,
29
+ InitWeightsSendGroupForRemoteInstanceReqInput,
32
30
  InitWeightsUpdateGroupReqInput,
33
31
  LoadLoRAAdapterReqInput,
32
+ SendWeightsToRemoteInstanceReqInput,
34
33
  UnloadLoRAAdapterReqInput,
35
34
  UpdateWeightFromDiskReqInput,
36
35
  UpdateWeightsFromDistributedReqInput,
@@ -39,11 +38,23 @@ from sglang.srt.managers.io_struct import (
39
38
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
40
39
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
41
40
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
42
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
41
+ from sglang.srt.model_executor.forward_batch_info import (
42
+ ForwardBatch,
43
+ ForwardBatchOutput,
44
+ PPProxyTensors,
45
+ )
43
46
  from sglang.srt.model_executor.model_runner import ModelRunner
44
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
45
47
  from sglang.srt.server_args import ServerArgs
46
48
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
49
+ from sglang.srt.utils.hf_transformers_utils import (
50
+ get_processor,
51
+ get_tokenizer,
52
+ get_tokenizer_from_processor,
53
+ )
54
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
55
+
56
+ if TYPE_CHECKING:
57
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
47
58
 
48
59
  logger = logging.getLogger(__name__)
49
60
 
@@ -78,6 +89,11 @@ class TpModelWorker:
78
89
  if not is_draft_worker
79
90
  else server_args.speculative_draft_model_path
80
91
  ),
92
+ model_revision=(
93
+ server_args.revision
94
+ if not is_draft_worker
95
+ else server_args.speculative_draft_model_revision
96
+ ),
81
97
  is_draft_model=is_draft_worker,
82
98
  )
83
99
 
@@ -137,8 +153,8 @@ class TpModelWorker:
137
153
  assert self.max_running_requests > 0, "max_running_request is zero"
138
154
  self.max_queued_requests = server_args.max_queued_requests
139
155
  assert (
140
- self.max_running_requests > 0
141
- ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
156
+ self.max_queued_requests is None or self.max_queued_requests >= 1
157
+ ), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
142
158
  self.max_req_len = min(
143
159
  self.model_config.context_len - 1,
144
160
  self.max_total_num_tokens - 1,
@@ -162,10 +178,10 @@ class TpModelWorker:
162
178
 
163
179
  self.hicache_layer_transfer_counter = None
164
180
 
165
- def register_hicache_layer_transfer_counter(self, counter):
181
+ def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
166
182
  self.hicache_layer_transfer_counter = counter
167
183
 
168
- def set_hicache_consumer(self, consumer_index):
184
+ def set_hicache_consumer(self, consumer_index: int):
169
185
  if self.hicache_layer_transfer_counter is not None:
170
186
  self.hicache_layer_transfer_counter.set_consumer(consumer_index)
171
187
 
@@ -221,10 +237,11 @@ class TpModelWorker:
221
237
  self,
222
238
  model_worker_batch: ModelWorkerBatch,
223
239
  launch_done: Optional[threading.Event] = None,
224
- skip_sample: bool = False,
225
- ) -> Tuple[
226
- Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
227
- ]:
240
+ is_verify: bool = False,
241
+ ) -> ForwardBatchOutput:
242
+ # update the consumer index of hicache to the running batch
243
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
244
+
228
245
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
229
246
 
230
247
  pp_proxy_tensors = None
@@ -242,20 +259,31 @@ class TpModelWorker:
242
259
  if launch_done is not None:
243
260
  launch_done.set()
244
261
 
245
- if skip_sample:
246
- next_token_ids = None
247
- else:
248
- next_token_ids = self.model_runner.sample(
262
+ skip_sample = is_verify or model_worker_batch.is_prefill_only
263
+ next_token_ids = None
264
+
265
+ if not skip_sample:
266
+ next_token_ids = self.model_runner.sample(logits_output, forward_batch)
267
+ elif model_worker_batch.return_logprob and not is_verify:
268
+ # NOTE: Compute logprobs without full sampling
269
+ self.model_runner.compute_logprobs_only(
249
270
  logits_output, model_worker_batch
250
271
  )
251
272
 
252
- return logits_output, next_token_ids, can_run_cuda_graph
273
+ return ForwardBatchOutput(
274
+ logits_output=logits_output,
275
+ next_token_ids=next_token_ids,
276
+ can_run_cuda_graph=can_run_cuda_graph,
277
+ )
253
278
  else:
254
279
  pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
255
280
  forward_batch,
256
281
  pp_proxy_tensors=pp_proxy_tensors,
257
282
  )
258
- return pp_proxy_tensors.tensors, None, can_run_cuda_graph
283
+ return ForwardBatchOutput(
284
+ pp_proxy_tensors=pp_proxy_tensors,
285
+ can_run_cuda_graph=can_run_cuda_graph,
286
+ )
259
287
 
260
288
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
261
289
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -280,6 +308,37 @@ class TpModelWorker:
280
308
  )
281
309
  return success, message
282
310
 
311
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
312
+ success, message = self.model_runner.destroy_weights_update_group(
313
+ recv_req.group_name,
314
+ )
315
+ return success, message
316
+
317
+ def init_weights_send_group_for_remote_instance(
318
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
319
+ ):
320
+ success, message = (
321
+ self.model_runner.init_weights_send_group_for_remote_instance(
322
+ recv_req.master_address,
323
+ recv_req.ports,
324
+ recv_req.group_rank,
325
+ recv_req.world_size,
326
+ recv_req.group_name,
327
+ recv_req.backend,
328
+ )
329
+ )
330
+ return success, message
331
+
332
+ def send_weights_to_remote_instance(
333
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
334
+ ):
335
+ success, message = self.model_runner.send_weights_to_remote_instance(
336
+ recv_req.master_address,
337
+ recv_req.ports,
338
+ recv_req.group_name,
339
+ )
340
+ return success, message
341
+
283
342
  def update_weights_from_distributed(
284
343
  self, recv_req: UpdateWeightsFromDistributedReqInput
285
344
  ):
@@ -12,42 +12,42 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """A tensor parallel worker."""
15
+ from __future__ import annotations
15
16
 
16
17
  import dataclasses
17
18
  import logging
18
19
  import signal
19
20
  import threading
20
21
  from queue import Queue
21
- from typing import Optional, Tuple
22
+ from typing import TYPE_CHECKING, List, Optional, Tuple
22
23
 
23
24
  import psutil
24
25
  import torch
25
26
 
26
27
  from sglang.srt.managers.io_struct import (
28
+ DestroyWeightsUpdateGroupReqInput,
27
29
  GetWeightsByNameReqInput,
30
+ InitWeightsSendGroupForRemoteInstanceReqInput,
28
31
  InitWeightsUpdateGroupReqInput,
29
32
  LoadLoRAAdapterReqInput,
33
+ SendWeightsToRemoteInstanceReqInput,
30
34
  UnloadLoRAAdapterReqInput,
31
35
  UpdateWeightFromDiskReqInput,
32
36
  UpdateWeightsFromDistributedReqInput,
33
37
  UpdateWeightsFromTensorReqInput,
34
38
  )
39
+ from sglang.srt.managers.overlap_utils import FutureMap
35
40
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
36
41
  from sglang.srt.managers.tp_worker import TpModelWorker
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
37
43
  from sglang.srt.server_args import ServerArgs
38
- from sglang.srt.utils import DynamicGradMode, get_compiler_backend
44
+ from sglang.srt.utils import DynamicGradMode
39
45
  from sglang.utils import get_exception_traceback
40
46
 
41
- logger = logging.getLogger(__name__)
42
-
47
+ if TYPE_CHECKING:
48
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
43
49
 
44
- @torch.compile(dynamic=True, backend=get_compiler_backend())
45
- def resolve_future_token_ids(input_ids, future_token_ids_map):
46
- input_ids[:] = torch.where(
47
- input_ids < 0,
48
- future_token_ids_map[torch.clamp(-input_ids, min=0)],
49
- input_ids,
50
- )
50
+ logger = logging.getLogger(__name__)
51
51
 
52
52
 
53
53
  class TpModelWorkerClient:
@@ -72,14 +72,10 @@ class TpModelWorkerClient:
72
72
  self.gpu_id = gpu_id
73
73
 
74
74
  # Init future mappings
75
- self.future_token_ids_ct = 0
76
- self.future_token_ids_limit = self.max_running_requests * 3
77
- self.future_token_ids_map = torch.empty(
78
- (self.max_running_requests * 5,), dtype=torch.int64, device=self.device
79
- )
75
+ self.future_map = FutureMap(self.max_running_requests, self.device)
80
76
 
81
77
  # Launch threads
82
- self.input_queue = Queue()
78
+ self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
83
79
  self.output_queue = Queue()
84
80
  self.forward_stream = torch.get_device_module(self.device).Stream()
85
81
  self.forward_thread = threading.Thread(
@@ -93,13 +89,9 @@ class TpModelWorkerClient:
93
89
 
94
90
  self.hicache_layer_transfer_counter = None
95
91
 
96
- def register_hicache_layer_transfer_counter(self, counter):
92
+ def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
97
93
  self.hicache_layer_transfer_counter = counter
98
94
 
99
- def set_hicache_consumer(self, consumer_index):
100
- if self.hicache_layer_transfer_counter is not None:
101
- self.hicache_layer_transfer_counter.set_consumer(consumer_index)
102
-
103
95
  def get_worker_info(self):
104
96
  return self.worker.get_worker_info()
105
97
 
@@ -147,10 +139,10 @@ class TpModelWorkerClient:
147
139
  @DynamicGradMode()
148
140
  def forward_thread_func_(self):
149
141
  batch_pt = 0
150
- batch_lists = [None] * 2
142
+ batch_lists: List = [None] * 2
151
143
 
152
144
  while True:
153
- model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
145
+ model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
154
146
  if not model_worker_batch:
155
147
  break
156
148
 
@@ -166,29 +158,35 @@ class TpModelWorkerClient:
166
158
  copy_done = torch.get_device_module(self.device).Event()
167
159
 
168
160
  # Resolve future tokens in the input
169
- input_ids = model_worker_batch.input_ids
170
- resolve_future_token_ids(input_ids, self.future_token_ids_map)
161
+ self.future_map.resolve_future(model_worker_batch)
171
162
 
172
- # update the consumer index of hicache to the running batch
173
- self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
174
163
  # Run forward
164
+ forward_batch_output = self.worker.forward_batch_generation(
165
+ model_worker_batch,
166
+ model_worker_batch.launch_done,
167
+ )
168
+
175
169
  logits_output, next_token_ids, can_run_cuda_graph = (
176
- self.worker.forward_batch_generation(
177
- model_worker_batch, model_worker_batch.launch_done
178
- )
170
+ forward_batch_output.logits_output,
171
+ forward_batch_output.next_token_ids,
172
+ forward_batch_output.can_run_cuda_graph,
179
173
  )
180
174
 
181
175
  # Update the future token ids map
182
176
  bs = len(model_worker_batch.seq_lens)
183
- self.future_token_ids_map[
184
- future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
185
- ] = next_token_ids
177
+ if model_worker_batch.is_prefill_only:
178
+ # For prefill-only requests, create dummy token IDs on CPU
179
+ next_token_ids = torch.zeros(bs, dtype=torch.long)
180
+
181
+ # store the future indices into future map
182
+ self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
186
183
 
187
184
  # Copy results to the CPU
188
185
  if model_worker_batch.return_logprob:
189
- logits_output.next_token_logprobs = (
190
- logits_output.next_token_logprobs.to("cpu", non_blocking=True)
191
- )
186
+ if logits_output.next_token_logprobs is not None:
187
+ logits_output.next_token_logprobs = (
188
+ logits_output.next_token_logprobs.to("cpu", non_blocking=True)
189
+ )
192
190
  if logits_output.input_token_logprobs is not None:
193
191
  logits_output.input_token_logprobs = (
194
192
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
@@ -197,7 +195,9 @@ class TpModelWorkerClient:
197
195
  logits_output.hidden_states = logits_output.hidden_states.to(
198
196
  "cpu", non_blocking=True
199
197
  )
200
- next_token_ids = next_token_ids.to("cpu", non_blocking=True)
198
+ # Only copy to CPU if not already on CPU
199
+ if next_token_ids.device.type != "cpu":
200
+ next_token_ids = next_token_ids.to("cpu", non_blocking=True)
201
201
  copy_done.record()
202
202
 
203
203
  self.output_queue.put(
@@ -221,16 +221,16 @@ class TpModelWorkerClient:
221
221
  logits_output.next_token_logprobs = (
222
222
  logits_output.next_token_logprobs.tolist()
223
223
  )
224
- if logits_output.input_token_logprobs is not None:
225
- logits_output.input_token_logprobs = tuple(
226
- logits_output.input_token_logprobs.tolist()
227
- )
224
+ if logits_output.input_token_logprobs is not None:
225
+ logits_output.input_token_logprobs = tuple(
226
+ logits_output.input_token_logprobs.tolist()
227
+ )
228
228
  next_token_ids = next_token_ids.tolist()
229
229
  return logits_output, next_token_ids, can_run_cuda_graph
230
230
 
231
231
  def forward_batch_generation(
232
232
  self, model_worker_batch: ModelWorkerBatch
233
- ) -> Tuple[None, torch.Tensor, bool]:
233
+ ) -> ForwardBatchOutput:
234
234
  # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
235
235
  sampling_info = model_worker_batch.sampling_info
236
236
  sampling_info.update_penalties()
@@ -245,21 +245,18 @@ class TpModelWorkerClient:
245
245
  sync_event.record(self.scheduler_stream)
246
246
 
247
247
  # Push a new batch to the queue
248
- self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
249
-
250
- # Allocate output future objects
251
248
  bs = len(model_worker_batch.seq_lens)
252
- future_next_token_ids = torch.arange(
253
- -(self.future_token_ids_ct + 1),
254
- -(self.future_token_ids_ct + 1 + bs),
255
- -1,
256
- dtype=torch.int64,
257
- device=self.device,
249
+ cur_future_map_ct = self.future_map.update_ct(bs)
250
+ self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
251
+
252
+ # get this forward batch's future token ids
253
+ future_next_token_ids = self.future_map.update_next_future(
254
+ cur_future_map_ct, bs
255
+ )
256
+ return ForwardBatchOutput(
257
+ next_token_ids=future_next_token_ids,
258
+ can_run_cuda_graph=False,
258
259
  )
259
- self.future_token_ids_ct = (
260
- self.future_token_ids_ct + bs
261
- ) % self.future_token_ids_limit
262
- return None, future_next_token_ids, False
263
260
 
264
261
  def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
265
262
  success, message = self.worker.update_weights_from_disk(recv_req)
@@ -269,6 +266,24 @@ class TpModelWorkerClient:
269
266
  success, message = self.worker.init_weights_update_group(recv_req)
270
267
  return success, message
271
268
 
269
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
270
+ success, message = self.worker.destroy_weights_update_group(recv_req)
271
+ return success, message
272
+
273
+ def init_weights_send_group_for_remote_instance(
274
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
275
+ ):
276
+ success, message = self.worker.init_weights_send_group_for_remote_instance(
277
+ recv_req
278
+ )
279
+ return success, message
280
+
281
+ def send_weights_to_remote_instance(
282
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
283
+ ):
284
+ success, message = self.worker.send_weights_to_remote_instance(recv_req)
285
+ return success, message
286
+
272
287
  def update_weights_from_distributed(
273
288
  self, recv_req: UpdateWeightsFromDistributedReqInput
274
289
  ):
@@ -2,11 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import multiprocessing as mp
5
- from http import HTTPStatus
6
5
  from typing import TYPE_CHECKING, Dict, List, Optional
7
6
 
8
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
8
+ from sglang.srt.managers.schedule_batch import Req
10
9
  from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
11
10
 
12
11
  if TYPE_CHECKING:
@@ -97,46 +96,3 @@ def get_logprob_from_pp_outputs(
97
96
  ]
98
97
 
99
98
  return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
100
-
101
-
102
- class DPBalanceMeta:
103
- """
104
- This class will be use in scheduler and dp controller
105
- """
106
-
107
- def __init__(self, num_workers: int):
108
- self.num_workers = num_workers
109
- self._manager = mp.Manager()
110
- self.mutex = self._manager.Lock()
111
-
112
- init_local_tokens = [0] * self.num_workers
113
- init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
114
-
115
- self.shared_state = self._manager.Namespace()
116
- self.shared_state.local_tokens = self._manager.list(init_local_tokens)
117
- self.shared_state.onfly_info = self._manager.list(init_onfly_info)
118
-
119
- def destructor(self):
120
- # we must destructor this class manually
121
- self._manager.shutdown()
122
-
123
- def get_shared_onfly(self) -> List[Dict[int, int]]:
124
- return [dict(d) for d in self.shared_state.onfly_info]
125
-
126
- def set_shared_onfly_info(self, data: List[Dict[int, int]]):
127
- self.shared_state.onfly_info = data
128
-
129
- def get_shared_local_tokens(self) -> List[int]:
130
- return list(self.shared_state.local_tokens)
131
-
132
- def set_shared_local_tokens(self, data: List[int]):
133
- self.shared_state.local_tokens = data
134
-
135
- def __getstate__(self):
136
- state = self.__dict__.copy()
137
- del state["_manager"]
138
- return state
139
-
140
- def __setstate__(self, state):
141
- self.__dict__.update(state)
142
- self._manager = None
@@ -27,7 +27,7 @@ import triton
27
27
  import triton.language as tl
28
28
 
29
29
  from sglang.srt.mem_cache.memory_pool import SWAKVPool
30
- from sglang.srt.utils import get_bool_env_var, next_power_of_2
30
+ from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
31
31
 
32
32
  if TYPE_CHECKING:
33
33
  from sglang.srt.mem_cache.memory_pool import KVCache
@@ -294,7 +294,6 @@ def alloc_extend_kernel(
294
294
  last_loc_ptr,
295
295
  free_page_ptr,
296
296
  out_indices,
297
- ret_values,
298
297
  bs_upper: tl.constexpr,
299
298
  page_size: tl.constexpr,
300
299
  max_num_extend_tokens: tl.constexpr,
@@ -323,13 +322,6 @@ def alloc_extend_kernel(
323
322
  sum_num_new_pages = tl.sum(num_new_pages)
324
323
  new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
325
324
 
326
- # Return value
327
- if pid == tl.num_programs(0) - 1:
328
- merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
329
- tl.int64
330
- )
331
- tl.store(ret_values, merged_value)
332
-
333
325
  # Part 1: fill the old partial page
334
326
  last_loc = tl.load(last_loc_ptr + pid)
335
327
  num_part1 = (
@@ -381,7 +373,6 @@ def alloc_decode_kernel(
381
373
  last_loc_ptr,
382
374
  free_page_ptr,
383
375
  out_indices,
384
- ret_values,
385
376
  bs_upper: tl.constexpr,
386
377
  page_size: tl.constexpr,
387
378
  ):
@@ -404,10 +395,6 @@ def alloc_decode_kernel(
404
395
  sum_num_new_pages = tl.sum(num_new_pages)
405
396
  new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
406
397
 
407
- # Return value
408
- if pid == tl.num_programs(0) - 1:
409
- tl.store(ret_values, sum_num_new_pages)
410
-
411
398
  if num_page_start_loc_self == 0:
412
399
  last_loc = tl.load(last_loc_ptr + pid)
413
400
  tl.store(out_indices + pid, last_loc + 1)
@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
438
425
  super().__init__(size, page_size, dtype, device, kvcache, need_sort)
439
426
  self.num_pages = size // page_size
440
427
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
441
- self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
442
428
  self.seen_max_num_extend_tokens_next_power_of_2 = 1
443
429
  self.clear()
444
430
 
@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
468
454
  def alloc_extend(
469
455
  self,
470
456
  prefix_lens: torch.Tensor,
457
+ prefix_lens_cpu: torch.Tensor,
471
458
  seq_lens: torch.Tensor,
459
+ seq_lens_cpu: torch.Tensor,
472
460
  last_loc: torch.Tensor,
473
461
  extend_num_tokens: int,
474
462
  ):
@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
497
485
  last_loc,
498
486
  self.free_pages,
499
487
  out_indices,
500
- self.ret_values,
501
488
  next_power_of_2(bs),
502
489
  self.page_size,
503
490
  self.seen_max_num_extend_tokens_next_power_of_2,
@@ -506,8 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
506
493
  if self.debug_mode:
507
494
  assert len(torch.unique(out_indices)) == len(out_indices)
508
495
 
509
- merged_value = self.ret_values.item()
510
- num_new_pages = merged_value >> 32
496
+ num_new_pages = get_num_new_pages(
497
+ seq_lens=seq_lens_cpu,
498
+ page_size=self.page_size,
499
+ prefix_lens=prefix_lens_cpu,
500
+ )
511
501
  if num_new_pages > len(self.free_pages):
512
502
  return None
513
503
 
@@ -517,6 +507,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
517
507
  def alloc_decode(
518
508
  self,
519
509
  seq_lens: torch.Tensor,
510
+ seq_lens_cpu: torch.Tensor,
520
511
  last_loc: torch.Tensor,
521
512
  ):
522
513
  if self.debug_mode:
@@ -534,7 +525,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
534
525
  last_loc,
535
526
  self.free_pages,
536
527
  out_indices,
537
- self.ret_values,
538
528
  next_power_of_2(bs),
539
529
  self.page_size,
540
530
  )
@@ -542,7 +532,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
542
532
  if self.debug_mode:
543
533
  assert len(torch.unique(out_indices)) == len(out_indices)
544
534
 
545
- num_new_pages = self.ret_values.item()
535
+ num_new_pages = get_num_new_pages(
536
+ seq_lens=seq_lens_cpu,
537
+ page_size=self.page_size,
538
+ decode=True,
539
+ )
546
540
  if num_new_pages > len(self.free_pages):
547
541
  return None
548
542