sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -15,20 +15,15 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import logging
18
- import threading
19
- from typing import TYPE_CHECKING, Optional, Tuple, Union
18
+ from abc import ABC, abstractmethod
19
+ from typing import TYPE_CHECKING, Optional
20
20
 
21
21
  import torch
22
22
 
23
23
  from sglang.srt.configs.model_config import ModelConfig
24
24
  from sglang.srt.distributed import get_pp_group, get_world_group
25
- from sglang.srt.hf_transformers_utils import (
26
- get_processor,
27
- get_tokenizer,
28
- get_tokenizer_from_processor,
29
- )
30
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
31
25
  from sglang.srt.managers.io_struct import (
26
+ DestroyWeightsUpdateGroupReqInput,
32
27
  GetWeightsByNameReqInput,
33
28
  InitWeightsSendGroupForRemoteInstanceReqInput,
34
29
  InitWeightsUpdateGroupReqInput,
@@ -37,16 +32,23 @@ from sglang.srt.managers.io_struct import (
37
32
  UnloadLoRAAdapterReqInput,
38
33
  UpdateWeightFromDiskReqInput,
39
34
  UpdateWeightsFromDistributedReqInput,
35
+ UpdateWeightsFromIPCReqInput,
40
36
  UpdateWeightsFromTensorReqInput,
41
37
  )
42
- from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
38
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
39
+ from sglang.srt.managers.scheduler import GenerationBatchResult
43
40
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
44
41
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
45
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
46
43
  from sglang.srt.model_executor.model_runner import ModelRunner
47
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
48
44
  from sglang.srt.server_args import ServerArgs
49
45
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
46
+ from sglang.srt.utils.hf_transformers_utils import (
47
+ get_processor,
48
+ get_tokenizer,
49
+ get_tokenizer_from_processor,
50
+ )
51
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
50
52
 
51
53
  if TYPE_CHECKING:
52
54
  from sglang.srt.managers.cache_controller import LayerDoneCounter
@@ -54,7 +56,145 @@ if TYPE_CHECKING:
54
56
  logger = logging.getLogger(__name__)
55
57
 
56
58
 
57
- class TpModelWorker:
59
+ class BaseTpWorker(ABC):
60
+ @abstractmethod
61
+ def forward_batch_generation(self, forward_batch: ForwardBatch):
62
+ pass
63
+
64
+ @property
65
+ @abstractmethod
66
+ def model_runner(self) -> ModelRunner:
67
+ pass
68
+
69
+ @property
70
+ def sliding_window_size(self) -> Optional[int]:
71
+ return self.model_runner.sliding_window_size
72
+
73
+ @property
74
+ def is_hybrid(self) -> bool:
75
+ return self.model_runner.is_hybrid is not None
76
+
77
+ def get_tokens_per_layer_info(self):
78
+ return (
79
+ self.model_runner.full_max_total_num_tokens,
80
+ self.model_runner.swa_max_total_num_tokens,
81
+ )
82
+
83
+ def get_pad_input_ids_func(self):
84
+ return getattr(self.model_runner.model, "pad_input_ids", None)
85
+
86
+ def get_tp_group(self):
87
+ return self.model_runner.tp_group
88
+
89
+ def get_attention_tp_group(self):
90
+ return self.model_runner.attention_tp_group
91
+
92
+ def get_attention_tp_cpu_group(self):
93
+ return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
94
+
95
+ def get_memory_pool(self):
96
+ return (
97
+ self.model_runner.req_to_token_pool,
98
+ self.model_runner.token_to_kv_pool_allocator,
99
+ )
100
+
101
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
102
+ success, message = self.model_runner.update_weights_from_disk(
103
+ recv_req.model_path, recv_req.load_format
104
+ )
105
+ return success, message
106
+
107
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
108
+ success, message = self.model_runner.init_weights_update_group(
109
+ recv_req.master_address,
110
+ recv_req.master_port,
111
+ recv_req.rank_offset,
112
+ recv_req.world_size,
113
+ recv_req.group_name,
114
+ recv_req.backend,
115
+ )
116
+ return success, message
117
+
118
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
119
+ success, message = self.model_runner.destroy_weights_update_group(
120
+ recv_req.group_name,
121
+ )
122
+ return success, message
123
+
124
+ def init_weights_send_group_for_remote_instance(
125
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
126
+ ):
127
+ success, message = (
128
+ self.model_runner.init_weights_send_group_for_remote_instance(
129
+ recv_req.master_address,
130
+ recv_req.ports,
131
+ recv_req.group_rank,
132
+ recv_req.world_size,
133
+ recv_req.group_name,
134
+ recv_req.backend,
135
+ )
136
+ )
137
+ return success, message
138
+
139
+ def send_weights_to_remote_instance(
140
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
141
+ ):
142
+ success, message = self.model_runner.send_weights_to_remote_instance(
143
+ recv_req.master_address,
144
+ recv_req.ports,
145
+ recv_req.group_name,
146
+ )
147
+ return success, message
148
+
149
+ def update_weights_from_distributed(
150
+ self, recv_req: UpdateWeightsFromDistributedReqInput
151
+ ):
152
+ success, message = self.model_runner.update_weights_from_distributed(
153
+ recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
154
+ )
155
+ return success, message
156
+
157
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
158
+
159
+ monkey_patch_torch_reductions()
160
+ success, message = self.model_runner.update_weights_from_tensor(
161
+ named_tensors=MultiprocessingSerializer.deserialize(
162
+ recv_req.serialized_named_tensors[self.tp_rank]
163
+ ),
164
+ load_format=recv_req.load_format,
165
+ )
166
+ return success, message
167
+
168
+ def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
169
+ """Update weights from IPC for checkpoint-engine integration."""
170
+ success, message = self.model_runner.update_weights_from_ipc(recv_req)
171
+ return success, message
172
+
173
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
174
+ parameter = self.model_runner.get_weights_by_name(
175
+ recv_req.name, recv_req.truncate_size
176
+ )
177
+ return parameter
178
+
179
+ def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
180
+ result = self.model_runner.load_lora_adapter(recv_req.to_ref())
181
+ return result
182
+
183
+ def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
184
+ result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
185
+ return result
186
+
187
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
188
+ return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
189
+
190
+ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
191
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
192
+ logits_output, _ = self.model_runner.forward(forward_batch)
193
+ embeddings = logits_output.embeddings
194
+ return embeddings
195
+
196
+
197
+ class TpModelWorker(BaseTpWorker):
58
198
  """A tensor parallel model worker."""
59
199
 
60
200
  def __init__(
@@ -90,10 +230,9 @@ class TpModelWorker:
90
230
  else server_args.speculative_draft_model_revision
91
231
  ),
92
232
  is_draft_model=is_draft_worker,
93
- tp_rank=tp_rank,
94
233
  )
