sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,14 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import argparse
18
19
  import asyncio
19
20
  import builtins
20
21
  import ctypes
21
22
  import dataclasses
22
23
  import functools
23
24
  import importlib
25
+ import inspect
24
26
  import io
25
27
  import ipaddress
26
28
  import itertools
@@ -81,11 +83,9 @@ from packaging import version as pkg_version
81
83
  from PIL import Image
82
84
  from starlette.routing import Mount
83
85
  from torch import nn
84
- from torch.func import functional_call
85
86
  from torch.library import Library
86
87
  from torch.profiler import ProfilerActivity, profile, record_function
87
88
  from torch.utils._contextlib import _DecoratorContextManager
88
- from triton.runtime.cache import FileCacheManager
89
89
  from typing_extensions import Literal
90
90
 
91
91
  from sglang.srt.metrics.func_timer import enable_func_timer
@@ -166,6 +166,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
166
166
  is_hopper_with_cuda_12_3 = lambda: _check(9)
167
167
 
168
168
 
169
+ @lru_cache(maxsize=1)
169
170
  def is_blackwell():
170
171
  if not is_cuda():
171
172
  return False
@@ -174,6 +175,8 @@ def is_blackwell():
174
175
 
175
176
  @lru_cache(maxsize=1)
176
177
  def is_sm100_supported(device=None) -> bool:
178
+ if not is_cuda_alike():
179
+ return False
177
180
  return (torch.cuda.get_device_capability(device)[0] == 10) and (
178
181
  torch.version.cuda >= "12.8"
179
182
  )
@@ -181,6 +184,8 @@ def is_sm100_supported(device=None) -> bool:
181
184
 
182
185
  @lru_cache(maxsize=1)
183
186
  def is_sm90_supported(device=None) -> bool:
187
+ if not is_cuda_alike():
188
+ return False
184
189
  return (torch.cuda.get_device_capability(device)[0] == 9) and (
185
190
  torch.version.cuda >= "12.3"
186
191
  )
@@ -190,6 +195,7 @@ _warned_bool_env_var_keys = set()
190
195
 
191
196
 
192
197
  def get_bool_env_var(name: str, default: str = "false") -> bool:
198
+ # FIXME: move your environment variable to sglang.srt.environ
193
199
  value = os.getenv(name, default)
194
200
  value = value.lower()
195
201
 
@@ -207,6 +213,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
207
213
 
208
214
 
209
215
  def get_int_env_var(name: str, default: int = 0) -> int:
216
+ # FIXME: move your environment variable to sglang.srt.environ
210
217
  value = os.getenv(name)
211
218
  if value is None or not value.strip():
212
219
  return default
@@ -230,8 +237,16 @@ except:
230
237
  is_intel_amx_backend_available = False
231
238
 
232
239
 
240
+ try:
241
+ # move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
242
+ # to support torch compile
243
+ is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
244
+ except:
245
+ is_amx_tile_supported = False
246
+
247
+
233
248
  def cpu_has_amx_support():
234
- return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
249
+ return is_amx_tile_supported and is_intel_amx_backend_available
235
250
 
236
251
 
237
252
  def use_intel_amx_backend(layer):
