sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,14 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import argparse
18
19
  import asyncio
19
20
  import builtins
20
21
  import ctypes
21
22
  import dataclasses
22
23
  import functools
23
24
  import importlib
25
+ import inspect
24
26
  import io
25
27
  import ipaddress
26
28
  import itertools
@@ -81,11 +83,9 @@ from packaging import version as pkg_version
81
83
  from PIL import Image
82
84
  from starlette.routing import Mount
83
85
  from torch import nn
84
- from torch.func import functional_call
85
86
  from torch.library import Library
86
87
  from torch.profiler import ProfilerActivity, profile, record_function
87
88
  from torch.utils._contextlib import _DecoratorContextManager
88
- from triton.runtime.cache import FileCacheManager
89
89
  from typing_extensions import Literal
90
90
 
91
91
  from sglang.srt.metrics.func_timer import enable_func_timer
@@ -166,6 +166,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
166
166
  is_hopper_with_cuda_12_3 = lambda: _check(9)
167
167
 
168
168
 
169
+ @lru_cache(maxsize=1)
169
170
  def is_blackwell():
170
171
  if not is_cuda():
171
172
  return False
@@ -174,6 +175,8 @@ def is_blackwell():
174
175
 
175
176
  @lru_cache(maxsize=1)
176
177
  def is_sm100_supported(device=None) -> bool:
178
+ if not is_cuda_alike():
179
+ return False
177
180
  return (torch.cuda.get_device_capability(device)[0] == 10) and (
178
181
  torch.version.cuda >= "12.8"
179
182
  )
@@ -181,6 +184,8 @@ def is_sm100_supported(device=None) -> bool:
181
184
 
182
185
  @lru_cache(maxsize=1)
183
186
  def is_sm90_supported(device=None) -> bool:
187
+ if not is_cuda_alike():
188
+ return False
184
189
  return (torch.cuda.get_device_capability(device)[0] == 9) and (
185
190
  torch.version.cuda >= "12.3"
186
191
  )
@@ -190,6 +195,7 @@ _warned_bool_env_var_keys = set()
190
195
 
191
196
 
192
197
  def get_bool_env_var(name: str, default: str = "false") -> bool:
198
+ # FIXME: move your environment variable to sglang.srt.environ
193
199
  value = os.getenv(name, default)
194
200
  value = value.lower()
195
201
 
@@ -207,6 +213,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
207
213
 
208
214
 
209
215
  def get_int_env_var(name: str, default: int = 0) -> int:
216
+ # FIXME: move your environment variable to sglang.srt.environ
210
217
  value = os.getenv(name)
211
218
  if value is None or not value.strip():
212
219
  return default
@@ -230,8 +237,16 @@ except:
230
237
  is_intel_amx_backend_available = False
231
238
 
232
239
 
240
+ try:
241
+ # move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
242
+ # to support torch compile
243
+ is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
244
+ except:
245
+ is_amx_tile_supported = False
246
+
247
+
233
248
  def cpu_has_amx_support():
234
- return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
249
+ return is_amx_tile_supported and is_intel_amx_backend_available
235
250
 
236
251
 
237
252
  def use_intel_amx_backend(layer):