95
234
 
96
- self.model_runner = ModelRunner(
235
+ self._model_runner = ModelRunner(
97
236
  model_config=self.model_config,
98
237
  mem_fraction_static=server_args.mem_fraction_static,
99
238
  gpu_id=gpu_id,
@@ -149,8 +288,8 @@ class TpModelWorker:
149
288
  assert self.max_running_requests > 0, "max_running_request is zero"
150
289
  self.max_queued_requests = server_args.max_queued_requests
151
290
  assert (
152
- self.max_queued_requests > 0
153
- ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
291
+ self.max_queued_requests is None or self.max_queued_requests >= 1
292
+ ), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
154
293
  self.max_req_len = min(
155
294
  self.model_config.context_len - 1,
156
295
  self.max_total_num_tokens - 1,
@@ -169,11 +308,13 @@ class TpModelWorker:
169
308
  )[0]
170
309
  set_random_seed(self.random_seed)
171
310
 
172
- # A reference make this class has the same member as TpModelWorkerClient
173
- self.worker = self
174
-
311
+ self.enable_overlap = not server_args.disable_overlap_schedule
175
312
  self.hicache_layer_transfer_counter = None
176
313
 
314
+ @property
315
+ def model_runner(self) -> ModelRunner:
316
+ return self._model_runner
317
+
177
318
  def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
178
319
  self.hicache_layer_transfer_counter = counter
179
320
 
@@ -191,56 +332,29 @@ class TpModelWorker:
191
332
  self.max_req_input_len,
192
333
  self.random_seed,
193
334
  self.device,
194
- global_server_args_dict,
195
335
  self.model_runner.req_to_token_pool.size,
196
336
  self.model_runner.req_to_token_pool.max_context_len,
197
337
  self.model_runner.token_to_kv_pool.size,
198
338
  )
199
339
 
200
- @property
201
- def sliding_window_size(self) -> Optional[int]:
202
- return self.model_runner.sliding_window_size
203
-
204
- @property
205
- def is_hybrid(self) -> bool:
206
- return self.model_runner.is_hybrid is not None
207
-
208
- def get_tokens_per_layer_info(self):
209
- return (
210
- self.model_runner.full_max_total_num_tokens,
211
- self.model_runner.swa_max_total_num_tokens,
212
- )
213
-
214
- def get_pad_input_ids_func(self):
215
- return getattr(self.model_runner.model, "pad_input_ids", None)
216
-
217
- def get_tp_group(self):
218
- return self.model_runner.tp_group
219
-
220
- def get_attention_tp_group(self):
221
- return self.model_runner.attention_tp_group
222
-
223
- def get_attention_tp_cpu_group(self):
224
- return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
225
-
226
- def get_memory_pool(self):
227
- return (
228
- self.model_runner.req_to_token_pool,
229
- self.model_runner.token_to_kv_pool_allocator,
230
- )
231
-
232
340
  def forward_batch_generation(
233
341
  self,
234
342
  model_worker_batch: ModelWorkerBatch,
235
- launch_done: Optional[threading.Event] = None,
236
- skip_sample: bool = False,
237
- ) -> Tuple[
238
- Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
239
- ]:
240
- # update the consumer index of hicache to the running batch
241
- self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
242
-
243
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
343
+ forward_batch: Optional[ForwardBatch] = None,
344
+ is_verify: bool = False,
345
+ skip_attn_backend_init=False,
346
+ ) -> GenerationBatchResult:
347
+ # FIXME(lsyin): maybe remove skip_attn_backend_init in forward_batch_generation,
348
+ # which requires preparing replay to always be in this function
349
+
350
+ if model_worker_batch is not None:
351
+ # update the consumer index of hicache to the running batch
352
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
353
+
354
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
355
+ else:
356
+ # FIXME(lsyin): unify the interface of forward_batch
357
+ assert forward_batch is not None
244
358
 
