sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,217 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import TYPE_CHECKING
5
+
6
+ from sglang.srt.disaggregation.utils import DisaggregationMode
7
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
8
+ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
9
+ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
10
+
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.managers.scheduler import Scheduler
13
+
14
+
15
+ class SchedulerRuntimeCheckerMixin:
16
+
17
+ def _check_hybrid_memory(self: Scheduler):
18
+ (
19
+ full_num_used,
20
+ swa_num_used,
21
+ _,
22
+ _,
23
+ full_available_size,
24
+ full_evictable_size,
25
+ swa_available_size,
26
+ swa_evictable_size,
27
+ ) = self._get_swa_token_info()
28
+ memory_leak = full_num_used != 0 or swa_num_used != 0
29
+ token_msg = (
30
+ f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
31
+ f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
32
+ )
33
+ return memory_leak, token_msg
34
+
35
+ def _check_mamba_memory(self: Scheduler):
36
+ (
37
+ full_num_used,
38
+ mamba_num_used,
39
+ _,
40
+ _,
41
+ full_available_size,
42
+ full_evictable_size,
43
+ mamba_available_size,
44
+ mamba_evictable_size,
45
+ ) = self._get_mamba_token_info()
46
+ memory_leak = (
47
+ full_num_used != self.tree_cache.full_protected_size()
48
+ or mamba_num_used != self.tree_cache.mamba_protected_size()
49
+ )
50
+ token_msg = (
51
+ f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
52
+ f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
53
+ )
54
+ return memory_leak, token_msg
55
+
56
+ def _check_radix_cache_memory(self: Scheduler):
57
+ _, _, available_size, evictable_size = self._get_token_info()
58
+ protected_size = self.tree_cache.protected_size()
59
+ memory_leak = (available_size + evictable_size) != (
60
+ # self.max_total_num_tokens
61
+ # if not self.enable_hierarchical_cache
62
+ # else self.max_total_num_tokens - protected_size
63
+ self.max_total_num_tokens
64
+ - protected_size
65
+ )
66
+ token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
67
+ return memory_leak, token_msg
68
+
69
+ def _check_runtime_mem_leak(self: Scheduler):
70
+ current_batch: ScheduleBatch = self.last_batch
71
+
72
+ if current_batch is None:
73
+ return
74
+
75
+ _, _, available_size, evictable_size = self._get_token_info()
76
+ protected_size = self.tree_cache.protected_size()
77
+
78
+ extend_size = 0
79
+ for i, req in enumerate(current_batch.reqs):
80
+ seq_len = len(req.origin_input_ids) + len(req.output_ids)
81
+ fill_len = len(req.fill_ids) if req.fill_ids is not None else 0
82
+ prefix_len = (
83
+ len(req.prefix_indices) if req.prefix_indices is not None else 0
84
+ )
85
+
86
+ if current_batch.forward_mode.is_decode():
87
+ if req.finished():
88
+ unreleased_len = 1
89
+ else:
90
+ unreleased_len = seq_len - prefix_len
91
+ else:
92
+ unreleased_len = fill_len - prefix_len
93
+
94
+ extend_size += unreleased_len
95
+
96
+ if (
97
+ current_batch.forward_mode.is_extend()
98
+ and self.running_batch is not None
99
+ and not self.running_batch.is_empty()
100
+ and self.running_batch.forward_mode.is_decode()
101
+ ):
102
+ for i, req in enumerate(self.running_batch.reqs):
103
+ seq_len = len(req.origin_input_ids) + len(req.output_ids)
104
+ prefix_len = (
105
+ len(req.prefix_indices) if req.prefix_indices is not None else 0
106
+ )
107
+
108
+ if req.finished():
109
+ unreleased_len = 0
110
+ else:
111
+ unreleased_len = seq_len - prefix_len - 1
112
+
113
+ extend_size += unreleased_len
114
+
115
+ total_tokens = available_size + evictable_size + protected_size + extend_size
116
+
117
+ assert (
118
+ total_tokens == self.max_total_num_tokens
119
+ ), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
120
+
121
+ def _check_req_pool(self: Scheduler):
122
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
123
+ req_total_size = (
124
+ self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
125
+ )
126
+ else:
127
+ req_total_size = self.req_to_token_pool.size
128
+
129
+ if len(self.req_to_token_pool.free_slots) != req_total_size:
130
+ msg = (
131
+ "req_to_token_pool memory leak detected!"
132
+ f"available_size={len(self.req_to_token_pool.free_slots)}, "
133
+ f"total_size={self.req_to_token_pool.size}\n"
134
+ )
135
+ raise ValueError(msg)
136
+
137
+ def check_memory(self: Scheduler):
138
+ if self.is_hybrid:
139
+ memory_leak, token_msg = self._check_hybrid_memory()
140
+ elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
141
+ memory_leak, token_msg = self._check_mamba_memory()
142
+ else:
143
+ memory_leak, token_msg = self._check_radix_cache_memory()
144
+
145
+ if memory_leak:
146
+ msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
147
+ raise ValueError(msg)
148
+
149
+ self._check_req_pool()
150
+
151
+ if (
152
+ self.enable_metrics
153
+ and self.current_scheduler_metrics_enabled()
154
+ and time.perf_counter() > self.metrics_collector.last_log_time + 30
155
+ ):
156
+ # During idle time, also collect metrics every 30 seconds.
157
+ if self.is_hybrid:
158
+ (
159
+ full_num_used,
160
+ swa_num_used,
161
+ full_token_usage,
162
+ swa_token_usage,
163
+ _,
164
+ _,
165
+ _,
166
+ _,
167
+ ) = self._get_swa_token_info()
168
+ num_used = max(full_num_used, swa_num_used)
169
+ token_usage = max(full_token_usage, swa_token_usage)
170
+ elif self.is_hybrid_gdn:
171
+ (
172
+ num_used,
173
+ _,
174
+ token_usage,
175
+ _,
176
+ _,
177
+ _,
178
+ _,
179
+ _,
180
+ ) = self._get_mamba_token_info()
181
+ else:
182
+ num_used, token_usage, _, _ = self._get_token_info()
183
+ num_running_reqs = len(self.running_batch.reqs)
184
+ self.stats.num_running_reqs = num_running_reqs
185
+ self.stats.num_used_tokens = num_used
186
+ self.stats.token_usage = round(token_usage, 2)
187
+ self.stats.gen_throughput = 0
188
+ self.stats.num_queue_reqs = len(self.waiting_queue)
189
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
190
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
191
+ self.stats.num_prefill_prealloc_queue_reqs = len(
192
+ self.disagg_prefill_bootstrap_queue.queue
193
+ )
194
+ self.stats.num_prefill_inflight_queue_reqs = len(
195
+ self.disagg_prefill_inflight_queue
196
+ )
197
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
198
+ self.stats.num_decode_prealloc_queue_reqs = len(
199
+ self.disagg_decode_prealloc_queue.queue
200
+ )
201
+ self.stats.num_decode_transfer_queue_reqs = len(
202
+ self.disagg_decode_transfer_queue.queue
203
+ )
204
+ self.metrics_collector.log_stats(self.stats)
205
+ self._publish_kv_events()
206
+
207
+ def check_tree_cache(self: Scheduler):
208
+ if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
209
+ self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
210
+ ):
211
+ self.tree_cache.sanity_check()
212
+
213
+ def self_check_during_idle(self: Scheduler):
214
+ self.check_memory()
215
+ self.check_tree_cache()
216
+ self.new_token_ratio = self.init_new_token_ratio
217
+ self.maybe_sleep_on_idle()
@@ -1,10 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
- from typing import Tuple
4
+ from typing import TYPE_CHECKING, Tuple
3
5
 
