sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,6 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """Common utilities."""
15
-
16
15
  from __future__ import annotations
17
16
 
18
17
  import argparse
@@ -22,6 +21,7 @@ import ctypes
22
21
  import dataclasses
23
22
  import functools
24
23
  import importlib
24
+ import inspect
25
25
  import io
26
26
  import ipaddress
27
27
  import itertools
@@ -42,6 +42,7 @@ import tempfile
42
42
  import threading
43
43
  import time
44
44
  import traceback
45
+ import types
45
46
  import uuid
46
47
  import warnings
47
48
  from collections import OrderedDict, defaultdict
@@ -55,6 +56,7 @@ from json import JSONDecodeError
55
56
  from multiprocessing.reduction import ForkingPickler
56
57
  from pathlib import Path
57
58
  from typing import (
59
+ TYPE_CHECKING,
58
60
  Any,
59
61
  Callable,
60
62
  Dict,
@@ -62,6 +64,7 @@ from typing import (
62
64
  List,
63
65
  Optional,
64
66
  Protocol,
67
+ Sequence,
65
68
  Set,
66
69
  Tuple,
67
70
  TypeVar,
@@ -69,6 +72,7 @@ from typing import (
69
72
  )
70
73
 
71
74
  import numpy as np
75
+ import orjson
72
76
  import psutil
73
77
  import pybase64
74
78
  import requests
@@ -82,15 +86,17 @@ from packaging import version as pkg_version
82
86
  from PIL import Image
83
87
  from starlette.routing import Mount
84
88
  from torch import nn
85
- from torch.func import functional_call
86
89
  from torch.library import Library
87
90
  from torch.profiler import ProfilerActivity, profile, record_function
88
91
  from torch.utils._contextlib import _DecoratorContextManager
89
- from triton.runtime.cache import FileCacheManager
90
92
  from typing_extensions import Literal
91
93
 
94
+ from sglang.srt.environ import envs
92
95
  from sglang.srt.metrics.func_timer import enable_func_timer
93
96
 
97
+ if TYPE_CHECKING:
98
+ from sglang.srt.layers.quantization.base_config import QuantizeMethodBase
99
+
94
100
  logger = logging.getLogger(__name__)
95
101
 
96
102
  show_time_cost = False
@@ -163,18 +169,44 @@ def _check(cc_major):
163
169
  ) >= (12, 3)
164
170
 
165
171
 
172
+ @contextmanager
173
+ def device_context(device: torch.device):
174
+ if device.type == "cpu" and is_cpu():
175
+ with torch.device("cpu"):
176
+ yield
177
+ else:
178
+ module = torch.get_device_module(device)
179
+ if module is not None:
180
+ with module.device(device.index):
181
+ yield
182
+ else:
183
+ raise ValueError(f"Unknown device module: {device}")
184
+
185
+
166
186
  is_ampere_with_cuda_12_3 = lambda: _check(8)
167
187
  is_hopper_with_cuda_12_3 = lambda: _check(9)
168
188
 
169
189
 
190
+ @lru_cache(maxsize=1)
170
191
  def is_blackwell():
171
192
  if not is_cuda():
172
193
  return False
173
194
  return torch.cuda.get_device_capability()[0] == 10
174
195
 
175
196
 
197
+ @lru_cache(maxsize=1)
198
+ def is_sm120_supported(device=None) -> bool:
199
+ if not is_cuda_alike():
200
+ return False
201
+ return (torch.cuda.get_device_capability(device)[0] == 12) and (
202
+ torch.version.cuda >= "12.8"
203
+ )
204
+
205
+
176
206
  @lru_cache(maxsize=1)
177
207
  def is_sm100_supported(device=None) -> bool:
208
+ if not is_cuda_alike():
209
+ return False
178
210
  return (torch.cuda.get_device_capability(device)[0] == 10) and (
179
211
  torch.version.cuda >= "12.8"
180
212
  )
@@ -182,6 +214,8 @@ def is_sm100_supported(device=None) -> bool:
182
214
 
183
215
  @lru_cache(maxsize=1)
184
216
  def is_sm90_supported(device=None) -> bool:
217
+ if not is_cuda_alike():
218
+ return False
185
219
  return (torch.cuda.get_device_capability(device)[0] == 9) and (
186
220
  torch.version.cuda >= "12.3"
187
221
  )
@@ -191,6 +225,7 @@ _warned_bool_env_var_keys = set()
191
225
 
192
226
 
193
227
  def get_bool_env_var(name: str, default: str = "false") -> bool:
228
+ # FIXME: move your environment variable to sglang.srt.environ
194
229
  value = os.getenv(name, default)
195
230
  value = value.lower()
196
231
 
@@ -208,6 +243,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
208
243
 
209
244
 
210
245
  def get_int_env_var(name: str, default: int = 0) -> int:
246
+ # FIXME: move your environment variable to sglang.srt.environ
211
247
  value = os.getenv(name)
212
248
  if value is None or not value.strip():
213
249
  return default
@@ -222,7 +258,7 @@ def support_triton(backend: str) -> bool:
222
258
 
223
259
 
224
260
  try:
225
- import sgl_kernel
261
+ import sgl_kernel # noqa: F401
226
262
 
227
263
  is_intel_amx_backend_available = hasattr(
228
264
  torch.ops.sgl_kernel, "convert_weight_packed"
@@ -247,6 +283,14 @@ def use_intel_amx_backend(layer):
247
283
  return getattr(layer, "use_intel_amx_backend", False)
248
284
 
249
285
 
286
+ def xpu_has_xmx_support():
287
+ # TODO: update with XPU capalibity query
288
+ if is_xpu():
289
+ # currently only PVC/LNL/BMG supports F64, so we only support these now
290
+ return torch.xpu.get_device_properties().has_fp64
291
+ return False
292
+
293
+
250
294
  def is_flashinfer_available():
251
295
  """
252
296
  Check whether flashinfer is available.
@@ -257,6 +301,17 @@ def is_flashinfer_available():
257
301
  return importlib.util.find_spec("flashinfer") is not None and is_cuda()
258
302
 
259
303
 
304
+ def is_nvidia_cublas_cu12_version_ge_12_9():
305
+ """
306
+ temporary fix for issue #11272
307
+ """
308
+ try:
309
+ installed_version = version("nvidia-cublas-cu12")
310
+ except PackageNotFoundError:
311
+ return False
312
+ return pkg_version.parse(installed_version) >= pkg_version.parse("12.9")
313
+
314
+
260
315
  def random_uuid() -> str:
261
316
  return str(uuid.uuid4().hex)
262
317
 
@@ -403,7 +458,15 @@ def get_available_gpu_memory(
403
458
 
404
459
  if empty_cache:
405
460
  torch.cuda.empty_cache()
406
- free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
461
+ SHARED_SYSMEM_DEVICE_MEM_SMS = (87, 110, 121) # Orin, Thor, Spark
462
+ if get_device_sm() in SHARED_SYSMEM_DEVICE_MEM_SMS:
463
+ # On these devices, which use sysmem as device mem, torch.cuda.mem_get_info()
464
+ # only reports "free" memory, which can be lower than what is actually
465
+ # available due to not including cache memory. So we use the system available
466
+ # memory metric instead.
467
+ free_gpu_memory = psutil.virtual_memory().available
468
+ else:
469
+ free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
407
470
 
408
471
  elif device == "xpu":
409
472
  num_gpus = torch.xpu.device_count()
@@ -447,6 +510,8 @@ def get_available_gpu_memory(
447
510
  f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
448
511
  "which may cause useless memory allocation for torch NPU context.",
449
512
  )
513
+ if empty_cache:
514
+ torch.npu.empty_cache()
450
515
  free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
451
516
 
452
517
  if distributed:
@@ -465,7 +530,7 @@ def is_pin_memory_available() -> bool:
465
530
 
466
531
  class LayerFn(Protocol):
467
532
 
468
- def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
533
+ def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
469
534
 
470
535
 
471
536
  def make_layers(
@@ -475,13 +540,13 @@ def make_layers(
475
540
  pp_size: Optional[int] = None,
476
541
  prefix: str = "",
477
542
  return_tuple: bool = False,
478
- offloader_kwargs: Dict[str, Any] = {},
479
- ) -> Tuple[int, int, torch.nn.ModuleList]:
543
+ offloader_kwargs: Optional[Dict[str, Any]] = None,
544
+ ) -> Tuple[torch.nn.Module, int, int]:
480
545
  """Make a list of layers with the given layer function"""
481
546
  # circula imports
482
547
  from sglang.srt.distributed import get_pp_indices
483
548
  from sglang.srt.layers.utils import PPMissingLayer
484
- from sglang.srt.offloader import get_offloader
549
+ from sglang.srt.utils.offloader import get_offloader
485
550
 
486
551
  assert not pp_size or num_hidden_layers >= pp_size
487
552
  start_layer, end_layer = (
@@ -500,7 +565,7 @@ def make_layers(
500
565
  layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
501
566
  for idx in range(start_layer, end_layer)
502
567
  ),
503
- **offloader_kwargs,
568
+ **(offloader_kwargs or {}),
504
569
  )
505
570
  + [
506
571
  PPMissingLayer(return_tuple=return_tuple)
@@ -512,6 +577,68 @@ def make_layers(
512
577
  return modules, start_layer, end_layer
513
578
 
514
579
 
580
+ def make_layers_non_pp(
581
+ num_hidden_layers: int,
582
+ layer_fn: LayerFn,
583
+ prefix: str = "",
584
+ ) -> torch.nn.ModuleList:
585
+ from sglang.srt.utils.offloader import get_offloader
586
+
587
+ layers = torch.nn.ModuleList(
588
+ get_offloader().wrap_modules(
589
+ (
590
+ layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
591
+ for idx in range(num_hidden_layers)
592
+ )
593
+ )
594
+ )
595
+ return layers
596
+
597
+
598
+ cmo_stream = None
599
+
600
+
601
+ def get_cmo_stream():
602
+ """
603
+ Cache Management Operation(CMO).
604
+ Launch a new stream to prefetch the weight of matmul when running other
605
+ AIV or communication kernels, aiming to overlap the memory access time.
606
+ """
607
+ global cmo_stream
608
+ if cmo_stream is None:
609
+ cmo_stream = torch.get_device_module().Stream()
610
+ return cmo_stream
611
+
612
+
613
+ def prepare_weight_cache(handle, cache):
614
+ import torch_npu
615
+
616
+ NPU_PREFETCH_MAX_SIZE_BYTES = (
617
+ 1000000000 # 1GB, a large value to prefetch entire weight
618
+ )
619
+ stream = get_cmo_stream()
620
+ stream.wait_stream(torch.npu.current_stream())
621
+ with torch.npu.stream(stream):
622
+ if isinstance(cache, list):
623
+ for weight in cache:
624
+ torch_npu.npu_prefetch(
625
+ weight,
626
+ handle,
627
+ NPU_PREFETCH_MAX_SIZE_BYTES,
628
+ )
629
+ else:
630
+ torch_npu.npu_prefetch(
631
+ cache,
632
+ handle,
633
+ NPU_PREFETCH_MAX_SIZE_BYTES,
634
+ )
635
+
636
+
637
+ def wait_cmo_stream():
638
+ cur_stream = torch.get_device_module().current_stream()
639
+ cur_stream.wait_stream(get_cmo_stream())
640
+
641
+
515
642
  def set_random_seed(seed: int) -> None:
516
643
  """Set the random seed for all libraries."""
517
644
  random.seed(seed)
@@ -749,6 +876,25 @@ def load_image(
749
876
  return image, image_size
750
877
 
751
878
 
879
+ def get_image_bytes(image_file: Union[str, bytes]):
880
+ if isinstance(image_file, bytes):
881
+ return image_file
882
+ elif image_file.startswith("http://") or image_file.startswith("https://"):
883
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
884
+ response = requests.get(image_file, timeout=timeout)
885
+ return response.content
886
+ elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
887
+ with open(image_file, "rb") as f:
888
+ return f.read()
889
+ elif image_file.startswith("data:"):
890
+ image_file = image_file.split(",")[1]
891
+ return pybase64.b64decode(image_file, validate=True)
892
+ elif isinstance(image_file, str):
893
+ return pybase64.b64decode(image_file, validate=True)
894
+ else:
895
+ raise NotImplementedError(f"Invalid image: {image_file}")
896
+
897
+
752
898
  def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
753
899
  # We import decord here to avoid a strange Segmentation fault (core dumped) issue.
754
900
  from decord import VideoReader, cpu, gpu
@@ -781,7 +927,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
781
927
  vr = VideoReader(tmp_file.name, ctx=ctx)
782
928
  elif video_file.startswith("data:"):
783
929
  _, encoded = video_file.split(",", 1)
784
- video_bytes = pybase64.b64decode(encoded)
930
+ video_bytes = pybase64.b64decode(encoded, validate=True)
785
931
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
786
932
  tmp_file.write(video_bytes)
787
933
  tmp_file.close()
@@ -789,7 +935,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
789
935
  elif os.path.isfile(video_file):
790
936
  vr = VideoReader(video_file, ctx=ctx)
791
937
  else:
792
- video_bytes = pybase64.b64decode(video_file)
938
+ video_bytes = pybase64.b64decode(video_file, validate=True)
793
939
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
794
940
  tmp_file.write(video_bytes)
795
941
  tmp_file.close()
@@ -804,6 +950,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
804
950
  os.unlink(tmp_file.name)
805
951
 
806
952
 
953
+ def encode_video(video_path, frame_count_limit=None):
954
+ # Lazy import because decord is not available on some arm platforms.
955
+ from decord import VideoReader, cpu
956
+
957
+ if not os.path.exists(video_path):
958
+ logger.error(f"Video {video_path} does not exist")
959
+ return []
960
+
961
+ if frame_count_limit == 0:
962
+ return []
963
+
964
+ def uniform_sample(l, n):
965
+ gap = len(l) / n
966
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
967
+ return [l[i] for i in idxs]
968
+
969
+ vr = VideoReader(video_path, ctx=cpu(0))
970
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
971
+ frame_indices = [i for i in range(0, len(vr), sample_fps)]
972
+ if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
973
+ frame_indices = uniform_sample(frame_indices, frame_count_limit)
974
+
975
+ frames = vr.get_batch(frame_indices).asnumpy()
976
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
977
+ return frames
978
+
979
+
807
980
  def suppress_other_loggers():
808
981
  warnings.filterwarnings(
809
982
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -911,7 +1084,7 @@ def monkey_patch_vllm_gguf_config():
911
1084
 
912
1085
  def get_quant_method_with_embedding_replaced(
913
1086
  self, layer: torch.nn.Module, prefix: str
914
- ) -> Optional["QuantizeMethodBase"]:
1087
+ ) -> Optional[QuantizeMethodBase]:
915
1088
  if isinstance(layer, LinearBase):
916
1089
  return GGUFLinearMethod(self)
917
1090
  elif isinstance(layer, VocabParallelEmbedding):
@@ -946,6 +1119,13 @@ def set_ulimit(target_soft_limit=65535):
946
1119
  logger.warning(f"Fail to set RLIMIT_STACK: {e}")
947
1120
 
948
1121
 
1122
+ def rank0_log(msg: str):
1123
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
1124
+
1125
+ if get_tensor_model_parallel_rank() == 0:
1126
+ logger.info(msg)
1127
+
1128
+
949
1129
  def add_api_key_middleware(app, api_key: str):
950
1130
  @app.middleware("http")
951
1131
  async def authentication(request, call_next):
@@ -980,7 +1160,7 @@ def configure_logger(server_args, prefix: str = ""):
980
1160
  f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
981
1161
  )
982
1162
  with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
983
- custom_config = json.loads(file.read())
1163
+ custom_config = orjson.loads(file.read())
984
1164
  logging.config.dictConfig(custom_config)
985
1165
  return
986
1166
  format = f"[%(asctime)s{prefix}] %(message)s"
@@ -1159,8 +1339,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
1159
1339
 
1160
1340
 
1161
1341
  def get_zmq_socket(
1162
- context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
1163
- ) -> zmq.Socket:
1342
+ context: zmq.Context,
1343
+ socket_type: zmq.SocketType,
1344
+ endpoint: Optional[str] = None,
1345
+ bind: bool = True,
1346
+ ) -> Union[zmq.Socket, Tuple[int, zmq.Socket]]:
1347
+ """Create and configure a ZeroMQ socket.
1348
+
1349
+ Args:
1350
+ context: ZeroMQ context to create the socket from.
1351
+ socket_type: Type of ZeroMQ socket to create.
1352
+ endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
1353
+ bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
1354
+
1355
+ Returns:
1356
+ If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
1357
+ If endpoint is provided: The configured ZeroMQ socket.
1358
+ """
1359
+ socket = context.socket(socket_type)
1360
+
1361
+ if endpoint is None:
1362
+ # Bind to random TCP port
1363
+ config_socket(socket, socket_type)
1364
+ port = socket.bind_to_random_port("tcp://*")
1365
+ return port, socket
1366
+ else:
1367
+ # Handle IPv6 if endpoint contains brackets
1368
+ if endpoint.find("[") != -1:
1369
+ socket.setsockopt(zmq.IPV6, 1)
1370
+
1371
+ config_socket(socket, socket_type)
1372
+
1373
+ if bind:
1374
+ socket.bind(endpoint)
1375
+ else:
1376
+ socket.connect(endpoint)
1377
+
1378
+ return socket
1379
+
1380
+
1381
+ def config_socket(socket, socket_type: zmq.SocketType):
1164
1382
  mem = psutil.virtual_memory()
1165
1383
  total_mem = mem.total / 1024**3
1166
1384
  available_mem = mem.available / 1024**3
@@ -1169,10 +1387,6 @@ def get_zmq_socket(
1169
1387
  else:
1170
1388
  buf_size = -1
1171
1389
 
1172
- socket = context.socket(socket_type)
1173
- if endpoint.find("[") != -1:
1174
- socket.setsockopt(zmq.IPV6, 1)
1175
-
1176
1390
  def set_send_opt():
1177
1391
  socket.setsockopt(zmq.SNDHWM, 0)
1178
1392
  socket.setsockopt(zmq.SNDBUF, buf_size)
@@ -1185,19 +1399,12 @@ def get_zmq_socket(
1185
1399
  set_send_opt()
1186
1400
  elif socket_type == zmq.PULL:
1187
1401
  set_recv_opt()
1188
- elif socket_type == zmq.DEALER:
1402
+ elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]:
1189
1403
  set_send_opt()
1190
1404
  set_recv_opt()
1191
1405
  else:
1192
1406
  raise ValueError(f"Unsupported socket type: {socket_type}")
1193
1407
 
1194
- if bind:
1195
- socket.bind(endpoint)
1196
- else:
1197
- socket.connect(endpoint)
1198
-
1199
- return socket
1200
-
1201
1408
 
1202
1409
  def dump_to_file(dirpath, name, value):
1203
1410
  from sglang.srt.distributed import get_tensor_model_parallel_rank
@@ -1397,13 +1604,44 @@ def get_hpu_memory_capacity():
1397
1604
 
1398
1605
  def get_npu_memory_capacity():
1399
1606
  try:
1400
- import torch_npu
1607
+ import torch_npu # noqa: F401
1401
1608
 
1402
1609
  return torch.npu.mem_get_info()[1] // 1024 // 1024 # unit: MB
1403
1610
  except ImportError as e:
1404
1611
  raise ImportError("torch_npu is required when run on npu device.")
1405
1612
 
1406
1613
 
1614
+ def get_cpu_memory_capacity():
1615
+ # Per-rank memory capacity cannot be determined for customized core settings
1616
+ if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
1617
+ return None
1618
+ n_numa_node: int = len(get_cpu_ids_by_node())
1619
+ if n_numa_node == 0:
1620
+ # Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
1621
+ return float(psutil.virtual_memory().total // (1 << 20))
1622
+ try:
1623
+ numa_mem_list = list()
1624
+ file_prefix = "/sys/devices/system/node/"
1625
+ for numa_id in range(n_numa_node):
1626
+ file_meminfo = f"node{numa_id}/meminfo"
1627
+ with open(os.path.join(file_prefix, file_meminfo), "r") as f:
1628
+ # MemTotal info is at the 1st line
1629
+ line = f.readline()
1630
+ # Expected format: "Node 0 MemTotal: 100000000 kB"
1631
+ parts = line.split()
1632
+ if len(parts) >= 4 and parts[2] == "MemTotal:":
1633
+ numa_mem_list.append(int(parts[3]))
1634
+ else:
1635
+ raise ValueError(f"Unexpected format in {file_meminfo}: {line}")
1636
+ # Retrieved value in KB, need MB
1637
+ numa_mem = float(min(numa_mem_list) // 1024)
1638
+ return numa_mem
1639
+ except (FileNotFoundError, ValueError, IndexError):
1640
+ numa_mem = psutil.virtual_memory().total / n_numa_node
1641
+ # Retrieved value in Byte, need MB
1642
+ return float(numa_mem // (1 << 20))
1643
+
1644
+
1407
1645
  def get_device_memory_capacity(device: str = None):
1408
1646
  if is_cuda():
1409
1647
  gpu_mem = get_nvgpu_memory_capacity()
@@ -1413,6 +1651,8 @@ def get_device_memory_capacity(device: str = None):
1413
1651
  gpu_mem = get_hpu_memory_capacity()
1414
1652
  elif device == "npu":
1415
1653
  gpu_mem = get_npu_memory_capacity()
1654
+ elif device == "cpu":
1655
+ gpu_mem = get_cpu_memory_capacity()
1416
1656
  else:
1417
1657
  # GPU memory is not known yet or no GPU is available.
1418
1658
  gpu_mem = None
@@ -1556,7 +1796,7 @@ def get_device(device_id: Optional[int] = None) -> str:
1556
1796
 
1557
1797
  if is_habana_available():
1558
1798
  try:
1559
- import habana_frameworks.torch.hpu
1799
+ import habana_frameworks.torch.hpu # noqa: F401
1560
1800
 
1561
1801
  if torch.hpu.is_available():
1562
1802
  if device_id == None:
@@ -1586,7 +1826,7 @@ def get_device_count() -> int:
1586
1826
 
1587
1827
  if is_habana_available():
1588
1828
  try:
1589
- import habana_frameworks.torch.hpu
1829
+ import habana_frameworks.torch.hpu # noqa: F401
1590
1830
 
1591
1831
  if torch.hpu.is_available():
1592
1832
  return torch.hpu.device_count()
@@ -1729,7 +1969,9 @@ def direct_register_custom_op(
1729
1969
  if fake_impl is not None:
1730
1970
  my_lib._register_fake(op_name, fake_impl)
1731
1971
  except RuntimeError as error:
1732
- if "Tried to register an operator" in str(e) and "multiple times" in str(e):
1972
+ if "Tried to register an operator" in str(error) and "multiple times" in str(
1973
+ error
1974
+ ):
1733
1975
  # Silently ignore duplicate registration errors
1734
1976
  # This can happen in multi-engine scenarios
1735
1977
  pass
@@ -1742,6 +1984,7 @@ def direct_register_custom_op(
1742
1984
 
1743
1985
 
1744
1986
  def set_gpu_proc_affinity(
1987
+ pp_size: int,
1745
1988
  tp_size: int,
1746
1989
  nnodes: int,
1747
1990
  gpu_id: int,
@@ -1750,7 +1993,8 @@ def set_gpu_proc_affinity(
1750
1993
  pid = os.getpid()
1751
1994
  p = psutil.Process(pid)
1752
1995
 
1753
- tp_size_per_node = tp_size // nnodes
1996
+ nnodes_per_tp_group = max(nnodes // pp_size, 1)
1997
+ tp_size_per_node = tp_size // nnodes_per_tp_group
1754
1998
 
1755
1999
  # total physical cores
1756
2000
  total_pcores = psutil.cpu_count(logical=False)
@@ -1862,7 +2106,7 @@ class MultiprocessingSerializer:
1862
2106
 
1863
2107
  if output_str:
1864
2108
  # Convert bytes to base64-encoded string
1865
- output = pybase64.b64encode(output).decode("utf-8")
2109
+ pybase64.b64encode(output).decode("utf-8")
1866
2110
 
1867
2111
  return output
1868
2112
 
@@ -1951,50 +2195,6 @@ def set_uvicorn_logging_configs():
1951
2195
  LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1952
2196
 
1953
2197
 
1954
- def get_ip() -> str:
1955
- # SGLANG_HOST_IP env can be ignore
1956
- host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
1957
- if host_ip:
1958
- return host_ip
1959
-
1960
- # IP is not set, try to get it from the network interface
1961
-
1962
- # try ipv4
1963
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1964
- try:
1965
- s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
1966
- return s.getsockname()[0]
1967
- except Exception:
1968
- pass
1969
-
1970
- # try ipv6
1971
- try:
1972
- s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
1973
- # Google's public DNS server, see
1974
- # https://developers.google.com/speed/public-dns/docs/using#addresses
1975
- s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
1976
- return s.getsockname()[0]
1977
- except Exception:
1978
- pass
1979
-
1980
- # try using hostname
1981
- hostname = socket.gethostname()
1982
- try:
1983
- ip_addr = socket.gethostbyname(hostname)
1984
- warnings.warn("using local ip address: {}".format(ip_addr))
1985
- return ip_addr
1986
- except Exception:
1987
- pass
1988
-
1989
- warnings.warn(
1990
- "Failed to get the IP address, using 0.0.0.0 by default."
1991
- "The value can be set by the environment variable"
1992
- " SGLANG_HOST_IP or HOST_IP.",
1993
- stacklevel=2,
1994
- )
1995
- return "0.0.0.0"
1996
-
1997
-
1998
2198
  def get_open_port() -> int:
1999
2199
  port = os.getenv("SGLANG_PORT")
2000
2200
  if port is not None:
@@ -2077,6 +2277,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics):
2077
2277
 
2078
2278
  app = FastAPI()
2079
2279
 
2280
+ @app.get("/ping")
2281
+ async def ping():
2282
+ """Could be used by the checkpoint-engine update script to confirm the server is up."""
2283
+ return Response(status_code=200)
2284
+
2080
2285
  @app.get("/health")
2081
2286
  async def health():
2082
2287
  """Check the health of the http server."""
@@ -2199,6 +2404,8 @@ def retry(
2199
2404
  try:
2200
2405
  return fn()
2201
2406
  except Exception as e:
2407
+ traceback.print_exc()
2408
+
2202
2409
  if try_index >= max_retry:
2203
2410
  raise Exception(f"retry() exceed maximum number of retries.")
2204
2411
 
@@ -2212,11 +2419,30 @@ def retry(
2212
2419
  logger.warning(
2213
2420
  f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}"
2214
2421
  )
2215
- traceback.print_exc()
2216
2422
 
2217
2423
  time.sleep(delay)
2218
2424
 
2219
2425
 
2426
+ def has_hf_quant_config(model_path: str) -> bool:
2427
+ """Check if the model path contains hf_quant_config.json file.
2428
+
2429
+ Args:
2430
+ model_path: Path to the model, can be local path or remote URL.
2431
+
2432
+ Returns:
2433
+ True if hf_quant_config.json exists, False otherwise.
2434
+ """
2435
+ if os.path.exists(os.path.join(model_path, "hf_quant_config.json")):
2436
+ return True
2437
+ try:
2438
+ from huggingface_hub import HfApi
2439
+
2440
+ hf_api = HfApi()
2441
+ return hf_api.file_exists(model_path, "hf_quant_config.json")
2442
+ except Exception:
2443
+ return False
2444
+
2445
+
2220
2446
  def flatten_nested_list(nested_list):
2221
2447
  if isinstance(nested_list, list):
2222
2448
  return [
@@ -2251,16 +2477,9 @@ def bind_or_assign(target, source):
2251
2477
  return source
2252
2478
 
2253
2479
 
2254
- def get_local_ip_auto() -> str:
2255
- interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
2256
- return (
2257
- get_local_ip_by_nic(interface)
2258
- if interface is not None
2259
- else get_local_ip_by_remote()
2260
- )
2261
-
2262
-
2263
- def get_local_ip_by_nic(interface: str) -> str:
2480
+ def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
2481
+ if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
2482
+ return None
2264
2483
  try:
2265
2484
  import netifaces
2266
2485
  except ImportError as e:
@@ -2281,15 +2500,13 @@ def get_local_ip_by_nic(interface: str) -> str:
2281
2500
  if ip and not ip.startswith("fe80::") and ip != "::1":
2282
2501
  return ip.split("%")[0]
2283
2502
  except (ValueError, OSError) as e:
2284
- raise ValueError(
2285
- "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2503
+ logger.warning(
2504
+ f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2286
2505
  )
2287
-
2288
- # Fallback
2289
- return get_local_ip_by_remote()
2506
+ return None
2290
2507
 
2291
2508
 
2292
- def get_local_ip_by_remote() -> str:
2509
+ def get_local_ip_by_remote() -> Optional[str]:
2293
2510
  # try ipv4
2294
2511
  s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
2295
2512
  try:
@@ -2314,7 +2531,51 @@ def get_local_ip_by_remote() -> str:
2314
2531
  s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
2315
2532
  return s.getsockname()[0]
2316
2533
  except Exception:
2317
- raise ValueError("Can not get local ip")
2534
+ logger.warning("Can not get local ip by remote")
2535
+ return None
2536
+
2537
+
2538
+ def get_local_ip_auto(fallback: str = None) -> str:
2539
+ """
2540
+ Automatically detect the local IP address using multiple fallback strategies.
2541
+
2542
+ This function attempts to obtain the local IP address through several methods.
2543
+ If all methods fail, it returns the specified fallback value or raises an exception.
2544
+
2545
+ Args:
2546
+ fallback (str, optional): Fallback IP address to return if all detection
2547
+ methods fail. For server applications, explicitly set this to
2548
+ "0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
2549
+ Defaults to None.
2550
+
2551
+ Returns:
2552
+ str: The detected local IP address, or the fallback value if detection fails.
2553
+
2554
+ Raises:
2555
+ ValueError: If IP detection fails and no fallback value is provided.
2556
+
2557
+ Note:
2558
+ The function tries detection methods in the following order:
2559
+ 1. Direct IP detection via get_ip()
2560
+ 2. Network interface enumeration via get_local_ip_by_nic()
2561
+ 3. Remote connection method via get_local_ip_by_remote()
2562
+ """
2563
+ # Try environment variable
2564
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
2565
+ if host_ip:
2566
+ return host_ip
2567
+ logger.debug("get_ip failed")
2568
+ # Fallback
2569
+ if ip := get_local_ip_by_nic():
2570
+ return ip
2571
+ logger.debug("get_local_ip_by_nic failed")
2572
+ # Fallback
2573
+ if ip := get_local_ip_by_remote():
2574
+ return ip
2575
+ logger.debug("get_local_ip_by_remote failed")
2576
+ if fallback:
2577
+ return fallback
2578
+ raise ValueError("Can not get local ip")
2318
2579
 
2319
2580
 
2320
2581
  def is_page_size_one(server_args):
@@ -2339,6 +2600,7 @@ def is_fa3_default_architecture(hf_config):
2339
2600
  "Qwen2ForCausalLM",
2340
2601
  "Llama4ForConditionalGeneration",
2341
2602
  "LlamaForCausalLM",
2603
+ "Olmo2ForCausalLM",
2342
2604
  "Gemma2ForCausalLM",
2343
2605
  "Gemma3ForConditionalGeneration",
2344
2606
  "Qwen3ForCausalLM",
@@ -2366,15 +2628,15 @@ class BumpAllocator:
2366
2628
  def log_info_on_rank0(logger, msg):
2367
2629
  from sglang.srt.distributed import get_tensor_model_parallel_rank
2368
2630
 
2369
- if get_tensor_model_parallel_rank() == 0:
2631
+ if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
2370
2632
  logger.info(msg)
2371
2633
 
2372
2634
 
2373
2635
  def load_json_config(data: str):
2374
2636
  try:
2375
- return json.loads(data)
2637
+ return orjson.loads(data)
2376
2638
  except JSONDecodeError:
2377
- return json.loads(Path(data).read_text())
2639
+ return orjson.loads(Path(data).read_text())
2378
2640
 
2379
2641
 
2380
2642
  def dispose_tensor(x: torch.Tensor):
@@ -2496,14 +2758,6 @@ def read_system_prompt_from_file(model_name: str) -> str:
2496
2758
  return ""
2497
2759
 
2498
2760
 
2499
- def bind_or_assign(target, source):
2500
- if target is not None:
2501
- target.copy_(source)
2502
- return target
2503
- else:
2504
- return source
2505
-
2506
-
2507
2761
  def prepack_weight_if_needed(weight):
2508
2762
  if weight.device != torch.device("cpu"):
2509
2763
  return weight
@@ -2749,7 +3003,7 @@ def get_cpu_ids_by_node():
2749
3003
  def is_shm_available(dtype, world_size, local_size):
2750
3004
  return (
2751
3005
  cpu_has_amx_support()
2752
- and dtype in [torch.bfloat16, torch.float]
3006
+ and dtype in [torch.bfloat16, torch.float16, torch.float]
2753
3007
  and world_size >= 1
2754
3008
  and world_size == local_size
2755
3009
  )
@@ -2800,10 +3054,6 @@ def lru_cache_frozenset(maxsize=128):
2800
3054
  return decorator
2801
3055
 
2802
3056
 
2803
- def get_origin_rid(rid):
2804
- return rid.split("_", 1)[1] if "_" in rid else rid
2805
-
2806
-
2807
3057
  def apply_module_patch(target_module, target_function, wrappers):
2808
3058
  original_module, original_function = parse_module_path(
2809
3059
  target_module, target_function, False
@@ -3042,6 +3292,44 @@ def check_cuda_result(raw_output):
3042
3292
  return results
3043
3293
 
3044
3294
 
3295
+ def get_physical_device_id(pytorch_device_id: int) -> int:
3296
+ """
3297
+ Convert PyTorch logical device ID to physical device ID.
3298
+ """
3299
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
3300
+ assert (
3301
+ cuda_visible_devices is not None
3302
+ ), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
3303
+ device_list = cuda_visible_devices.split(",")
3304
+ assert (
3305
+ len(device_list) == 1
3306
+ ), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
3307
+ return int(device_list[0])
3308
+
3309
+
3310
+ def get_device_sm_nvidia_smi():
3311
+ try:
3312
+ # Run nvidia-smi command and capture output
3313
+ result = subprocess.run(
3314
+ ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
3315
+ capture_output=True,
3316
+ text=True,
3317
+ check=True,
3318
+ )
3319
+
3320
+ # Get the first line of output (assuming at least one GPU exists)
3321
+ compute_cap_str = result.stdout.strip().split("\n")[0]
3322
+
3323
+ # Convert string (e.g., "9.0") to tuple of integers (9, 0)
3324
+ major, minor = map(int, compute_cap_str.split("."))
3325
+ return (major, minor)
3326
+
3327
+ except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
3328
+ # Handle cases where nvidia-smi isn't available or output is unexpected
3329
+ print(f"Error getting compute capability: {e}")
3330
+ return (0, 0) # Default/fallback value
3331
+
3332
+
3045
3333
  def numa_bind_to_node(node: int):
3046
3334
  libnuma = ctypes.CDLL("libnuma.so")
3047
3335
  if libnuma.numa_available() < 0:
@@ -3053,8 +3341,190 @@ def numa_bind_to_node(node: int):
3053
3341
 
3054
3342
  def json_list_type(value):
3055
3343
  try:
3056
- return json.loads(value)
3344
+ return orjson.loads(value)
3057
3345
  except json.JSONDecodeError:
3058
3346
  raise argparse.ArgumentTypeError(
3059
3347
  f"Invalid JSON list: {value}. Please provide a valid JSON list."
3060
3348
  )
3349
+
3350
+
3351
+ @contextmanager
3352
+ def maybe_reindex_device_id(gpu_id: int):
3353
+
3354
+ if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() is False or not is_cuda_alike():
3355
+ yield gpu_id
3356
+ return
3357
+
3358
+ original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
3359
+ if original_cuda_visible_devices:
3360
+ cuda_visible_devices = original_cuda_visible_devices.split(",")
3361
+ else:
3362
+ cuda_visible_devices = []
3363
+
3364
+ str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
3365
+ os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
3366
+
3367
+ logger.debug(f"Set CUDA_VISIBLE_DEVICES to {str_gpu_id}")
3368
+
3369
+ yield 0
3370
+
3371
+ if original_cuda_visible_devices:
3372
+ os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
3373
+ else:
3374
+ del os.environ["CUDA_VISIBLE_DEVICES"]
3375
+
3376
+
3377
+ def get_extend_input_len_swa_limit(
3378
+ sliding_window_size: int, chunked_prefill_size: int, page_size: int
3379
+ ) -> int:
3380
+ # 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
3381
+ # and between prefills, we run swa_radix_cache.cache_unfinished_req(),
3382
+ # so we unlock the previously locked nodes.
3383
+ # 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
3384
+ # in that case, each prefill contains chunked_prefill_size tokens,
3385
+ # and we can only free out-of-sliding-window kv indices after each prefill.
3386
+ # 3. page_size is because we want to have 1 token extra for generated tokens.
3387
+ return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
3388
+
3389
+
3390
+ def get_num_new_pages(
3391
+ seq_lens: torch.Tensor,
3392
+ page_size: int,
3393
+ prefix_lens: Optional[torch.Tensor] = None,
3394
+ decode: bool = False,
3395
+ ) -> torch.Tensor:
3396
+ """
3397
+ Get the number of new pages for the given prefix and sequence lengths.
3398
+ We use cpu tensors to avoid blocking kernel launch.
3399
+ """
3400
+ cpu_device = torch.device("cpu")
3401
+ assert seq_lens.device == cpu_device
3402
+
3403
+ if prefix_lens is None or decode:
3404
+ # NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
3405
+ assert decode
3406
+ return (seq_lens % page_size == 1).int().sum().item()
3407
+
3408
+ assert prefix_lens.device == cpu_device
3409
+ num_pages_after = (seq_lens + page_size - 1) // page_size
3410
+ num_pages_before = (prefix_lens + page_size - 1) // page_size
3411
+ num_new_pages = num_pages_after - num_pages_before
3412
+ sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
3413
+ return sum_num_new_pages.item()
3414
+
3415
+
3416
+ class CachedKernel:
3417
+ """
3418
+ Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
3419
+
3420
+ This wrapper caches compiled Triton kernels based on keys extracted by a
3421
+ user-provided key function to avoid redundant compilations.
3422
+ """
3423
+
3424
+ def __init__(self, fn, key_fn=None):
3425
+ self.fn = fn
3426
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
3427
+
3428
+ original_fn = fn.fn
3429
+ self.signature = inspect.signature(original_fn)
3430
+ self.param_names = tuple(self.signature.parameters.keys())
3431
+ self.num_args = len(self.param_names)
3432
+
3433
+ # Check that no parameters have default values
3434
+ for name, param in self.signature.parameters.items():
3435
+ assert (
3436
+ param.default is inspect.Parameter.empty
3437
+ ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
3438
+
3439
+ functools.update_wrapper(self, original_fn)
3440
+ self.kernel_cache = {}
3441
+
3442
+ # Store the key function
3443
+ self.key_fn = key_fn
3444
+
3445
+ def __getitem__(self, grid):
3446
+ """
3447
+ Index with grid to get a launcher function.
3448
+ Returns a launcher that will handle caching based on the key function.
3449
+ """
3450
+ assert (
3451
+ isinstance(grid, tuple) and len(grid) <= 3
3452
+ ), "Grid must be a tuple with at most 3 dimensions."
3453
+
3454
+ # Normalize grid once
3455
+ if len(grid) < 3:
3456
+ grid = grid + (1,) * (3 - len(grid))
3457
+
3458
+ def launcher(*args, **kwargs):
3459
+ cache_key = self.key_fn(args, kwargs)
3460
+
3461
+ cached_kernel = self.kernel_cache.get(cache_key)
3462
+
3463
+ if cached_kernel is None:
3464
+ # First time: compile and cache the kernel
3465
+ cached_kernel = self.fn[grid](*args, **kwargs)
3466
+ self.kernel_cache[cache_key] = cached_kernel
3467
+ return cached_kernel
3468
+ else:
3469
+ # Use cached kernel
3470
+ all_args = self._build_args(args, kwargs)
3471
+ cached_kernel[grid](*all_args)
3472
+ return cached_kernel
3473
+
3474
+ return launcher
3475
+
3476
+ def _build_args(self, args, kwargs):
3477
+ """
3478
+ Build the complete argument list for kernel invocation.
3479
+ """
3480
+ complete_args = list(args)
3481
+
3482
+ for i in range(len(args), self.num_args):
3483
+ name = self.param_names[i]
3484
+ value = kwargs.get(name, inspect.Parameter.empty)
3485
+ if value is not inspect.Parameter.empty:
3486
+ complete_args.append(value)
3487
+ else:
3488
+ raise ValueError(f"Missing argument: {name}")
3489
+
3490
+ return complete_args
3491
+
3492
+ def _clear_cache(self):
3493
+ """
3494
+ Clear the kernel cache for testing purposes.
3495
+ """
3496
+ self.kernel_cache.clear()
3497
+
3498
+
3499
+ def cached_triton_kernel(key_fn=None):
3500
+ """
3501
+ Decorator that enables key-based caching for Triton kernels using a key function.
3502
+
3503
+ It essentially bypasses Triton's built-in caching mechanism, allowing users to
3504
+ define their own caching strategy based on kernel parameters. This helps reduce
3505
+ the heavy overheads of Triton kernel launch when the kernel specialization dispatch
3506
+ is simple.
3507
+
3508
+ Usage:
3509
+ @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
3510
+ @triton.jit
3511
+ def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
3512
+ ...
3513
+
3514
+ # Invoke normally
3515
+ my_kernel[grid](x, y, BLOCK_SIZE=1024)
3516
+
3517
+ Args:
3518
+ key_fn: A function that takes (args, kwargs) and returns the cache key(s).
3519
+ The key can be a single value or a tuple of values.
3520
+
3521
+ Returns:
3522
+ A decorator that wraps the kernel with caching functionality.
3523
+
3524
+ Note: Kernels with default parameter values are not supported and will raise an assertion error.
3525
+ """
3526
+
3527
+ def decorator(fn):
3528
+ return CachedKernel(fn, key_fn)
3529
+
3530
+ return decorator