@@ -426,7 +441,9 @@ def get_available_gpu_memory(
426
441
 
427
442
  elif device == "cpu":
428
443
  # TODO: rename the variables in the current function to be not GPU specific
429
- free_gpu_memory = psutil.virtual_memory().available
444
+ total_free_memory = psutil.virtual_memory().available
445
+ n_numa_node: int = len(get_cpu_ids_by_node())
446
+ free_gpu_memory = round(total_free_memory / n_numa_node, 3)
430
447
  elif device == "npu":
431
448
  num_gpus = torch.npu.device_count()
432
449
  assert gpu_id < num_gpus
@@ -454,7 +471,7 @@ def is_pin_memory_available() -> bool:
454
471
 
455
472
  class LayerFn(Protocol):
456
473
 
457
- def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
474
+ def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
458
475
 
459
476
 
460
477
  def make_layers(
@@ -465,7 +482,7 @@ def make_layers(
465
482
  prefix: str = "",
466
483
  return_tuple: bool = False,
467
484
  offloader_kwargs: Dict[str, Any] = {},
468
- ) -> Tuple[int, int, torch.nn.ModuleList]:
485
+ ) -> Tuple[torch.nn.Module, int, int]:
469
486
  """Make a list of layers with the given layer function"""
470
487
  # circula imports
471
488
  from sglang.srt.distributed import get_pp_indices
@@ -501,6 +518,50 @@ def make_layers(
501
518
  return modules, start_layer, end_layer
502
519
 
503
520
 
521
+ cmo_stream = None
522
+
523
+
524
+ def get_cmo_stream():
525
+ """
526
+ Cache Management Operation(CMO).
527
+ Launch a new stream to prefetch the weight of matmul when running other
528
+ AIV or communication kernels, aiming to overlap the memory access time.
529
+ """
530
+ global cmo_stream
531
+ if cmo_stream is None:
532
+ cmo_stream = torch.get_device_module().Stream()
533
+ return cmo_stream
534
+
535
+
536
+ def prepare_weight_cache(handle, cache):
537
+ import torch_npu
538
+
539
+ NPU_PREFETCH_MAX_SIZE_BYTES = (
540
+ 1000000000 # 1GB, a large value to prefetch entire weight
541
+ )
542
+ stream = get_cmo_stream()
543
+ stream.wait_stream(torch.npu.current_stream())
544
+ with torch.npu.stream(stream):
545
+ if isinstance(cache, list):
546
+ for weight in cache:
547
+ torch_npu.npu_prefetch(
548
+ weight,
549
+ handle,
550
+ NPU_PREFETCH_MAX_SIZE_BYTES,
551
+ )
552
+ else:
553
+ torch_npu.npu_prefetch(
554
+ cache,
555
+ handle,
556
+ NPU_PREFETCH_MAX_SIZE_BYTES,
557
+ )
558
+
559
+
560
+ def wait_cmo_stream():
561
+ cur_stream = torch.get_device_module().current_stream()
562
+ cur_stream.wait_stream(get_cmo_stream())
563
+
564
+
504
565
  def set_random_seed(seed: int) -> None:
505
566
  """Set the random seed for all libraries."""
506
567
  random.seed(seed)
@@ -738,6 +799,25 @@ def load_image(
738
799
  return image, image_size
739
800
 
740
801
 
802
+ def get_image_bytes(image_file: Union[str, bytes]):
803
+ if isinstance(image_file, bytes):
804
+ return image_file
805
+ elif image_file.startswith("http://") or image_file.startswith("https://"):
806
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
807
+ response = requests.get(image_file, timeout=timeout)
808
+ return response.content
809
+ elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
810
+ with open(image_file, "rb") as f:
811
+ return f.read()
812
+ elif image_file.startswith("data:"):
813
+ image_file = image_file.split(",")[1]
814
+ return pybase64.b64decode(image_file)
815
+ elif isinstance(image_file, str):
816
+ return pybase64.b64decode(image_file)
817
+ else:
818
+ raise NotImplementedError(f"Invalid image: {image_file}")
819
+
820
+
741
821
  def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
742
822
  # We import decord here to avoid a strange Segmentation fault (core dumped) issue.
743
823
  from decord import VideoReader, cpu, gpu
@@ -793,6 +873,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
793
873
  os.unlink(tmp_file.name)
794
874
 
795
875
 
876
+ def encode_video(video_path, frame_count_limit=None):
877
+ # Lazy import because decord is not available on some arm platforms.
878
+ from decord import VideoReader, cpu
879
+
880
+ if not os.path.exists(video_path):
881
+ logger.error(f"Video {video_path} does not exist")
882
+ return []
883
+
884
+ if frame_count_limit == 0:
885
+ return []
886
+
887
+ def uniform_sample(l, n):
888
+ gap = len(l) / n
889
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
890
+ return [l[i] for i in idxs]
891
+
892
+ vr = VideoReader(video_path, ctx=cpu(0))
893
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
894
+ frame_indices = [i for i in range(0, len(vr), sample_fps)]
895
+ if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
896
+ frame_indices = uniform_sample(frame_indices, frame_count_limit)
897
+
898
+ frames = vr.get_batch(frame_indices).asnumpy()
899
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
900
+ return frames
901
+
902
+
796
903
  def suppress_other_loggers():
797
904
  warnings.filterwarnings(
798
905
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -935,6 +1042,13 @@ def set_ulimit(target_soft_limit=65535):
935
1042
  logger.warning(f"Fail to set RLIMIT_STACK: {e}")
936
1043
 
937
1044
 
1045
+ def rank0_log(msg: str):
1046
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
1047
+
1048
+ if get_tensor_model_parallel_rank() == 0:
1049
+ logger.info(msg)
1050
+
1051
+
938
1052
  def add_api_key_middleware(app, api_key: str):
939
1053
  @app.middleware("http")
940
1054
  async def authentication(request, call_next):
@@ -1149,7 +1263,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
1149
1263
 
1150
1264
  def get_zmq_socket(
1151
1265
  context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
1152
- ):
1266
+ ) -> zmq.Socket:
1153
1267
  mem = psutil.virtual_memory()
1154
1268
  total_mem = mem.total / 1024**3
1155
1269
  available_mem = mem.available / 1024**3
@@ -1393,6 +1507,32 @@ def get_npu_memory_capacity():
1393
1507
  raise ImportError("torch_npu is required when run on npu device.")
1394
1508
 
1395
1509
 
1510
+ def get_cpu_memory_capacity():
1511
+ # Per-rank memory capacity cannot be determined for customized core settings
1512
+ if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
1513
+ return None
1514
+ n_numa_node: int = len(get_cpu_ids_by_node())
1515
+ if n_numa_node == 0:
1516
+ # Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
1517
+ return float(psutil.virtual_memory().total // (1 << 20))
1518
+ try:
1519
+ numa_mem_list = list()
1520
+ file_prefix = "/sys/devices/system/node/"
1521
+ for numa_id in range(n_numa_node):
1522
+ file_meminfo = f"node{numa_id}/meminfo"
1523
+ with open(os.path.join(file_prefix, file_meminfo), "r") as f:
1524
+ # 1st line contains 'MemTotal'
1525
+ line = f.read().split("\n")[0]
1526
+ numa_mem_list.append(int(line.split()[3]))
1527
+ # Retrieved value in KB, need MB
1528
+ numa_mem = float(min(numa_mem_list) // 1024)
1529
+ return numa_mem
1530
+ except FileNotFoundError:
1531
+ numa_mem = psutil.virtual_memory().total / n_numa_node
1532
+ # Retrieved value in Byte, need MB
1533
+ return float(numa_mem // (1 << 20))
1534
+
1535
+
1396
1536
  def get_device_memory_capacity(device: str = None):
1397
1537
  if is_cuda():
1398
1538
  gpu_mem = get_nvgpu_memory_capacity()
@@ -1402,6 +1542,8 @@ def get_device_memory_capacity(device: str = None):
1402
1542
  gpu_mem = get_hpu_memory_capacity()
1403
1543
  elif device == "npu":
1404
1544
  gpu_mem = get_npu_memory_capacity()
1545
+ elif device == "cpu":
1546
+ gpu_mem = get_cpu_memory_capacity()
1405
1547
  else:
1406
1548
  # GPU memory is not known yet or no GPU is available.
1407
1549
  gpu_mem = None
@@ -1421,6 +1563,7 @@ def init_custom_process_group(
1421
1563
  store=None,
1422
1564
  group_name=None,
1423
1565
  pg_options=None,
1566
+ device_id=None,
1424
1567
  ):
1425
1568
  from torch.distributed.distributed_c10d import (
1426
1569
  Backend,
@@ -1474,6 +1617,7 @@ def init_custom_process_group(
1474
1617
  group_name=group_name,
1475
1618
  **{pg_options_param_name: pg_options},
1476
1619
  timeout=timeout,
1620
+ device_id=device_id,
1477
1621
  )
1478
1622
 
1479
1623
  _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
@@ -1938,50 +2082,6 @@ def set_uvicorn_logging_configs():
1938
2082
  LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1939
2083
 
1940
2084
 
1941
- def get_ip() -> str:
1942
- # SGLANG_HOST_IP env can be ignore
1943
- host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
1944
- if host_ip:
1945
- return host_ip
1946
-
1947
- # IP is not set, try to get it from the network interface
1948
-
1949
- # try ipv4
1950
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1951
- try:
1952
- s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
1953
- return s.getsockname()[0]
1954
- except Exception:
1955
- pass
1956
-
1957
- # try ipv6
1958
- try:
1959
- s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
1960
- # Google's public DNS server, see
1961
- # https://developers.google.com/speed/public-dns/docs/using#addresses
1962
- s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
1963
- return s.getsockname()[0]
1964
- except Exception:
1965
- pass
1966
-
1967
- # try using hostname
1968
- hostname = socket.gethostname()
1969
- try:
1970
- ip_addr = socket.gethostbyname(hostname)
1971
- warnings.warn("using local ip address: {}".format(ip_addr))
1972
- return ip_addr
1973
- except Exception:
1974
- pass
1975
-
1976
- warnings.warn(
1977
- "Failed to get the IP address, using 0.0.0.0 by default."
1978
- "The value can be set by the environment variable"
1979
- " SGLANG_HOST_IP or HOST_IP.",
1980
- stacklevel=2,
1981
- )
1982
- return "0.0.0.0"
1983
-
1984
-
1985
2085
  def get_open_port() -> int:
1986
2086
  port = os.getenv("SGLANG_PORT")
1987
2087
  if port is not None:
@@ -2238,16 +2338,9 @@ def bind_or_assign(target, source):
2238
2338
  return source
2239
2339
 
2240
2340
 
2241
- def get_local_ip_auto() -> str:
2242
- interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
2243
- return (
2244
- get_local_ip_by_nic(interface)
2245
- if interface is not None
2246
- else get_local_ip_by_remote()
2247
- )
2248
-
2249
-
2250
- def get_local_ip_by_nic(interface: str) -> str:
2341
+ def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
2342
+ if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
2343
+ return None
2251
2344
  try:
2252
2345
  import netifaces
2253
2346
  except ImportError as e:
@@ -2268,15 +2361,13 @@ def get_local_ip_by_nic(interface: str) -> str:
2268
2361
  if ip and not ip.startswith("fe80::") and ip != "::1":
2269
2362
  return ip.split("%")[0]
2270
2363
  except (ValueError, OSError) as e:
2271
- raise ValueError(
2272
- "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2364
+ logger.warning(
2365
+ f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2273
2366
  )
2274
-
2275
- # Fallback
2276
- return get_local_ip_by_remote()
2367
+ return None
2277
2368
 
2278
2369
 
2279
- def get_local_ip_by_remote() -> str:
2370
+ def get_local_ip_by_remote() -> Optional[str]:
2280
2371
  # try ipv4
2281
2372
  s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
2282
2373
  try:
@@ -2301,7 +2392,51 @@ def get_local_ip_by_remote() -> str:
2301
2392
  s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
2302
2393
  return s.getsockname()[0]
2303
2394
  except Exception:
2304
- raise ValueError("Can not get local ip")
2395
+ logger.warning("Can not get local ip by remote")
2396
+ return None
2397
+
2398
+
2399
+ def get_local_ip_auto(fallback: str = None) -> str:
2400
+ """
2401
+ Automatically detect the local IP address using multiple fallback strategies.
2402
+
2403
+ This function attempts to obtain the local IP address through several methods.
2404
+ If all methods fail, it returns the specified fallback value or raises an exception.
2405
+
2406
+ Args:
2407
+ fallback (str, optional): Fallback IP address to return if all detection
2408
+ methods fail. For server applications, explicitly set this to
2409
+ "0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
2410
+ Defaults to None.
2411
+
2412
+ Returns:
2413
+ str: The detected local IP address, or the fallback value if detection fails.
2414
+
2415
+ Raises:
2416
+ ValueError: If IP detection fails and no fallback value is provided.
2417
+
2418
+ Note:
2419
+ The function tries detection methods in the following order:
2420
+ 1. Direct IP detection via get_ip()
2421
+ 2. Network interface enumeration via get_local_ip_by_nic()
2422
+ 3. Remote connection method via get_local_ip_by_remote()
2423
+ """
2424
+ # Try environment variable
2425
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
2426
+ if host_ip:
2427
+ return host_ip
2428
+ logger.debug("get_ip failed")
2429
+ # Fallback
2430
+ if ip := get_local_ip_by_nic():
2431
+ return ip
2432
+ logger.debug("get_local_ip_by_nic failed")
2433
+ # Fallback
2434
+ if ip := get_local_ip_by_remote():
2435
+ return ip
2436
+ logger.debug("get_local_ip_by_remote failed")
2437
+ if fallback:
2438
+ return fallback
2439
+ raise ValueError("Can not get local ip")
2305
2440
 
2306
2441
 
2307
2442
  def is_page_size_one(server_args):
@@ -2353,7 +2488,7 @@ class BumpAllocator:
2353
2488
  def log_info_on_rank0(logger, msg):
2354
2489
  from sglang.srt.distributed import get_tensor_model_parallel_rank
2355
2490
 
2356
- if get_tensor_model_parallel_rank() == 0:
2491
+ if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
2357
2492
  logger.info(msg)
2358
2493
 
2359
2494
 
@@ -2483,14 +2618,6 @@ def read_system_prompt_from_file(model_name: str) -> str:
2483
2618
  return ""
2484
2619
 
2485
2620
 
2486
- def bind_or_assign(target, source):
2487
- if target is not None:
2488
- target.copy_(source)
2489
- return target
2490
- else:
2491
- return source
2492
-
2493
-
2494
2621
  def prepack_weight_if_needed(weight):
2495
2622
  if weight.device != torch.device("cpu"):
2496
2623
  return weight
@@ -3027,3 +3154,232 @@ def check_cuda_result(raw_output):
3027
3154
  raise Exception(f"CUDA error: {err}")
3028
3155
 
3029
3156
  return results
3157
+
3158
+
3159
+ def get_physical_device_id(pytorch_device_id: int) -> int:
3160
+ """
3161
+ Convert PyTorch logical device ID to physical device ID.
3162
+ """
3163
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
3164
+ assert (
3165
+ cuda_visible_devices is not None
3166
+ ), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
3167
+ device_list = cuda_visible_devices.split(",")
3168
+ assert (
3169
+ len(device_list) == 1
3170
+ ), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
3171
+ return int(device_list[0])
3172
+
3173
+
3174
+ def get_device_sm_nvidia_smi():
3175
+ try:
3176
+ # Run nvidia-smi command and capture output
3177
+ result = subprocess.run(
3178
+ ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
3179
+ capture_output=True,
3180
+ text=True,
3181
+ check=True,
3182
+ )
3183
+
3184
+ # Get the first line of output (assuming at least one GPU exists)
3185
+ compute_cap_str = result.stdout.strip().split("\n")[0]
3186
+
3187
+ # Convert string (e.g., "9.0") to tuple of integers (9, 0)
3188
+ major, minor = map(int, compute_cap_str.split("."))
3189
+ return (major, minor)
3190
+
3191
+ except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
3192
+ # Handle cases where nvidia-smi isn't available or output is unexpected
3193
+ print(f"Error getting compute capability: {e}")
3194
+ return (0, 0) # Default/fallback value
3195
+
3196
+
3197
+ def numa_bind_to_node(node: int):
3198
+ libnuma = ctypes.CDLL("libnuma.so")
3199
+ if libnuma.numa_available() < 0:
3200
+ raise SystemError("numa not available on this system")
3201
+
3202
+ libnuma.numa_run_on_node(ctypes.c_int(node))
3203
+ libnuma.numa_set_localalloc()
3204
+
3205
+
3206
+ def json_list_type(value):
3207
+ try:
3208
+ return json.loads(value)
3209
+ except json.JSONDecodeError:
3210
+ raise argparse.ArgumentTypeError(
3211
+ f"Invalid JSON list: {value}. Please provide a valid JSON list."
3212
+ )
3213
+
3214
+
3215
+ @contextmanager
3216
+ def temp_set_cuda_visible_devices(gpu_id: int):
3217
+ original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
3218
+ if original_cuda_visible_devices:
3219
+ cuda_visible_devices = original_cuda_visible_devices.split(",")
3220
+ else:
3221
+ cuda_visible_devices = []
3222
+
3223
+ str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
3224
+ os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
3225
+ yield
3226
+ if original_cuda_visible_devices:
3227
+ os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
3228
+ else:
3229
+ del os.environ["CUDA_VISIBLE_DEVICES"]
3230
+
3231
+
3232
+ def get_extend_input_len_swa_limit(
3233
+ sliding_window_size: int, chunked_prefill_size: int, page_size: int
3234
+ ) -> int:
3235
+ # 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
3236
+ # and between prefills, we run swa_radix_cache.cache_unfinished_req(),
3237
+ # so we unlock the previously locked nodes.
3238
+ # 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
3239
+ # in that case, each prefill contains chunked_prefill_size tokens,
3240
+ # and we can only free out-of-sliding-window kv indices after each prefill.
3241
+ # 3. page_size is because we want to have 1 token extra for generated tokens.
3242
+ return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
3243
+
3244
+
3245
+ def get_num_new_pages(
3246
+ seq_lens: torch.Tensor,
3247
+ page_size: int,
3248
+ prefix_lens: Optional[torch.Tensor] = None,
3249
+ decode: bool = False,
3250
+ ) -> torch.Tensor:
3251
+ """
3252
+ Get the number of new pages for the given prefix and sequence lengths.
3253
+ We use cpu tensors to avoid blocking kernel launch.
3254
+ """
3255
+ cpu_device = torch.device("cpu")
3256
+ assert seq_lens.device == cpu_device
3257
+
3258
+ if prefix_lens is None or decode:
3259
+ # NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
3260
+ assert decode
3261
+ return (seq_lens % page_size == 1).int().sum().item()
3262
+
3263
+ assert prefix_lens.device == cpu_device
3264
+ num_pages_after = (seq_lens + page_size - 1) // page_size
3265
+ num_pages_before = (prefix_lens + page_size - 1) // page_size
3266
+ num_new_pages = num_pages_after - num_pages_before
3267
+ sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
3268
+ return sum_num_new_pages.item()
3269
+
3270
+
3271
+ class CachedKernel:
3272
+ """
3273
+ Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
3274
+
3275
+ This wrapper caches compiled Triton kernels based on keys extracted by a
3276
+ user-provided key function to avoid redundant compilations.
3277
+ """
3278
+
3279
+ def __init__(self, fn, key_fn=None):
3280
+ self.fn = fn
3281
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
3282
+
3283
+ original_fn = fn.fn
3284
+ self.signature = inspect.signature(original_fn)
3285
+ self.param_names = tuple(self.signature.parameters.keys())
3286
+ self.num_args = len(self.param_names)
3287
+
3288
+ # Check that no parameters have default values
3289
+ for name, param in self.signature.parameters.items():
3290
+ assert (
3291
+ param.default is inspect.Parameter.empty
3292
+ ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
3293
+
3294
+ functools.update_wrapper(self, original_fn)
3295
+ self.kernel_cache = {}
3296
+
3297
+ # Store the key function
3298
+ self.key_fn = key_fn
3299
+
3300
+ def __getitem__(self, grid):
3301
+ """
3302
+ Index with grid to get a launcher function.
3303
+ Returns a launcher that will handle caching based on the key function.
3304
+ """
3305
+ assert (
3306
+ isinstance(grid, tuple) and len(grid) <= 3
3307
+ ), "Grid must be a tuple with at most 3 dimensions."
3308
+
3309
+ # Normalize grid once
3310
+ if len(grid) < 3:
3311
+ grid = grid + (1,) * (3 - len(grid))
3312
+
3313
+ def launcher(*args, **kwargs):
3314
+ cache_key = self.key_fn(args, kwargs)
3315
+
3316
+ cached_kernel = self.kernel_cache.get(cache_key)
3317
+
3318
+ if cached_kernel is None:
3319
+ # First time: compile and cache the kernel
3320
+ cached_kernel = self.fn[grid](*args, **kwargs)
3321
+ self.kernel_cache[cache_key] = cached_kernel
3322
+ return cached_kernel
3323
+ else:
3324
+ # Use cached kernel
3325
+ all_args = self._build_args(args, kwargs)
3326
+ cached_kernel[grid](*all_args)
3327
+ return cached_kernel
3328
+
3329
+ return launcher
3330
+
3331
+ def _build_args(self, args, kwargs):
3332
+ """
3333
+ Build the complete argument list for kernel invocation.
3334
+ """
3335
+ complete_args = list(args)
3336
+
3337
+ for i in range(len(args), self.num_args):
3338
+ name = self.param_names[i]
3339
+ value = kwargs.get(name, inspect.Parameter.empty)
3340
+ if value is not inspect.Parameter.empty:
3341
+ complete_args.append(value)
3342
+ else:
3343
+ raise ValueError(f"Missing argument: {name}")
3344
+
3345
+ return complete_args
3346
+
3347
+ def _clear_cache(self):
3348
+ """
3349
+ Clear the kernel cache for testing purposes.
3350
+ """
3351
+ self.kernel_cache.clear()
3352
+
3353
+
3354
+ def cached_triton_kernel(key_fn=None):
3355
+ """
3356
+ Decorator that enables key-based caching for Triton kernels using a key function.
3357
+
3358
+ It essentially bypasses Triton's built-in caching mechanism, allowing users to
3359
+ define their own caching strategy based on kernel parameters. This helps reduce
3360
+ the heavy overheads of Triton kernel launch when the kernel specialization dispatch
3361
+ is simple.
3362
+
3363
+ Usage:
3364
+ @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
3365
+ @triton.jit
3366
+ def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
3367
+ ...
3368
+
3369
+ # Invoke normally
3370
+ my_kernel[grid](x, y, BLOCK_SIZE=1024)
3371
+
3372
+ Args:
3373
+ key_fn: A function that takes (args, kwargs) and returns the cache key(s).
3374
+ The key can be a single value or a tuple of values.
3375
+
3376
+ Returns:
3377
+ A decorator that wraps the kernel with caching functionality.
3378
+
3379
+ Note: Kernels with default parameter values are not supported and will raise an assertion error.
3380
+ """
3381
+
3382
+ def decorator(fn):
3383
+ return CachedKernel(fn, key_fn)
3384
+
3385
+ return decorator