4
6
  import torch
5
7
 
6
8
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
7
9
  from sglang.srt.managers.io_struct import (
10
+ DestroyWeightsUpdateGroupReqInput,
11
+ DestroyWeightsUpdateGroupReqOutput,
8
12
  GetWeightsByNameReqInput,
9
13
  GetWeightsByNameReqOutput,
10
14
  InitWeightsUpdateGroupReqInput,
@@ -17,10 +21,15 @@ from sglang.srt.managers.io_struct import (
17
21
  UpdateWeightFromDiskReqOutput,
18
22
  UpdateWeightsFromDistributedReqInput,
19
23
  UpdateWeightsFromDistributedReqOutput,
24
+ UpdateWeightsFromIPCReqInput,
25
+ UpdateWeightsFromIPCReqOutput,
20
26
  UpdateWeightsFromTensorReqInput,
21
27
  UpdateWeightsFromTensorReqOutput,
22
28
  )
23
29
 
30
+ if TYPE_CHECKING:
31
+ from sglang.srt.managers.scheduler import Scheduler
32
+
24
33
  logger = logging.getLogger(__name__)
25
34
 
26
35
 
@@ -41,6 +50,11 @@ class SchedulerUpdateWeightsMixin:
41
50
  success, message = self.tp_worker.init_weights_update_group(recv_req)
42
51
  return InitWeightsUpdateGroupReqOutput(success, message)
43
52
 
53
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
54
+ """Destroy the online model parameter update group."""
55
+ success, message = self.tp_worker.destroy_weights_update_group(recv_req)
56
+ return DestroyWeightsUpdateGroupReqOutput(success, message)
57
+
44
58
  def update_weights_from_distributed(
45
59
  self,
46
60
  recv_req: UpdateWeightsFromDistributedReqInput,
@@ -68,11 +82,25 @@ class SchedulerUpdateWeightsMixin:
68
82
  torch.distributed.barrier(group=self.tp_cpu_group)
69
83
  return UpdateWeightsFromTensorReqOutput(success, message)
70
84
 
85
+ def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
86
+ """Update the online model parameter from IPC for checkpoint-engine integration."""
87
+ success, message = self.tp_worker.update_weights_from_ipc(recv_req)
88
+ if success:
89
+ if recv_req.flush_cache:
90
+ flush_cache_success = self.flush_cache()
91
+ assert flush_cache_success, "Cache flush failed after updating weights"
92
+ else:
93
+ logger.error(message)
94
+ torch.distributed.barrier(group=self.tp_cpu_group)
95
+ return UpdateWeightsFromIPCReqOutput(success, message)
96
+
71
97
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
72
98
  parameter = self.tp_worker.get_weights_by_name(recv_req)
73
99
  return GetWeightsByNameReqOutput(parameter)
74
100
 
75
- def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
101
+ def release_memory_occupation(
102
+ self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
103
+ ):
76
104
  tags = recv_req.tags
77
105
 
78
106
  if tags is None or len(tags) == 0:
@@ -87,14 +115,16 @@ class SchedulerUpdateWeightsMixin:
87
115
 
88
116
  if GPU_MEMORY_TYPE_WEIGHTS in tags:
89
117
  self.stashed_model_static_state = _export_static_state(
90
- self.tp_worker.worker.model_runner.model
118
+ self.tp_worker.model_runner.model
91
119
  )
92
120
  torch.distributed.barrier(self.tp_cpu_group)
93
121
  self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
94
122
 
95
123
  return ReleaseMemoryOccupationReqOutput()
96
124
 
97
- def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
125
+ def resume_memory_occupation(
126
+ self: Scheduler, recv_req: ResumeMemoryOccupationReqInput
127
+ ):
98
128
  tags = recv_req.tags
99
129
 
100
130
  if tags is None or len(tags) == 0:
@@ -107,7 +137,7 @@ class SchedulerUpdateWeightsMixin:
107
137
  self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
108
138
  torch.distributed.barrier(self.tp_cpu_group)
109
139
  _import_static_state(
110
- self.tp_worker.worker.model_runner.model,
140
+ self.tp_worker.model_runner.model,
111
141
  self.stashed_model_static_state,
112
142
  )
113
143
  del self.stashed_model_static_state
@@ -117,24 +147,20 @@ class SchedulerUpdateWeightsMixin:
117
147
 
118
148
  return ResumeMemoryOccupationReqOutput()
119
149
 
120
- def save_remote_model(self, params):
150
+ def save_remote_model(self: Scheduler, params):
121
151
  url = params["url"]
122
152
 
123
- worker = self.tp_worker.worker
124
- worker.model_runner.save_remote_model(url)
153
+ self.tp_worker.model_runner.save_remote_model(url)
125
154
 
126
155
  if self.draft_worker is not None:
127
156
  draft_url = params.get("draft_url", None)
128
157
  assert (
129
158
  draft_url is not None
130
159
  ), "draft_url must be provided when draft model is enabled"
131
- draft_worker = self.draft_worker.worker
132
- draft_worker.model_runner.save_remote_model(draft_url)
133
-
134
- def save_sharded_model(self, params):
135
- worker = self.tp_worker.worker
160
+ self.draft_worker.model_runner.save_remote_model(draft_url)
136
161
 
137
- worker.model_runner.save_sharded_model(
162
+ def save_sharded_model(self: Scheduler, params):
163
+ self.tp_worker.model_runner.save_sharded_model(
138
164
  path=params["path"],
139
165
  pattern=params["pattern"],
140
166
  max_size=params["max_size"],
@@ -3,8 +3,8 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import copy
5
5
  import logging
6
- import os
7
6
  import time
7
+ import uuid
8
8
  from collections import deque
9
9
  from typing import (
10
10
  TYPE_CHECKING,
@@ -24,8 +24,12 @@ import zmq
24
24
  from sglang.srt.managers.io_struct import (
25
25
  ClearHiCacheReqInput,
26
26
  ClearHiCacheReqOutput,
27
+ CloseSessionReqInput,
28
+ DestroyWeightsUpdateGroupReqInput,
29
+ DestroyWeightsUpdateGroupReqOutput,
27
30
  ExpertDistributionReq,
28
31
  ExpertDistributionReqOutput,
32
+ ExpertDistributionReqType,
29
33
  FlushCacheReqInput,
30
34
  FlushCacheReqOutput,
31
35
  GetInternalStateReq,
@@ -40,8 +44,8 @@ from sglang.srt.managers.io_struct import (
40
44
  InitWeightsUpdateGroupReqOutput,
41
45
  LoadLoRAAdapterReqInput,
42
46
  LoadLoRAAdapterReqOutput,
43
- LoRAUpdateResult,
44
- MultiTokenizerWrapper,
47
+ LoRAUpdateOutput,
48
+ OpenSessionReqInput,
45
49
  ProfileReq,
46
50
  ProfileReqOutput,
47
51
  ProfileReqType,
@@ -59,6 +63,8 @@ from sglang.srt.managers.io_struct import (
59
63
  UnloadLoRAAdapterReqOutput,
60
64
  UpdateWeightsFromDistributedReqInput,
61
65
  UpdateWeightsFromDistributedReqOutput,
66
+ UpdateWeightsFromIPCReqInput,
67
+ UpdateWeightsFromIPCReqOutput,
62
68
  UpdateWeightsFromTensorReqInput,
63
69
  UpdateWeightsFromTensorReqOutput,
64
70
  )
@@ -77,8 +83,6 @@ logger = logging.getLogger(__name__)
77
83
  class _Communicator(Generic[T]):
78
84
  """Note: The communicator now only run up to 1 in-flight request at any time."""
79
85
 
80
- enable_multi_tokenizer = False
81
-
82
86
  def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
83
87
  self._sender = sender
84
88
  self._fan_out = fan_out
@@ -98,8 +102,6 @@ class _Communicator(Generic[T]):
98
102
  assert self._result_values is None
99
103
 
100
104
  if obj:
101
- if _Communicator.enable_multi_tokenizer:
102
- obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
103
105
  self._sender.send_pyobj(obj)
104
106
 
105
107
  self._result_event = asyncio.Event()
@@ -120,8 +122,6 @@ class _Communicator(Generic[T]):
120
122
  self._result_event = asyncio.Event()
121
123
 
122
124
  if obj:
123
- if _Communicator.enable_multi_tokenizer:
124
- obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
125
125
  self._sender.send_pyobj(obj)
126
126
 
127
127
  await self._result_event.wait()
@@ -140,6 +140,13 @@ class _Communicator(Generic[T]):
140
140
  if len(self._result_values) == self._fan_out:
141
141
  self._result_event.set()
142
142
 
143
+ @staticmethod
144
+ def merge_results(results):
145
+ all_success = all([r.success for r in results])
146
+ all_message = [r.message for r in results]
147
+ all_message = " | ".join(all_message)
148
+ return all_success, all_message
149
+
143
150
 
144
151
  class TokenizerCommunicatorMixin:
145
152
  """Mixin class for TokenizerManager to handle communication with the scheduler."""
@@ -149,6 +156,9 @@ class TokenizerCommunicatorMixin:
149
156
  self.init_weights_update_group_communicator = _Communicator(
150
157
  self.send_to_scheduler, server_args.dp_size
151
158
  )
159
+ self.destroy_weights_update_group_communicator = _Communicator(
160
+ self.send_to_scheduler, server_args.dp_size
161
+ )
152
162
  self.update_weights_from_distributed_communicator = _Communicator(
153
163
  self.send_to_scheduler, server_args.dp_size
154
164
  )
@@ -161,6 +171,9 @@ class TokenizerCommunicatorMixin:
161
171
  self.update_weights_from_tensor_communicator = _Communicator(
162
172
  self.send_to_scheduler, server_args.dp_size
163
173
  )
174
+ self.update_weights_from_ipc_communicator = _Communicator(
175
+ self.send_to_scheduler, server_args.dp_size
176
+ )
164
177
  self.get_weights_by_name_communicator = _Communicator(
165
178
  self.send_to_scheduler, server_args.dp_size
166
179
  )
@@ -207,6 +220,10 @@ class TokenizerCommunicatorMixin:
207
220
  InitWeightsUpdateGroupReqOutput,
208
221
  self.init_weights_update_group_communicator.handle_recv,
209
222
  ),
223
+ (
224
+ DestroyWeightsUpdateGroupReqOutput,
225
+ self.destroy_weights_update_group_communicator.handle_recv,
226
+ ),
210
227
  (
211
228
  UpdateWeightsFromDistributedReqOutput,
212
229
  self.update_weights_from_distributed_communicator.handle_recv,
@@ -223,6 +240,10 @@ class TokenizerCommunicatorMixin:
223
240
  UpdateWeightsFromTensorReqOutput,
224
241
  self.update_weights_from_tensor_communicator.handle_recv,
225
242
  ),
243
+ (
244
+ UpdateWeightsFromIPCReqOutput,
245
+ self.update_weights_from_ipc_communicator.handle_recv,
246
+ ),
226
247
  (
227
248
  GetWeightsByNameReqOutput,
228
249
  self.get_weights_by_name_communicator.handle_recv,
@@ -264,7 +285,7 @@ class TokenizerCommunicatorMixin:
264
285
  self.expert_distribution_communicator.handle_recv,
265
286
  ),
266
287
  (
267
- LoRAUpdateResult,
288
+ LoRAUpdateOutput,
268
289
  self.update_lora_adapter_communicator.handle_recv,
269
290
  ),
270
291
  (
@@ -293,6 +314,7 @@ class TokenizerCommunicatorMixin:
293
314
  with_stack: Optional[bool] = None,
294
315
  record_shapes: Optional[bool] = None,
295
316
  profile_by_stage: bool = False,
317
+ merge_profiles: bool = False,
296
318
  ):
297
319
  self.auto_create_handle_loop()
298
320
  env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
@@ -307,6 +329,7 @@ class TokenizerCommunicatorMixin:
307
329
  record_shapes=record_shapes,
308
330
  profile_by_stage=profile_by_stage,
309
331
  profile_id=str(time.time()),
332
+ merge_profiles=merge_profiles,
310
333
  )
311
334
  return await self._execute_profile(req)
312
335
 
@@ -323,15 +346,18 @@ class TokenizerCommunicatorMixin:
323
346
 
324
347
  async def start_expert_distribution_record(self: TokenizerManager):
325
348
  self.auto_create_handle_loop()
326
- await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
349
+ req = ExpertDistributionReq(action=ExpertDistributionReqType.START_RECORD)
350
+ await self.expert_distribution_communicator(req)
327
351
 
328
352
  async def stop_expert_distribution_record(self: TokenizerManager):
329
353
  self.auto_create_handle_loop()
330
- await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
354
+ req = ExpertDistributionReq(action=ExpertDistributionReqType.STOP_RECORD)
355
+ await self.expert_distribution_communicator(req)
331
356
 
332
357
  async def dump_expert_distribution_record(self: TokenizerManager):
333
358
  self.auto_create_handle_loop()
334
- await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
359
+ req = ExpertDistributionReq(action=ExpertDistributionReqType.DUMP_RECORD)
360
+ await self.expert_distribution_communicator(req)
335
361
 
336
362
  async def init_weights_update_group(
337
363
  self: TokenizerManager,
@@ -340,10 +366,24 @@ class TokenizerCommunicatorMixin:
340
366
  ) -> Tuple[bool, str]:
341
367
  self.auto_create_handle_loop()
342
368
  assert (
343
- self.server_args.dp_size == 1
344
- ), "dp_size must be 1 for init parameter update group"
345
- result = (await self.init_weights_update_group_communicator(obj))[0]
346
- return result.success, result.message
369
+ self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
370
+ ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
371
+
372
+ results = await self.init_weights_update_group_communicator(obj)
373
+ return _Communicator.merge_results(results)
374
+
375
+ async def destroy_weights_update_group(
376
+ self,
377
+ obj: DestroyWeightsUpdateGroupReqInput,
378
+ request: Optional[fastapi.Request] = None,
379
+ ) -> Tuple[bool, str]:
380
+ self.auto_create_handle_loop()
381
+ assert (
382
+ self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
383
+ ), "dp_size must be 1 or dp attention must be enabled for destroy parameter update group"
384
+
385
+ results = await self.destroy_weights_update_group_communicator(obj)
386
+ return _Communicator.merge_results(results)
347
387
 
348
388
  async def update_weights_from_distributed(
349
389
  self: TokenizerManager,
@@ -361,8 +401,8 @@ class TokenizerCommunicatorMixin:
361
401
  # This means that weight sync
362
402
  # cannot run while requests are in progress.
363
403
  async with self.model_update_lock.writer_lock:
364
- result = (await self.update_weights_from_distributed_communicator(obj))[0]
365
- return result.success, result.message
404
+ results = await self.update_weights_from_distributed_communicator(obj)
405
+ return _Communicator.merge_results(results)
366
406
 
367
407
  async def init_weights_send_group_for_remote_instance(
368
408
  self,
@@ -411,6 +451,28 @@ class TokenizerCommunicatorMixin:
411
451
  result = (await self.update_weights_from_tensor_communicator(obj))[0]
412
452
  return result.success, result.message
413
453
 
454
+ async def update_weights_from_ipc(
455
+ self,
456
+ obj: UpdateWeightsFromIPCReqInput,
457
+ request: Optional[fastapi.Request] = None,
458
+ ) -> Tuple[bool, str]:
459
+ """Update weights via IPC for checkpoint-engine integration."""
460
+ self.auto_create_handle_loop()
461
+ try:
462
+ # For now, we only support single data parallel instance
463
+ assert (
464
+ self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
465
+ ), "dp_size must be 1 or dp attention must be enabled for update weights from IPC"
466
+ logger.info("Starting IPC weight update")
467
+ # This means that weight sync cannot run while requests are in progress.
468
+ async with self.model_update_lock.writer_lock:
469
+ result = (await self.update_weights_from_ipc_communicator(obj))[0]
470
+ return result.success, result.message
471
+ except Exception as e:
472
+ error_msg = f"IPC weight update failed: {str(e)}"
473
+ logger.error(error_msg)
474
+ return False, error_msg
475
+
414
476
  async def load_lora_adapter(
415
477
  self: TokenizerManager,
416
478
  obj: LoadLoRAAdapterReqInput,
@@ -567,3 +629,63 @@ class TokenizerCommunicatorMixin:
567
629
  async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
568
630
  req = GetLoadReqInput()
569
631
  return await self.get_load_communicator(req)
632
+
633
+ async def open_session(
634
+ self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
635
+ ):
636
+ self.auto_create_handle_loop()
637
+
638
+ if obj.session_id is None:
639
+ obj.session_id = uuid.uuid4().hex
640
+ elif obj.session_id in self.session_futures:
641
+ return None
642
+
643
+ self.send_to_scheduler.send_pyobj(obj)
644
+
645
+ self.session_futures[obj.session_id] = asyncio.Future()
646
+ session_id = await self.session_futures[obj.session_id]
647
+ del self.session_futures[obj.session_id]
648
+ return session_id
649
+
650
+ async def close_session(
651
+ self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
652
+ ):
653
+ await self.send_to_scheduler.send_pyobj(obj)
654
+
655
+ def get_log_request_metadata(self):
656
+ max_length = None
657
+ skip_names = None
658
+ out_skip_names = None
659
+ if self.log_requests:
660
+ if self.log_requests_level == 0:
661
+ max_length = 1 << 30
662
+ skip_names = {
663
+ "text",
664
+ "input_ids",
665
+ "input_embeds",
666
+ "image_data",
667
+ "audio_data",
668
+ "lora_path",
669
+ "sampling_params",
670
+ }
671
+ out_skip_names = {"text", "output_ids", "embedding"}
672
+ elif self.log_requests_level == 1:
673
+ max_length = 1 << 30
674
+ skip_names = {
675
+ "text",
676
+ "input_ids",
677
+ "input_embeds",
678
+ "image_data",
679
+ "audio_data",
680
+ "lora_path",
681
+ }
682
+ out_skip_names = {"text", "output_ids", "embedding"}
683
+ elif self.log_requests_level == 2:
684
+ max_length = 2048
685
+ elif self.log_requests_level == 3:
686
+ max_length = 1 << 30
687
+ else:
688
+ raise ValueError(
689
+ f"Invalid --log-requests-level: {self.log_requests_level=}"
690
+ )
691
+ return max_length, skip_names, out_skip_names