245
359
  pp_proxy_tensors = None
246
360
  if not self.pp_group.is_first_rank:
@@ -252,115 +366,62 @@ class TpModelWorker:
252
366
 
253
367
  if self.pp_group.is_last_rank:
254
368
  logits_output, can_run_cuda_graph = self.model_runner.forward(
255
- forward_batch, pp_proxy_tensors=pp_proxy_tensors
369
+ forward_batch,
370
+ pp_proxy_tensors=pp_proxy_tensors,
371
+ skip_attn_backend_init=skip_attn_backend_init,
372
+ )
373
+ batch_result = GenerationBatchResult(
374
+ logits_output=logits_output,
375
+ can_run_cuda_graph=can_run_cuda_graph,
256
376
  )
257
- if launch_done is not None:
258
- launch_done.set()
259
377
 
260
- if skip_sample:
261
- next_token_ids = None
262
- # For prefill-only requests, we still need to compute logprobs even when sampling is skipped
378
+ if is_verify:
379
+ # Skip sampling and return logits for target forward
380
+ return batch_result
381
+
382
+ if (
383
+ self.enable_overlap
384
+ and model_worker_batch.sampling_info.grammars is not None
385
+ ):
386
+
387
+ def sample_batch_func():
388
+ batch_result.next_token_ids = self.model_runner.sample(
389
+ logits_output, forward_batch
390
+ )
391
+ return batch_result
392
+
393
+ batch_result.delay_sample_func = sample_batch_func
394
+ return batch_result
395
+
396
+ if model_worker_batch.is_prefill_only:
397
+ # For prefill-only requests, create dummy token IDs on CPU
398
+ # The size should match the batch size (number of sequences), not total tokens
399
+ batch_result.next_token_ids = torch.zeros(
400
+ len(model_worker_batch.seq_lens),
401
+ dtype=torch.long,
402
+ device=model_worker_batch.input_ids.device,
403
+ )
263
404
  if (
264
- model_worker_batch.is_prefill_only
265
- and model_worker_batch.return_logprob
405
+ model_worker_batch.return_logprob
406
+ and logits_output.next_token_logits is not None
266
407
  ):
267
- # Compute logprobs without full sampling
408
+ # NOTE: Compute logprobs without full sampling
268
409
  self.model_runner.compute_logprobs_only(
269
410
  logits_output, model_worker_batch
270
411
  )
