sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
sglang/check_env.py CHANGED
@@ -47,7 +47,7 @@ PACKAGE_LIST = [
47
47
  "tiktoken",
48
48
  "anthropic",
49
49
  "litellm",
50
- "decord",
50
+ "decord2",
51
51
  ]
52
52
 
53
53
 
@@ -19,6 +19,7 @@ import requests
19
19
 
20
20
  from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
21
21
  from sglang.srt.entrypoints.http_server import launch_server
22
+ from sglang.srt.environ import envs
22
23
  from sglang.srt.managers.io_struct import GenerateReqInput
23
24
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
24
25
  from sglang.srt.server_args import ServerArgs
@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup
28
29
  multiprocessing.set_start_method("spawn", force=True)
29
30
 
30
31
  # Reduce warning
31
- os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
32
+ envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
32
33
  # Force enable deep gemm
33
- os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
34
+ envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
34
35
  # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
35
36
  os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
36
37
 
@@ -141,6 +142,9 @@ def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
141
142
  server_args.enable_torch_compile = False
142
143
  print(f"Disable CUDA Graph and Torch Compile to save time...")
143
144
 
145
+ server_args.load_format = "dummy"
146
+ print(f"Set load format to dummy to save time...")
147
+
144
148
  # Set watchdog timeout to compile_args.timeout because compilation will take a long time
145
149
  server_args.watchdog_timeout = compile_args.timeout
146
150
  server_args.warmups = "compile-deep-gemm"
sglang/global_config.py CHANGED
@@ -1,14 +1,11 @@
1
1
  """Global configurations"""
2
2
 
3
- import os
3
+ # FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py
4
4
 
5
5
 
6
6
  class GlobalConfig:
7
7
  """
8
8
  Store some global constants.
9
-
10
- See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
11
- many global runtime arguments as well.
12
9
  """
13
10
 
14
11
  def __init__(self):
@@ -20,27 +17,6 @@ class GlobalConfig:
20
17
  # Default backend of the language
21
18
  self.default_backend = None
22
19
 
23
- # Runtime constants: New generation token ratio estimation
24
- self.default_init_new_token_ratio = float(
25
- os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
26
- )
27
- self.default_min_new_token_ratio_factor = float(
28
- os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
29
- )
30
- self.default_new_token_ratio_decay_steps = float(
31
- os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
32
- )
33
- self.torch_empty_cache_interval = float(
34
- os.environ.get(
35
- "SGLANG_EMPTY_CACHE_INTERVAL", -1
36
- ) # in seconds. Set if you observe high memory accumulation over a long serving period.
37
- )
38
- # Runtime constants: others
39
- self.retract_decode_steps = 20
40
- self.flashinfer_workspace_size = int(
41
- os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
42
- )
43
-
44
20
  # Output tokenization configs
45
21
  self.skip_special_tokens_in_output = True
46
22
  self.spaces_between_special_tokens_in_out = True