@@ -426,7 +441,9 @@ def get_available_gpu_memory(
426
441
 
427
442
  elif device == "cpu":
428
443
  # TODO: rename the variables in the current function to be not GPU specific
429
- free_gpu_memory = psutil.virtual_memory().available
444
+ total_free_memory = psutil.virtual_memory().available
445
+ n_numa_node: int = len(get_cpu_ids_by_node())
446
+ free_gpu_memory = round(total_free_memory / n_numa_node, 3)
430
447
  elif device == "npu":
431
448
  num_gpus = torch.npu.device_count()
432
449
  assert gpu_id < num_gpus
@@ -454,7 +471,7 @@ def is_pin_memory_available() -> bool:
454
471
 
455
472
  class LayerFn(Protocol):
456
473
 
457
- def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
474
+ def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
458
475
 
459
476
 
460
477
  def make_layers(
@@ -465,7 +482,7 @@ def make_layers(
465
482
  prefix: str = "",
466
483
  return_tuple: bool = False,
467
484
  offloader_kwargs: Dict[str, Any] = {},
468
- ) -> Tuple[int, int, torch.nn.ModuleList]:
485
+ ) -> Tuple[torch.nn.Module, int, int]:
469
486
  """Make a list of layers with the given layer function"""
470
487
  # circula imports
471
488
  from sglang.srt.distributed import get_pp_indices
@@ -501,6 +518,68 @@ def make_layers(
501
518
  return modules, start_layer, end_layer
502
519
 
503
520
 
521
+ def make_layers_non_pp(
522
+ num_hidden_layers: int,
523
+ layer_fn: LayerFn,
524
+ prefix: str = "",
525
+ ) -> torch.nn.ModuleList:
526
+ from sglang.srt.offloader import get_offloader
527
+
528
+ layers = torch.nn.ModuleList(
529
+ get_offloader().wrap_modules(
530
+ (
531
+ layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
532
+ for idx in range(num_hidden_layers)
533
+ )
534
+ )
535
+ )
536
+ return layers
537
+
538
+
539
+ cmo_stream = None
540
+
541
+
542
+ def get_cmo_stream():
543
+ """
544
+ Cache Management Operation(CMO).
545
+ Launch a new stream to prefetch the weight of matmul when running other
546
+ AIV or communication kernels, aiming to overlap the memory access time.
547
+ """
548
+ global cmo_stream
549
+ if cmo_stream is None:
550
+ cmo_stream = torch.get_device_module().Stream()
551
+ return cmo_stream
552
+
553
+
554
+ def prepare_weight_cache(handle, cache):
555
+ import torch_npu
556
+
557
+ NPU_PREFETCH_MAX_SIZE_BYTES = (
558
+ 1000000000 # 1GB, a large value to prefetch entire weight
559
+ )
560
+ stream = get_cmo_stream()
561
+ stream.wait_stream(torch.npu.current_stream())
562
+ with torch.npu.stream(stream):
563
+ if isinstance(cache, list):
564
+ for weight in cache:
565
+ torch_npu.npu_prefetch(
566
+ weight,
567
+ handle,
568
+ NPU_PREFETCH_MAX_SIZE_BYTES,
569
+ )
570
+ else:
571
+ torch_npu.npu_prefetch(
572
+ cache,
573
+ handle,
574
+ NPU_PREFETCH_MAX_SIZE_BYTES,
575
+ )
576
+
577
+
578
+ def wait_cmo_stream():
579
+ cur_stream = torch.get_device_module().current_stream()
580
+ cur_stream.wait_stream(get_cmo_stream())
581
+
582
+
504
583
  def set_random_seed(seed: int) -> None:
505
584
  """Set the random seed for all libraries."""
506
585
  random.seed(seed)
@@ -738,6 +817,25 @@ def load_image(
738
817
  return image, image_size
739
818
 
740
819
 
820
+ def get_image_bytes(image_file: Union[str, bytes]):
821
+ if isinstance(image_file, bytes):
822
+ return image_file
823
+ elif image_file.startswith("http://") or image_file.startswith("https://"):
824
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
825
+ response = requests.get(image_file, timeout=timeout)
826
+ return response.content
827
+ elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
828
+ with open(image_file, "rb") as f:
829
+ return f.read()
830
+ elif image_file.startswith("data:"):
831
+ image_file = image_file.split(",")[1]
832
+ return pybase64.b64decode(image_file)
833
+ elif isinstance(image_file, str):
834
+ return pybase64.b64decode(image_file)
835
+ else:
836
+ raise NotImplementedError(f"Invalid image: {image_file}")
837
+
838
+
741
839
  def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
742
840
  # We import decord here to avoid a strange Segmentation fault (core dumped) issue.
743
841
  from decord import VideoReader, cpu, gpu
@@ -793,6 +891,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
793
891
  os.unlink(tmp_file.name)
794
892
 
795
893
 
894
+ def encode_video(video_path, frame_count_limit=None):
895
+ # Lazy import because decord is not available on some arm platforms.
896
+ from decord import VideoReader, cpu
897
+
898
+ if not os.path.exists(video_path):
899
+ logger.error(f"Video {video_path} does not exist")
900
+ return []
901
+
902
+ if frame_count_limit == 0:
903
+ return []
904
+
905
+ def uniform_sample(l, n):
906
+ gap = len(l) / n
907
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
908
+ return [l[i] for i in idxs]
909
+
910
+ vr = VideoReader(video_path, ctx=cpu(0))
911
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
912
+ frame_indices = [i for i in range(0, len(vr), sample_fps)]
913
+ if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
914
+ frame_indices = uniform_sample(frame_indices, frame_count_limit)
915
+
916
+ frames = vr.get_batch(frame_indices).asnumpy()
917
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
918
+ return frames
919
+
920
+
796
921
  def suppress_other_loggers():
797
922
  warnings.filterwarnings(
798
923
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -935,6 +1060,13 @@ def set_ulimit(target_soft_limit=65535):
935
1060
  logger.warning(f"Fail to set RLIMIT_STACK: {e}")
936
1061
 
937
1062
 
1063
+ def rank0_log(msg: str):
1064
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
1065
+
1066
+ if get_tensor_model_parallel_rank() == 0:
1067
+ logger.info(msg)
1068
+
1069
+
938
1070
  def add_api_key_middleware(app, api_key: str):
939
1071
  @app.middleware("http")
940
1072
  async def authentication(request, call_next):
@@ -1149,7 +1281,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
1149
1281
 
1150
1282
  def get_zmq_socket(
1151
1283
  context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
1152
- ):
1284
+ ) -> zmq.Socket:
1153
1285
  mem = psutil.virtual_memory()
1154
1286
  total_mem = mem.total / 1024**3
1155
1287
  available_mem = mem.available / 1024**3
@@ -1393,6 +1525,32 @@ def get_npu_memory_capacity():
1393
1525
  raise ImportError("torch_npu is required when run on npu device.")
1394
1526
 
1395
1527
 
1528
+ def get_cpu_memory_capacity():
1529
+ # Per-rank memory capacity cannot be determined for customized core settings
1530
+ if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
1531
+ return None
1532
+ n_numa_node: int = len(get_cpu_ids_by_node())
1533
+ if n_numa_node == 0:
1534
+ # Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
1535
+ return float(psutil.virtual_memory().total // (1 << 20))
1536
+ try:
1537
+ numa_mem_list = list()
1538
+ file_prefix = "/sys/devices/system/node/"
1539
+ for numa_id in range(n_numa_node):
1540
+ file_meminfo = f"node{numa_id}/meminfo"
1541
+ with open(os.path.join(file_prefix, file_meminfo), "r") as f:
1542
+ # 1st line contains 'MemTotal'
1543
+ line = f.read().split("\n")[0]
1544
+ numa_mem_list.append(int(line.split()[3]))
1545
+ # Retrieved value in KB, need MB
1546
+ numa_mem = float(min(numa_mem_list) // 1024)
1547
+ return numa_mem
1548
+ except FileNotFoundError:
1549
+ numa_mem = psutil.virtual_memory().total / n_numa_node
1550
+ # Retrieved value in Byte, need MB
1551
+ return float(numa_mem // (1 << 20))
1552
+
1553
+
1396
1554
  def get_device_memory_capacity(device: str = None):
1397
1555
  if is_cuda():
1398
1556
  gpu_mem = get_nvgpu_memory_capacity()
@@ -1402,6 +1560,8 @@ def get_device_memory_capacity(device: str = None):
1402
1560
  gpu_mem = get_hpu_memory_capacity()
1403
1561
  elif device == "npu":
1404
1562
  gpu_mem = get_npu_memory_capacity()
1563
+ elif device == "cpu":
1564
+ gpu_mem = get_cpu_memory_capacity()
1405
1565
  else:
1406
1566
  # GPU memory is not known yet or no GPU is available.
1407
1567
  gpu_mem = None
@@ -1421,6 +1581,7 @@ def init_custom_process_group(
1421
1581
  store=None,
1422
1582
  group_name=None,
1423
1583
  pg_options=None,
1584
+ device_id=None,
1424
1585
  ):
1425
1586
  from torch.distributed.distributed_c10d import (
1426
1587
  Backend,
@@ -1474,6 +1635,7 @@ def init_custom_process_group(
1474
1635
  group_name=group_name,
1475
1636
  **{pg_options_param_name: pg_options},
1476
1637
  timeout=timeout,
1638
+ device_id=device_id,
1477
1639
  )
1478
1640
 
1479
1641
  _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
@@ -1938,50 +2100,6 @@ def set_uvicorn_logging_configs():
1938
2100
  LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1939
2101
 
1940
2102
 
1941
- def get_ip() -> str:
1942
- # SGLANG_HOST_IP env can be ignore
1943
- host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
1944
- if host_ip:
1945
- return host_ip
1946
-
1947
- # IP is not set, try to get it from the network interface
1948
-
1949
- # try ipv4
1950
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1951
- try:
1952
- s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
1953
- return s.getsockname()[0]
1954
- except Exception:
1955
- pass
1956
-
1957
- # try ipv6
1958
- try:
1959
- s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
1960
- # Google's public DNS server, see
1961
- # https://developers.google.com/speed/public-dns/docs/using#addresses
1962
- s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
1963
- return s.getsockname()[0]
1964
- except Exception:
1965
- pass
1966
-
1967
- # try using hostname
1968
- hostname = socket.gethostname()
1969
- try:
1970
- ip_addr = socket.gethostbyname(hostname)
1971
- warnings.warn("using local ip address: {}".format(ip_addr))
1972
- return ip_addr
1973
- except Exception:
1974
- pass
1975
-
1976
- warnings.warn(
1977
- "Failed to get the IP address, using 0.0.0.0 by default."
1978
- "The value can be set by the environment variable"
1979
- " SGLANG_HOST_IP or HOST_IP.",
1980
- stacklevel=2,
1981
- )
1982
- return "0.0.0.0"
1983
-
1984
-
1985
2103
  def get_open_port() -> int:
1986
2104
  port = os.getenv("SGLANG_PORT")
1987
2105
  if port is not None:
@@ -2238,16 +2356,9 @@ def bind_or_assign(target, source):
2238
2356
  return source
2239
2357
 
2240
2358
 
2241
- def get_local_ip_auto() -> str:
2242
- interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
2243
- return (
2244
- get_local_ip_by_nic(interface)
2245
- if interface is not None
2246
- else get_local_ip_by_remote()
2247
- )
2248
-
2249
-
2250
- def get_local_ip_by_nic(interface: str) -> str:
2359
+ def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
2360
+ if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
2361
+ return None
2251
2362
  try:
2252
2363
  import netifaces
2253
2364
  except ImportError as e:
@@ -2268,15 +2379,13 @@ def get_local_ip_by_nic(interface: str) -> str:
2268
2379
  if ip and not ip.startswith("fe80::") and ip != "::1":
2269
2380
  return ip.split("%")[0]
2270
2381
  except (ValueError, OSError) as e:
2271
- raise ValueError(
2272
- "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2382
+ logger.warning(
2383
+ f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2273
2384
  )
2274
-
2275
- # Fallback
2276
- return get_local_ip_by_remote()
2385
+ return None
2277
2386
 
2278
2387
 
2279
- def get_local_ip_by_remote() -> str:
2388
+ def get_local_ip_by_remote() -> Optional[str]:
2280
2389
  # try ipv4
2281
2390
  s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
2282
2391
  try:
@@ -2301,7 +2410,51 @@ def get_local_ip_by_remote() -> str:
2301
2410
  s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
2302
2411
  return s.getsockname()[0]
2303
2412
  except Exception:
2304
- raise ValueError("Can not get local ip")
2413
+ logger.warning("Can not get local ip by remote")
2414
+ return None
2415
+
2416
+
2417
+ def get_local_ip_auto(fallback: str = None) -> str:
2418
+ """
2419
+ Automatically detect the local IP address using multiple fallback strategies.
2420
+
2421
+ This function attempts to obtain the local IP address through several methods.
2422
+ If all methods fail, it returns the specified fallback value or raises an exception.
2423
+
2424
+ Args:
2425
+ fallback (str, optional): Fallback IP address to return if all detection
2426
+ methods fail. For server applications, explicitly set this to
2427
+ "0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
2428
+ Defaults to None.
2429
+
2430
+ Returns:
2431
+ str: The detected local IP address, or the fallback value if detection fails.
2432
+
2433
+ Raises:
2434
+ ValueError: If IP detection fails and no fallback value is provided.
2435
+
2436
+ Note:
2437
+ The function tries detection methods in the following order:
2438
+ 1. Direct IP detection via get_ip()
2439
+ 2. Network interface enumeration via get_local_ip_by_nic()
2440
+ 3. Remote connection method via get_local_ip_by_remote()
2441
+ """
2442
+ # Try environment variable
2443
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
2444
+ if host_ip:
2445
+ return host_ip
2446
+ logger.debug("get_ip failed")
2447
+ # Fallback
2448
+ if ip := get_local_ip_by_nic():
2449
+ return ip
2450
+ logger.debug("get_local_ip_by_nic failed")
2451
+ # Fallback
2452
+ if ip := get_local_ip_by_remote():
2453
+ return ip
2454
+ logger.debug("get_local_ip_by_remote failed")
2455
+ if fallback:
2456
+ return fallback
2457
+ raise ValueError("Can not get local ip")
2305
2458
 
2306
2459
 
2307
2460
  def is_page_size_one(server_args):
@@ -2353,7 +2506,7 @@ class BumpAllocator:
2353
2506
  def log_info_on_rank0(logger, msg):
2354
2507
  from sglang.srt.distributed import get_tensor_model_parallel_rank
2355
2508
 
2356
- if get_tensor_model_parallel_rank() == 0:
2509
+ if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
2357
2510
  logger.info(msg)
2358
2511
 
2359
2512
 
@@ -2483,14 +2636,6 @@ def read_system_prompt_from_file(model_name: str) -> str:
2483
2636
  return ""
2484
2637
 
2485
2638
 
2486
- def bind_or_assign(target, source):
2487
- if target is not None:
2488
- target.copy_(source)
2489
- return target
2490
- else:
2491
- return source
2492
-
2493
-
2494
2639
  def prepack_weight_if_needed(weight):
2495
2640
  if weight.device != torch.device("cpu"):
2496
2641
  return weight
@@ -3027,3 +3172,232 @@ def check_cuda_result(raw_output):
3027
3172
  raise Exception(f"CUDA error: {err}")
3028
3173
 
3029
3174
  return results
3175
+
3176
+
3177
+ def get_physical_device_id(pytorch_device_id: int) -> int:
3178
+ """
3179
+ Convert PyTorch logical device ID to physical device ID.
3180
+ """
3181
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
3182
+ assert (
3183
+ cuda_visible_devices is not None
3184
+ ), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
3185
+ device_list = cuda_visible_devices.split(",")
3186
+ assert (
3187
+ len(device_list) == 1
3188
+ ), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
3189
+ return int(device_list[0])
3190
+
3191
+
3192
+ def get_device_sm_nvidia_smi():
3193
+ try:
3194
+ # Run nvidia-smi command and capture output
3195
+ result = subprocess.run(
3196
+ ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
3197
+ capture_output=True,
3198
+ text=True,
3199
+ check=True,
3200
+ )
3201
+
3202
+ # Get the first line of output (assuming at least one GPU exists)
3203
+ compute_cap_str = result.stdout.strip().split("\n")[0]
3204
+
3205
+ # Convert string (e.g., "9.0") to tuple of integers (9, 0)
3206
+ major, minor = map(int, compute_cap_str.split("."))
3207
+ return (major, minor)
3208
+
3209
+ except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
3210
+ # Handle cases where nvidia-smi isn't available or output is unexpected
3211
+ print(f"Error getting compute capability: {e}")
3212
+ return (0, 0) # Default/fallback value
3213
+
3214
+
3215
+ def numa_bind_to_node(node: int):
3216
+ libnuma = ctypes.CDLL("libnuma.so")
3217
+ if libnuma.numa_available() < 0:
3218
+ raise SystemError("numa not available on this system")
3219
+
3220
+ libnuma.numa_run_on_node(ctypes.c_int(node))
3221
+ libnuma.numa_set_localalloc()
3222
+
3223
+
3224
+ def json_list_type(value):
3225
+ try:
3226
+ return json.loads(value)
3227
+ except json.JSONDecodeError:
3228
+ raise argparse.ArgumentTypeError(
3229
+ f"Invalid JSON list: {value}. Please provide a valid JSON list."
3230
+ )
3231
+
3232
+
3233
+ @contextmanager
3234
+ def temp_set_cuda_visible_devices(gpu_id: int):
3235
+ original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
3236
+ if original_cuda_visible_devices:
3237
+ cuda_visible_devices = original_cuda_visible_devices.split(",")
3238
+ else:
3239
+ cuda_visible_devices = []
3240
+
3241
+ str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
3242
+ os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
3243
+ yield
3244
+ if original_cuda_visible_devices:
3245
+ os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
3246
+ else:
3247
+ del os.environ["CUDA_VISIBLE_DEVICES"]
3248
+
3249
+
3250
+ def get_extend_input_len_swa_limit(
3251
+ sliding_window_size: int, chunked_prefill_size: int, page_size: int
3252
+ ) -> int:
3253
+ # 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
3254
+ # and between prefills, we run swa_radix_cache.cache_unfinished_req(),
3255
+ # so we unlock the previously locked nodes.
3256
+ # 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
3257
+ # in that case, each prefill contains chunked_prefill_size tokens,
3258
+ # and we can only free out-of-sliding-window kv indices after each prefill.
3259
+ # 3. page_size is because we want to have 1 token extra for generated tokens.
3260
+ return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
3261
+
3262
+
3263
+ def get_num_new_pages(
3264
+ seq_lens: torch.Tensor,
3265
+ page_size: int,
3266
+ prefix_lens: Optional[torch.Tensor] = None,
3267
+ decode: bool = False,
3268
+ ) -> torch.Tensor:
3269
+ """
3270
+ Get the number of new pages for the given prefix and sequence lengths.
3271
+ We use cpu tensors to avoid blocking kernel launch.
3272
+ """
3273
+ cpu_device = torch.device("cpu")
3274
+ assert seq_lens.device == cpu_device
3275
+
3276
+ if prefix_lens is None or decode:
3277
+ # NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
3278
+ assert decode
3279
+ return (seq_lens % page_size == 1).int().sum().item()
3280
+
3281
+ assert prefix_lens.device == cpu_device
3282
+ num_pages_after = (seq_lens + page_size - 1) // page_size
3283
+ num_pages_before = (prefix_lens + page_size - 1) // page_size
3284
+ num_new_pages = num_pages_after - num_pages_before
3285
+ sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
3286
+ return sum_num_new_pages.item()
3287
+
3288
+
3289
+ class CachedKernel:
3290
+ """
3291
+ Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
3292
+
3293
+ This wrapper caches compiled Triton kernels based on keys extracted by a
3294
+ user-provided key function to avoid redundant compilations.
3295
+ """
3296
+
3297
+ def __init__(self, fn, key_fn=None):
3298
+ self.fn = fn
3299
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
3300
+
3301
+ original_fn = fn.fn
3302
+ self.signature = inspect.signature(original_fn)
3303
+ self.param_names = tuple(self.signature.parameters.keys())
3304
+ self.num_args = len(self.param_names)
3305
+
3306
+ # Check that no parameters have default values
3307
+ for name, param in self.signature.parameters.items():
3308
+ assert (
3309
+ param.default is inspect.Parameter.empty
3310
+ ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
3311
+
3312
+ functools.update_wrapper(self, original_fn)
3313
+ self.kernel_cache = {}
3314
+
3315
+ # Store the key function
3316
+ self.key_fn = key_fn
3317
+
3318
+ def __getitem__(self, grid):
3319
+ """
3320
+ Index with grid to get a launcher function.
3321
+ Returns a launcher that will handle caching based on the key function.
3322
+ """
3323
+ assert (
3324
+ isinstance(grid, tuple) and len(grid) <= 3
3325
+ ), "Grid must be a tuple with at most 3 dimensions."
3326
+
3327
+ # Normalize grid once
3328
+ if len(grid) < 3:
3329
+ grid = grid + (1,) * (3 - len(grid))
3330
+
3331
+ def launcher(*args, **kwargs):
3332
+ cache_key = self.key_fn(args, kwargs)
3333
+
3334
+ cached_kernel = self.kernel_cache.get(cache_key)
3335
+
3336
+ if cached_kernel is None:
3337
+ # First time: compile and cache the kernel
3338
+ cached_kernel = self.fn[grid](*args, **kwargs)
3339
+ self.kernel_cache[cache_key] = cached_kernel
3340
+ return cached_kernel
3341
+ else:
3342
+ # Use cached kernel
3343
+ all_args = self._build_args(args, kwargs)
3344
+ cached_kernel[grid](*all_args)
3345
+ return cached_kernel
3346
+
3347
+ return launcher
3348
+
3349
+ def _build_args(self, args, kwargs):
3350
+ """
3351
+ Build the complete argument list for kernel invocation.
3352
+ """
3353
+ complete_args = list(args)
3354
+
3355
+ for i in range(len(args), self.num_args):
3356
+ name = self.param_names[i]
3357
+ value = kwargs.get(name, inspect.Parameter.empty)
3358
+ if value is not inspect.Parameter.empty:
3359
+ complete_args.append(value)
3360
+ else:
3361
+ raise ValueError(f"Missing argument: {name}")
3362
+
3363
+ return complete_args
3364
+
3365
+ def _clear_cache(self):
3366
+ """
3367
+ Clear the kernel cache for testing purposes.
3368
+ """
3369
+ self.kernel_cache.clear()
3370
+
3371
+
3372
+ def cached_triton_kernel(key_fn=None):
3373
+ """
3374
+ Decorator that enables key-based caching for Triton kernels using a key function.
3375
+
3376
+ It essentially bypasses Triton's built-in caching mechanism, allowing users to
3377
+ define their own caching strategy based on kernel parameters. This helps reduce
3378
+ the heavy overheads of Triton kernel launch when the kernel specialization dispatch
3379
+ is simple.
3380
+
3381
+ Usage:
3382
+ @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
3383
+ @triton.jit
3384
+ def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
3385
+ ...
3386
+
3387
+ # Invoke normally
3388
+ my_kernel[grid](x, y, BLOCK_SIZE=1024)
3389
+
3390
+ Args:
3391
+ key_fn: A function that takes (args, kwargs) and returns the cache key(s).
3392
+ The key can be a single value or a tuple of values.
3393
+
3394
+ Returns:
3395
+ A decorator that wraps the kernel with caching functionality.
3396
+
3397
+ Note: Kernels with default parameter values are not supported and will raise an assertion error.
3398
+ """
3399
+
3400
+ def decorator(fn):
3401
+ return CachedKernel(fn, key_fn)
3402
+
3403
+ return decorator