271
412
  else:
272
- next_token_ids = self.model_runner.sample(
273
- logits_output, model_worker_batch
413
+ batch_result.next_token_ids = self.model_runner.sample(
414
+ logits_output, forward_batch
274
415
  )
275
416
 
276
- return logits_output, next_token_ids, can_run_cuda_graph
417
+ return batch_result
277
418
  else:
278
419
  pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
279
420
  forward_batch,
280
421
  pp_proxy_tensors=pp_proxy_tensors,
422
+ skip_attn_backend_init=skip_attn_backend_init,
281
423
  )
282
- return pp_proxy_tensors.tensors, None, can_run_cuda_graph
283
-
284
- def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
285
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
286
- logits_output, _ = self.model_runner.forward(forward_batch)
287
- embeddings = logits_output.embeddings
288
- return embeddings
289
-
290
- def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
291
- success, message = self.model_runner.update_weights_from_disk(
292
- recv_req.model_path, recv_req.load_format
293
- )
294
- return success, message
295
-
296
- def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
297
- success, message = self.model_runner.init_weights_update_group(
298
- recv_req.master_address,
299
- recv_req.master_port,
300
- recv_req.rank_offset,
301
- recv_req.world_size,
302
- recv_req.group_name,
303
- recv_req.backend,
304
- )
305
- return success, message
306
-
307
- def init_weights_send_group_for_remote_instance(
308
- self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
309
- ):
310
- success, message = (
311
- self.model_runner.init_weights_send_group_for_remote_instance(
312
- recv_req.master_address,
313
- recv_req.ports,
314
- recv_req.group_rank,
315
- recv_req.world_size,
316
- recv_req.group_name,
317
- recv_req.backend,
424
+ return GenerationBatchResult(
425
+ pp_hidden_states_proxy_tensors=pp_proxy_tensors,
426
+ can_run_cuda_graph=can_run_cuda_graph,
318
427
  )
319
- )
320
- return success, message
321
-
322
- def send_weights_to_remote_instance(
323
- self, recv_req: SendWeightsToRemoteInstanceReqInput
324
- ):
325
- success, message = self.model_runner.send_weights_to_remote_instance(
326
- recv_req.master_address,
327
- recv_req.ports,
328
- recv_req.group_name,
329
- )
330
- return success, message
331
-
332
- def update_weights_from_distributed(
333
- self, recv_req: UpdateWeightsFromDistributedReqInput
334
- ):
335
- success, message = self.model_runner.update_weights_from_distributed(
336
- recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
337
- )
338
- return success, message
339
-
340
- def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
341
-
342
- monkey_patch_torch_reductions()
343
- success, message = self.model_runner.update_weights_from_tensor(
344
- named_tensors=MultiprocessingSerializer.deserialize(
345
- recv_req.serialized_named_tensors[self.tp_rank]
346
- ),
347
- load_format=recv_req.load_format,
348
- )
349
- return success, message
350
-
351
- def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
352
- parameter = self.model_runner.get_weights_by_name(
353
- recv_req.name, recv_req.truncate_size
354
- )
355
- return parameter
356
-
357
- def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
358
- result = self.model_runner.load_lora_adapter(recv_req.to_ref())
359
- return result
360
-
361
- def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
362
- result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
363
- return result
364
-
365
- def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
366
- return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
@@ -1,20 +1,95 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
3
4
  import logging
4
- import multiprocessing as mp
5
- from http import HTTPStatus
6
- from typing import TYPE_CHECKING, Dict, List, Optional
5
+ from typing import TYPE_CHECKING, List, Optional
6
+
7
+ import torch
7
8
 
8
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
10
+ from sglang.srt.managers.overlap_utils import FutureIndices
11
+ from sglang.srt.managers.schedule_batch import Req
10
12
  from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
11
13
 
12
14
  if TYPE_CHECKING:
13
15
  from sglang.srt.managers.scheduler import GenerationBatchResult
16
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
17
+
14
18
 
15
19
  logger = logging.getLogger(__name__)
16
20
 
17
21
 
22
+ @dataclasses.dataclass
23
+ class GenerationBatchResult:
24
+ logits_output: Optional[LogitsProcessorOutput] = None
25
+ pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
26
+ next_token_ids: Optional[torch.Tensor] = None
27
+ num_accepted_tokens: Optional[int] = None
28
+ can_run_cuda_graph: bool = False
29
+
30
+ # For output processing
31
+ extend_input_len_per_req: Optional[List[int]] = None
32
+ extend_logprob_start_len_per_req: Optional[List[int]] = None
33
+
34
+ # For overlap scheduling
35
+ copy_done: Optional[torch.cuda.Event] = None
36
+ delay_sample_func: Optional[callable] = None
37
+ future_indices: Optional[FutureIndices] = None
38
+
39
+ # FIXME(lsyin): maybe move to a better place?
40
+ # sync path: forward stream -> output processor
41
+ accept_lens: Optional[torch.Tensor] = None
42
+ allocate_lens: Optional[torch.Tensor] = None
43
+
44
+ # relay path: forward stream -> next step forward
45
+ next_draft_input: Optional[EagleDraftInput] = None
46
+
47
+ def copy_to_cpu(self, return_logprob: bool = False):
48
+ """Copy tensors to CPU in overlap scheduling.
49
+ Only the tensors which are needed for processing results are copied,
50
+ e.g., next_token_ids, logits outputs
51
+ """
52
+ if return_logprob:
53
+ if self.logits_output.next_token_logits is not None:
54
+ self.logits_output.next_token_logits = (
55
+ self.logits_output.next_token_logits.to("cpu", non_blocking=True)
56
+ )
57
+ if self.logits_output.input_token_logprobs is not None:
58
+ self.logits_output.input_token_logprobs = (
59
+ self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
60
+ )
61
+ if self.logits_output.hidden_states is not None:
62
+ self.logits_output.hidden_states = self.logits_output.hidden_states.to(
63
+ "cpu", non_blocking=True
64
+ )
65
+ self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
66
+
67
+ if self.accept_lens is not None:
68
+ self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
69
+
70
+ if self.allocate_lens is not None:
71
+ self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
72
+
73
+ self.copy_done.record()
74
+
75
+ @classmethod
76
+ def from_pp_proxy(
77
+ cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
78
+ ):
79
+ # TODO(lsyin): refactor PP and avoid using dict
80
+ proxy_dict = next_pp_outputs.tensors
81
+ return cls(
82
+ logits_output=logits_output,
83
+ pp_hidden_states_proxy_tensors=None,
84
+ next_token_ids=next_pp_outputs["next_token_ids"],
85
+ extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
86
+ extend_logprob_start_len_per_req=proxy_dict.get(
87
+ "extend_logprob_start_len_per_req", None
88
+ ),
89
+ can_run_cuda_graph=can_run_cuda_graph,
90
+ )
91
+
92
+
18
93
  def validate_input_length(
19
94
  req: Req, max_req_input_len: int, allow_auto_truncate: bool
20
95
  ) -> Optional[str]:
@@ -97,46 +172,3 @@ def get_logprob_from_pp_outputs(
97
172
  ]
98
173
 
99
174
  return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
100
-
101
-
102
- class DPBalanceMeta:
103
- """
104
- This class will be use in scheduler and dp controller
105
- """
106
-
107
- def __init__(self, num_workers: int):
108
- self.num_workers = num_workers
109
- self._manager = mp.Manager()
110
- self.mutex = self._manager.Lock()
111
-
112
- init_local_tokens = [0] * self.num_workers
113
- init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
114
-
115
- self.shared_state = self._manager.Namespace()
116
- self.shared_state.local_tokens = self._manager.list(init_local_tokens)
117
- self.shared_state.onfly_info = self._manager.list(init_onfly_info)
118
-
119
- def destructor(self):
120
- # we must destructor this class manually
121
- self._manager.shutdown()
122
-
123
- def get_shared_onfly(self) -> List[Dict[int, int]]:
124
- return [dict(d) for d in self.shared_state.onfly_info]
125
-
126
- def set_shared_onfly_info(self, data: List[Dict[int, int]]):
127
- self.shared_state.onfly_info = data
128
-
129
- def get_shared_local_tokens(self) -> List[int]:
130
- return list(self.shared_state.local_tokens)
131
-
132
- def set_shared_local_tokens(self, data: List[int]):
133
- self.shared_state.local_tokens = data
134
-
135
- def __getstate__(self):
136
- state = self.__dict__.copy()
137
- del state["_manager"]
138
- return state
139
-
140
- def __setstate__(self, state):
141
- self.__dict__.update(state)
142
- self._manager = None