sglang/lang/api.py CHANGED
@@ -79,6 +79,7 @@ def gen(
79
79
  n: Optional[int] = None,
80
80
  stop: Optional[Union[str, List[str]]] = None,
81
81
  stop_token_ids: Optional[List[int]] = None,
82
+ stop_regex: Optional[Union[str, List[str]]] = None,
82
83
  temperature: Optional[float] = None,
83
84
  top_p: Optional[float] = None,
84
85
  top_k: Optional[int] = None,
@@ -120,6 +121,7 @@ def gen(
120
121
  n,
121
122
  stop,
122
123
  stop_token_ids,
124
+ stop_regex,
123
125
  temperature,
124
126
  top_p,
125
127
  top_k,
@@ -143,6 +145,7 @@ def gen_int(
143
145
  n: Optional[int] = None,
144
146
  stop: Optional[Union[str, List[str]]] = None,
145
147
  stop_token_ids: Optional[List[int]] = None,
148
+ stop_regex: Optional[Union[str, List[str]]] = None,
146
149
  temperature: Optional[float] = None,
147
150
  top_p: Optional[float] = None,
148
151
  top_k: Optional[int] = None,
@@ -162,6 +165,7 @@ def gen_int(
162
165
  n,
163
166
  stop,
164
167
  stop_token_ids,
168
+ stop_regex,
165
169
  temperature,
166
170
  top_p,
167
171
  top_k,
@@ -184,6 +188,7 @@ def gen_string(
184
188
  n: Optional[int] = None,
185
189
  stop: Optional[Union[str, List[str]]] = None,
186
190
  stop_token_ids: Optional[List[int]] = None,
191
+ stop_regex: Optional[Union[str, List[str]]] = None,
187
192
  temperature: Optional[float] = None,
188
193
  top_p: Optional[float] = None,
189
194
  top_k: Optional[int] = None,
@@ -203,6 +208,7 @@ def gen_string(
203
208
  n,
204
209
  stop,
205
210
  stop_token_ids,
211
+ stop_regex,
206
212
  temperature,
207
213
  top_p,
208
214
  top_k,
@@ -792,6 +792,7 @@ class StreamExecutor:
792
792
  "n",
793
793
  "stop",
794
794
  "stop_token_ids",
795
+ "stop_regex",
795
796
  "temperature",
796
797
  "top_p",
797
798
  "top_k",
sglang/lang/ir.py CHANGED
@@ -21,6 +21,7 @@ class SglSamplingParams:
21
21
  n: int = 1
22
22
  stop: Union[str, List[str]] = ()
23
23
  stop_token_ids: Optional[List[int]] = ()
24
+ stop_regex: Optional[Union[str, List[str]]] = ()
24
25
  temperature: float = 1.0
25
26
  top_p: float = 1.0
26
27
  top_k: int = -1 # -1 means disable
@@ -45,6 +46,7 @@ class SglSamplingParams:
45
46
  self.n,
46
47
  self.stop,
47
48
  self.stop_token_ids,
49
+ self.stop_regex,
48
50
  self.temperature,
49
51
  self.top_p,
50
52
  self.top_k,
@@ -123,6 +125,7 @@ class SglSamplingParams:
123
125
  "n": self.n,
124
126
  "stop": self.stop,
125
127
  "stop_token_ids": self.stop_token_ids,
128
+ "stop_regex": self.stop_regex,
126
129
  "temperature": self.temperature,
127
130
  "top_p": self.top_p,
128
131
  "top_k": self.top_k,
@@ -161,6 +164,7 @@ class SglFunction:
161
164
  n: int = 1,
162
165
  stop: Optional[Union[str, List[str]]] = None,
163
166
  stop_token_ids: Optional[List[int]] = None,
167
+ stop_regex: Optional[Union[str, List[str]]] = None,
164
168
  temperature: float = 1.0,
165
169
  top_p: float = 1.0,
166
170
  top_k: int = -1,
@@ -184,12 +188,15 @@ class SglFunction:
184
188
  stop = []
185
189
  if stop_token_ids is None:
186
190
  stop_token_ids = []
191
+ if stop_regex is None:
192
+ stop_regex = []
187
193
 
188
194
  default_sampling_para = SglSamplingParams(
189
195
  max_new_tokens=max_new_tokens,
190
196
  n=n,
191
197
  stop=stop,
192
198
  stop_token_ids=stop_token_ids,
199
+ stop_regex=stop_regex,
193
200
  temperature=temperature,
194
201
  top_p=top_p,
195
202
  top_k=top_k,
@@ -221,6 +228,7 @@ class SglFunction:
221
228
  n: int = 1,
222
229
  stop: Optional[Union[str, List[str]]] = None,
223
230
  stop_token_ids: Optional[List[int]] = None,
231
+ stop_regex: Optional[Union[str, List[str]]] = None,
224
232
  temperature: float = 1.0,
225
233
  top_p: float = 1.0,
226
234
  top_k: int = -1,
@@ -243,6 +251,8 @@ class SglFunction:
243
251
  stop = []
244
252
  if stop_token_ids is None:
245
253
  stop_token_ids = []
254
+ if stop_regex is None:
255
+ stop_regex = []
246
256
 
247
257
  assert isinstance(batch_kwargs, (list, tuple))
248
258
  if len(batch_kwargs) == 0:
@@ -267,6 +277,7 @@ class SglFunction:
267
277
  n=n,
268
278
  stop=stop,
269
279
  stop_token_ids=stop_token_ids,
280
+ stop_regex=stop_regex,
270
281
  temperature=temperature,
271
282
  top_p=top_p,
272
283
  top_k=top_k,
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
451
462
  n: Optional[int] = None,
452
463
  stop: Optional[Union[str, List[str]]] = None,
453
464
  stop_token_ids: Optional[List[int]] = None,
465
+ stop_regex: Optional[Union[str, List[str]]] = None,
454
466
  temperature: Optional[float] = None,
455
467
  top_p: Optional[float] = None,
456
468
  top_k: Optional[int] = None,
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
474
486
  min_new_tokens=min_new_tokens,
475
487
  n=n,
476
488
  stop=stop,
489
+ stop_regex=stop_regex,
477
490
  stop_token_ids=stop_token_ids,
478
491
  temperature=temperature,
479
492
  top_p=top_p,
sglang/launch_server.py CHANGED
@@ -1,30 +1,23 @@
1
1
  """Launch the inference server."""
2
2
 
3
+ import asyncio
3
4
  import os
4
5
  import sys
5
6
 
6
- from sglang.srt.entrypoints.http_server import launch_server
7
7
  from sglang.srt.server_args import prepare_server_args
8
8
  from sglang.srt.utils import kill_process_tree
9
9
 
10
- MOVE_ENVS_WARN = """
11
- ########################################################################
12
- # For contributors and developers: #
13
- # Please move environment variable definitions to sglang.srt.environ #
14
- # using the following pattern: #
15
- # SGLANG_XXX = EnvBool(False) #
16
- # #
17
- ########################################################################
18
- """
19
-
20
10
  if __name__ == "__main__":
21
11
  server_args = prepare_server_args(sys.argv[1:])
22
12
 
23
- from sglang.srt.server_args import print_deprecated_warning
13
+ try:
14
+ if server_args.grpc_mode:
15
+ from sglang.srt.entrypoints.grpc_server import serve_grpc
24
16
 
25
- print_deprecated_warning(MOVE_ENVS_WARN)
17
+ asyncio.run(serve_grpc(server_args))
18
+ else:
19
+ from sglang.srt.entrypoints.http_server import launch_server
26
20
 
27
- try:
28
- launch_server(server_args)
21
+ launch_server(server_args)
29
22
  finally:
30
23
  kill_process_tree(os.getpid(), include_parent=False)
sglang/profiler.py CHANGED
@@ -25,6 +25,7 @@ def _run_profile(
25
25
  output_dir: Optional[str] = None,
26
26
  profile_name: Optional[str] = None,
27
27
  profile_by_stage: bool = False,
28
+ merge_profiles: bool = False,
28
29
  ) -> str:
29
30
  if output_dir is None:
30
31
  output_dir = PROFILER_DIR
@@ -60,6 +61,7 @@ def _run_profile(
60
61
  "num_steps": str(num_steps),
61
62
  "activities": activities,
62
63
  "profile_by_stage": profile_by_stage,
64
+ "merge_profiles": merge_profiles,
63
65
  }
64
66
 
65
67
  response = requests.post(url=url + "/start_profile", json=json_data)
@@ -76,10 +78,17 @@ def run_profile(
76
78
  output_dir: Optional[str] = None,
77
79
  profile_name: Optional[str] = None,
78
80
  profile_by_stage: bool = False,
81
+ merge_profiles: bool = False,
79
82
  ):
80
83
  # step based profile will self terminate on num_steps constraints
81
84
  link = _run_profile(
82
- url, num_steps, activities, output_dir, profile_name, profile_by_stage
85
+ url,
86
+ num_steps,
87
+ activities,
88
+ output_dir,
89
+ profile_name,
90
+ profile_by_stage,
91
+ merge_profiles,
83
92
  )
84
93
  return link
85
94
 
@@ -145,6 +154,13 @@ if __name__ == "__main__":
145
154
  default=False,
146
155
  help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
147
156
  )
157
+ parser.add_argument(
158
+ "--merge-profiles",
159
+ action=argparse.BooleanOptionalAction,
160
+ type=bool,
161
+ default=False,
162
+ help="Whether to merge profiles from all ranks into a single trace file",
163
+ )
148
164
 
149
165
  args = parser.parse_args()
150
166
  activities = []
@@ -163,4 +179,5 @@ if __name__ == "__main__":
163
179
  args.output_dir,
164
180
  args.profile_name,
165
181
  args.profile_by_stage,
182
+ args.merge_profiles,
166
183
  )
sglang/srt/_custom_ops.py CHANGED
@@ -15,7 +15,7 @@ if not is_hpu():
15
15
  # ROCm does not use vllm custom allreduce
16
16
  if use_vllm_custom_allreduce and not is_hip():
17
17
  try:
18
- import vllm._C
18
+ import vllm._C # noqa: F401
19
19
  except ImportError as e:
20
20
  logger.warning("Failed to import from vllm._C with %r", e)
21
21
  else:
@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
77
77
  k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
78
78
  num_tiles = num_pid_m * num_pid_n
79
79
 
80
- tile_id_c = start_pid - NUM_SMS
81
-
82
80
  offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
83
81
  num_pid_in_group = GROUP_SIZE_M * num_pid_n
84
82
 
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
120
118
  )
121
119
  accumulator = tl.dot(a, b, accumulator)
122
120
 
123
- tile_id_c += NUM_SMS
124
- pid_m, pid_n = _compute_pid(
125
- tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
126
- )
127
121
  offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
128
122
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
129
123
  if C_LARGE:
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
137
131
  accumulator += bias
138
132
  if c_ptr.dtype.element_ty == tl.float8e4nv:
139
133
  c = accumulator.to(tl.float8e4nv)
134
+ elif c_ptr.dtype.element_ty == tl.bfloat16:
135
+ c = accumulator.to(tl.bfloat16)
136
+ elif c_ptr.dtype.element_ty == tl.float32:
137
+ c = accumulator.to(tl.float32)
140
138
  else:
141
139
  c = accumulator.to(tl.float16)
142
140
  tl.store(c_ptrs, c, mask=c_mask)
@@ -0,0 +1,142 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Checkpoint-engine integration for SGLang.
16
+ This module provides weight update functionality via IPC for checkpoint-engine compatibility.
17
+ """
18
+ import logging
19
+ from typing import Callable, Dict, Optional
20
+
21
+ import torch
22
+ import zmq
23
+
24
+ try:
25
+ from checkpoint_engine.worker import update_weights_from_ipc
26
+ except ImportError:
27
+ raise ImportError(
28
+ "checkpoint-engine is not installed. "
29
+ "Please install it with: pip install sglang[checkpoint-engine]"
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SGLangCheckpointEngineWorkerExtension:
36
+ """
37
+ Worker extension for SGLang to support checkpoint-engine IPC weight updates.
38
+ This class provides the interface needed for checkpoint-engine integration.
39
+ """
40
+
41
+ def __init__(self):
42
+ self._zmq_ctx: Optional[zmq.Context] = None
43
+
44
+ def get_device_uuid(self) -> str:
45
+ """Get the UUID of current device."""
46
+ # We need to implement this to get the device UUID
47
+ # This will be overridden when integrated into SGLang's worker
48
+ raise NotImplementedError(
49
+ "This method should be overridden by SGLang integration"
50
+ )
51
+
52
+ def get_device_id(self) -> int:
53
+ """Get the device ID."""
54
+ raise NotImplementedError(
55
+ "This method should be overridden by SGLang integration"
56
+ )
57
+
58
+ def get_model_loader(self) -> Callable:
59
+ """Get the model weight loader function."""
60
+ raise NotImplementedError(
61
+ "This method should be overridden by SGLang integration"
62
+ )
63
+
64
+ def get_post_hook(self) -> Optional[Callable]:
65
+ """Get the post-processing hook after weight loading."""
66
+ return None
67
+
68
+ def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
69
+ """
70
+ Update weights from IPC communication.
71
+ Args:
72
+ zmq_handles: Dict mapping device UUID to ZMQ socket path
73
+ """
74
+ if self._zmq_ctx is None:
75
+ self._zmq_ctx = zmq.Context()
76
+ device_uuid = self.get_device_uuid()
77
+ device_id = self.get_device_id()
78
+ if device_uuid not in zmq_handles:
79
+ raise ValueError(
80
+ f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
81
+ )
82
+ update_weights_from_ipc(
83
+ self._zmq_ctx,
84
+ zmq_handles[device_uuid],
85
+ device_id=device_id,
86
+ run=self.get_model_loader(),
87
+ post_hook=self.get_post_hook(),
88
+ )
89
+
90
+
91
+ class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
92
+ """
93
+ Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
94
+ This class provides the concrete implementation for checkpoint-engine IPC weight updates.
95
+ """
96
+
97
+ def __init__(self, model_runner):
98
+ super().__init__()
99
+ self.model_runner = model_runner
100
+
101
+ def get_device_uuid(self) -> str:
102
+ """Get the UUID of current device."""
103
+ # Get device UUID for current device
104
+ device_id = torch.cuda.current_device()
105
+ try:
106
+ return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
107
+ except AssertionError as e:
108
+ raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
109
+
110
+ def get_device_id(self) -> int:
111
+ """Get the device ID."""
112
+ return torch.cuda.current_device()
113
+
114
+ def get_model_loader(self) -> Callable:
115
+ """Get the model weight loader function."""
116
+ return self.model_runner.model.load_weights
117
+
118
+ def get_post_hook(self) -> Optional[Callable]:
119
+ """Get the post-processing hook after weight loading."""
120
+
121
+ def post_hook():
122
+ # Perform post-processing after weight loading similar to DefaultModelLoader
123
+ try:
124
+ from sglang.srt.model_loader.loader import device_loading_context
125
+
126
+ # Process quantization methods after loading weights
127
+ for _, module in self.model_runner.model.named_modules():
128
+ quant_method = getattr(module, "quant_method", None)
129
+ if quant_method is not None:
130
+ # Move parameters to device if needed for quantization processing
131
+ target_device = torch.device(
132
+ "cuda", torch.cuda.current_device()
133
+ )
134
+ with device_loading_context(module, target_device):
135
+ quant_method.process_weights_after_loading(module)
136
+ # Call model-specific post-loading hook if available
137
+ if hasattr(self.model_runner.model, "post_load_weights"):
138
+ self.model_runner.model.post_load_weights()
139
+ except Exception as e:
140
+ logger.warning(f"Post-hook processing failed: {e}")
141
+
142
+ return post_hook