sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,313 @@
1
+ """
2
+ Batch the same prompt in random batch sizes, and test if the results are consistent across different trials.
3
+
4
+ Usage:
5
+ python3 -m sglang.test.test_deterministic --n-trials <numer_of_trials> --test-mode <single|mixed|prefix> --profile
6
+ """
7
+
8
+ import argparse
9
+ import dataclasses
10
+ import json
11
+ import os
12
+ import random
13
+ from typing import List
14
+
15
+ import requests
16
+
17
+ from sglang.profiler import run_profile
18
+
19
+ PROMPT_1 = "Tell me about Richard Feynman: "
20
+ PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
21
+ dirpath = os.path.dirname(__file__)
22
+ with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
23
+ LONG_PROMPT = f.read()
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class BenchArgs:
28
+ host: str = "localhost"
29
+ port: int = 30000
30
+ batch_size: int = 1
31
+ temperature: float = 0.0
32
+ sampling_seed: int = 42
33
+ max_new_tokens: int = 100
34
+ frequency_penalty: float = 0.0
35
+ presence_penalty: float = 0.0
36
+ return_logprob: bool = False
37
+ stream: bool = False
38
+ profile: bool = False
39
+ profile_steps: int = 3
40
+ profile_by_stage: bool = False
41
+ test_mode: str = "single"
42
+ n_trials: int = 50
43
+ n_start: int = 1
44
+
45
+ @staticmethod
46
+ def add_cli_args(parser: argparse.ArgumentParser):
47
+ parser.add_argument("--host", type=str, default=BenchArgs.host)
48
+ parser.add_argument("--port", type=int, default=BenchArgs.port)
49
+ parser.add_argument("--n-trials", type=int, default=BenchArgs.n_trials)
50
+ parser.add_argument("--n-start", type=int, default=BenchArgs.n_start)
51
+ parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
52
+ parser.add_argument(
53
+ "--sampling-seed", type=int, default=BenchArgs.sampling_seed
54
+ )
55
+ parser.add_argument(
56
+ "--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
57
+ )
58
+ parser.add_argument(
59
+ "--frequency-penalty", type=float, default=BenchArgs.frequency_penalty
60
+ )
61
+ parser.add_argument(
62
+ "--presence-penalty", type=float, default=BenchArgs.presence_penalty
63
+ )
64
+ parser.add_argument("--return-logprob", action="store_true")
65
+ parser.add_argument("--stream", action="store_true")
66
+ parser.add_argument(
67
+ "--test-mode",
68
+ type=str,
69
+ default=BenchArgs.test_mode,
70
+ choices=["single", "mixed", "prefix"],
71
+ )
72
+ parser.add_argument("--profile", action="store_true")
73
+ parser.add_argument(
74
+ "--profile-steps", type=int, default=BenchArgs.profile_steps
75
+ )
76
+ parser.add_argument("--profile-by-stage", action="store_true")
77
+
78
+ @classmethod
79
+ def from_cli_args(cls, args: argparse.Namespace):
80
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
81
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
82
+
83
+
84
+ def send_single(
85
+ args,
86
+ batch_size: int,
87
+ profile: bool = False,
88
+ profile_steps: int = 3,
89
+ profile_by_stage: bool = False,
90
+ ):
91
+
92
+ base_url = f"http://{args.host}:{args.port}"
93
+ prompt = [PROMPT_1] * batch_size
94
+
95
+ json_data = {
96
+ "text": prompt,
97
+ "sampling_params": {
98
+ "temperature": args.temperature,
99
+ "max_new_tokens": args.max_new_tokens,
100
+ "frequency_penalty": args.frequency_penalty,
101
+ "presence_penalty": args.presence_penalty,
102
+ },
103
+ "return_logprob": args.return_logprob,
104
+ "stream": args.stream,
105
+ }
106
+
107
+ if args.sampling_seed is not None:
108
+ # sglang server cannot parse None value for sampling_seed
109
+ json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
110
+
111
+ if profile:
112
+ run_profile(
113
+ base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
114
+ )
115
+
116
+ response = requests.post(
117
+ f"{base_url}/generate",
118
+ json=json_data,
119
+ stream=args.stream,
120
+ )
121
+
122
+ if args.stream:
123
+ for chunk in response.iter_lines(decode_unicode=False):
124
+ chunk = chunk.decode("utf-8")
125
+ if chunk and chunk.startswith("data:"):
126
+ if chunk == "data: [DONE]":
127
+ break
128
+ ret = json.loads(chunk[5:].strip("\n"))
129
+ else:
130
+ ret = response.json()
131
+ ret = ret[0]
132
+
133
+ if response.status_code != 200:
134
+ print(ret)
135
+ return -1
136
+
137
+ return ret["text"]
138
+
139
+
140
+ def send_mixed(args, batch_size: int):
141
+ num_long_prompt = 0 if batch_size <= 10 else random.randint(1, 10)
142
+ num_prompt_1 = random.randint(1, batch_size - num_long_prompt)
143
+ num_prompt_2 = batch_size - num_prompt_1 - num_long_prompt
144
+
145
+ json_data = {
146
+ "text": [PROMPT_1] * num_prompt_1
147
+ + [PROMPT_2] * num_prompt_2
148
+ + [LONG_PROMPT] * num_long_prompt,
149
+ "sampling_params": {
150
+ "temperature": args.temperature,
151
+ "max_new_tokens": args.max_new_tokens,
152
+ "frequency_penalty": args.frequency_penalty,
153
+ "presence_penalty": args.presence_penalty,
154
+ },
155
+ "return_logprob": args.return_logprob,
156
+ "stream": args.stream,
157
+ }
158
+
159
+ if args.sampling_seed is not None:
160
+ json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
161
+
162
+ response = requests.post(
163
+ f"http://{args.host}:{args.port}/generate",
164
+ json=json_data,
165
+ stream=args.stream,
166
+ )
167
+ ret = response.json()
168
+ if response.status_code != 200:
169
+ print(ret)
170
+ return -1, -1, -1
171
+
172
+ prompt_1_ret = [ret[i]["text"] for i in range(num_prompt_1)]
173
+ prompt_2_ret = [
174
+ ret[i]["text"] for i in range(num_prompt_1, num_prompt_1 + num_prompt_2)
175
+ ]
176
+ long_prompt_ret = [
177
+ ret[i]["text"]
178
+ for i in range(
179
+ num_prompt_1 + num_prompt_2, num_prompt_1 + num_prompt_2 + num_long_prompt
180
+ )
181
+ ]
182
+
183
+ return prompt_1_ret, prompt_2_ret, long_prompt_ret
184
+
185
+
186
+ def send_prefix(args, batch_size: int, prompts: List[str]):
187
+ requests.post(f"http://{args.host}:{args.port}/flush_cache")
188
+
189
+ batch_data = []
190
+ sampled_indices = []
191
+ for _ in range(batch_size):
192
+ sampled_index = random.randint(0, len(prompts) - 1)
193
+ sampled_indices.append(sampled_index)
194
+ batch_data.append(prompts[sampled_index])
195
+
196
+ json_data = {
197
+ "text": batch_data,
198
+ "sampling_params": {
199
+ "temperature": args.temperature,
200
+ "max_new_tokens": args.max_new_tokens,
201
+ "frequency_penalty": args.frequency_penalty,
202
+ "presence_penalty": args.presence_penalty,
203
+ },
204
+ "return_logprob": args.return_logprob,
205
+ "stream": args.stream,
206
+ }
207
+
208
+ if args.sampling_seed is not None:
209
+ json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
210
+
211
+ response = requests.post(
212
+ f"http://{args.host}:{args.port}/generate",
213
+ json=json_data,
214
+ stream=args.stream,
215
+ )
216
+ ret = response.json()
217
+ if response.status_code != 200:
218
+ print(ret)
219
+ return -1, -1, -1
220
+
221
+ ret_dict = {i: [] for i in range(len(prompts))}
222
+ for i in range(batch_size):
223
+ ret_dict[sampled_indices[i]].append(ret[i]["text"])
224
+
225
+ return ret_dict
226
+
227
+
228
+ def test_deterministic(args):
229
+ # First do some warmups
230
+ for i in range(3):
231
+ send_single(args, 16, args.profile)
232
+
233
+ if args.test_mode == "single":
234
+ # In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
235
+ texts = []
236
+ for i in range(1, args.n_trials + 1):
237
+ batch_size = i
238
+ text = send_single(args, batch_size, args.profile)
239
+ text = text.replace("\n", " ")
240
+ print(f"Trial {i} with batch size {batch_size}: {text}")
241
+ texts.append(text)
242
+
243
+ print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
244
+ return [len(set(texts))]
245
+
246
+ elif args.test_mode == "mixed":
247
+ # In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials.
248
+ output_prompt_1 = []
249
+ output_prompt_2 = []
250
+ output_long_prompt = []
251
+ for i in range(1, args.n_trials + 1):
252
+ batch_size = i
253
+ ret_prompt_1, ret_prompt_2, ret_long_prompt = send_mixed(args, batch_size)
254
+ output_prompt_1.extend(ret_prompt_1)
255
+ output_prompt_2.extend(ret_prompt_2)
256
+ output_long_prompt.extend(ret_long_prompt)
257
+
258
+ print(
259
+ f"Testing Trial {i} with batch size {batch_size}, number of prompt 1: {len(ret_prompt_1)}, number of prompt 2: {len(ret_prompt_2)}, number of long prompt: {len(ret_long_prompt)}"
260
+ )
261
+
262
+ print(
263
+ f"Prompt 1: total samples: {len(output_prompt_1)}, Unique samples: {len(set(output_prompt_1))}"
264
+ )
265
+ print(
266
+ f"Prompt 2: total samples: {len(output_prompt_2)}, Unique samples: {len(set(output_prompt_2))}"
267
+ )
268
+ print(
269
+ f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}"
270
+ )
271
+
272
+ return [
273
+ len(set(output_prompt_1)),
274
+ len(set(output_prompt_2)),
275
+ len(set(output_long_prompt)),
276
+ ]
277
+
278
+ elif args.test_mode == "prefix":
279
+ # In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
280
+ len_prefix = [1, 511, 2048, 4097]
281
+ num_prompts = len(len_prefix)
282
+ outputs = {i: [] for i in range(4)}
283
+ prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
284
+ for i in range(args.n_start, args.n_start + args.n_trials):
285
+ batch_size = i
286
+ ret_dict = send_prefix(args, batch_size, prompts)
287
+ msg = f"Testing Trial {i} with batch size {batch_size},"
288
+ for i in range(num_prompts):
289
+ msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
290
+ print(msg)
291
+ for i in range(num_prompts):
292
+ outputs[i].extend(ret_dict[i])
293
+
294
+ for i in range(num_prompts):
295
+ print(
296
+ f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}"
297
+ )
298
+
299
+ results = []
300
+ for i in range(num_prompts):
301
+ results.append(len(set(outputs[i])))
302
+ return results
303
+
304
+ else:
305
+ raise ValueError(f"Invalid test mode: {args.test_mode}")
306
+
307
+
308
+ if __name__ == "__main__":
309
+ parser = argparse.ArgumentParser()
310
+ BenchArgs.add_cli_args(parser)
311
+ args = parser.parse_args()
312
+
313
+ test_deterministic(args)
@@ -0,0 +1,81 @@
1
+ import time
2
+ import unittest
3
+
4
+ import requests
5
+
6
+ from sglang.srt.utils import kill_process_tree
7
+ from sglang.test.test_deterministic import BenchArgs, test_deterministic
8
+ from sglang.test.test_utils import (
9
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
10
+ DEFAULT_URL_FOR_TEST,
11
+ CustomTestCase,
12
+ popen_launch_server,
13
+ )
14
+
15
+ DEFAULT_MODEL = "Qwen/Qwen3-8B"
16
+ COMMON_SERVER_ARGS = [
17
+ "--trust-remote-code",
18
+ "--cuda-graph-max-bs",
19
+ "32",
20
+ "--enable-deterministic-inference",
21
+ ]
22
+
23
+
24
+ class TestDeterministicBase(CustomTestCase):
25
+ @classmethod
26
+ def get_server_args(cls):
27
+ return COMMON_SERVER_ARGS
28
+
29
+ @classmethod
30
+ def setUpClass(cls):
31
+ cls.model = DEFAULT_MODEL
32
+ cls.base_url = DEFAULT_URL_FOR_TEST
33
+ if "--attention-backend" not in cls.get_server_args():
34
+ raise unittest.SkipTest("Skip the base test class")
35
+
36
+ cls.process = popen_launch_server(
37
+ cls.model,
38
+ cls.base_url,
39
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
40
+ other_args=cls.get_server_args(),
41
+ )
42
+
43
+ @classmethod
44
+ def tearDownClass(cls):
45
+ kill_process_tree(cls.process.pid)
46
+
47
+ def _extract_host_and_port(self, url):
48
+ return url.split("://")[-1].split(":")[0], int(url.split(":")[-1])
49
+
50
+ def test_single(self):
51
+ args = BenchArgs()
52
+ url = DEFAULT_URL_FOR_TEST
53
+ args.host, args.port = self._extract_host_and_port(url)
54
+ args.test_mode = "single"
55
+ args.n_start = 10
56
+ args.n_trials = 20
57
+ results = test_deterministic(args)
58
+ for result in results:
59
+ assert result == 1
60
+
61
+ def test_mixed(self):
62
+ args = BenchArgs()
63
+ url = DEFAULT_URL_FOR_TEST
64
+ args.host, args.port = self._extract_host_and_port(url)
65
+ args.test_mode = "mixed"
66
+ args.n_start = 10
67
+ args.n_trials = 20
68
+ results = test_deterministic(args)
69
+ for result in results:
70
+ assert result == 1
71
+
72
+ def test_prefix(self):
73
+ args = BenchArgs()
74
+ url = DEFAULT_URL_FOR_TEST
75
+ args.host, args.port = self._extract_host_and_port(url)
76
+ args.test_mode = "prefix"
77
+ args.n_start = 10
78
+ args.n_trials = 10
79
+ results = test_deterministic(args)
80
+ for result in results:
81
+ assert result == 1
@@ -0,0 +1,140 @@
1
+ import os
2
+ import time
3
+ import warnings
4
+ from urllib.parse import urlparse
5
+
6
+ import requests
7
+
8
+ from sglang.srt.environ import envs
9
+ from sglang.srt.utils import kill_process_tree
10
+ from sglang.test.test_utils import (
11
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
12
+ DEFAULT_URL_FOR_TEST,
13
+ CustomTestCase,
14
+ is_in_ci,
15
+ popen_with_error_check,
16
+ )
17
+
18
+
19
+ class TestDisaggregationBase(CustomTestCase):
20
+ @classmethod
21
+ def setUpClass(cls):
22
+ parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
23
+ cls.base_host = parsed_url.hostname
24
+ base_port = str(parsed_url.port)
25
+ cls.lb_port = base_port
26
+ cls.prefill_port = f"{int(base_port) + 100}"
27
+ cls.decode_port = f"{int(base_port) + 200}"
28
+ cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
29
+ cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
30
+ cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
31
+ print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
32
+ cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
33
+
34
+ # config transfer backend and rdma devices
35
+ if is_in_ci():
36
+ cls.transfer_backend = ["--disaggregation-transfer-backend", "mooncake"]
37
+ cls.rdma_devices = ["--disaggregation-ib-device", get_rdma_devices_args()]
38
+ else:
39
+ cls.transfer_backend = [
40
+ "--disaggregation-transfer-backend",
41
+ envs.SGLANG_TEST_PD_DISAGG_BACKEND.get(),
42
+ ]
43
+ cls.rdma_devices = [
44
+ "--disaggregation-ib-device",
45
+ envs.SGLANG_TEST_PD_DISAGG_DEVICES.get(),
46
+ ]
47
+ if cls.rdma_devices[1] is None:
48
+ cls.rdma_devices = []
49
+ msg = "No RDMA devices specified for disaggregation test, using default settings."
50
+ warnings.warn(msg)
51
+
52
+ @classmethod
53
+ def launch_lb(cls):
54
+ lb_command = [
55
+ "python3",
56
+ "-m",
57
+ "sglang_router.launch_router",
58
+ "--pd-disaggregation",
59
+ "--mini-lb", # FIXME: remove this
60
+ "--prefill",
61
+ cls.prefill_url,
62
+ "--decode",
63
+ cls.decode_url,
64
+ "--host",
65
+ cls.base_host,
66
+ "--port",
67
+ cls.lb_port,
68
+ ]
69
+ print("Starting load balancer:", " ".join(lb_command))
70
+ cls.process_lb = popen_with_error_check(lb_command)
71
+ cls.wait_server_ready(cls.lb_url + "/health")
72
+
73
+ @classmethod
74
+ def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
75
+ start_time = time.perf_counter()
76
+ while True:
77
+ try:
78
+ response = requests.get(url)
79
+ if response.status_code == 200:
80
+ print(f"Server {url} is ready")
81
+ return
82
+ except Exception:
83
+ pass
84
+
85
+ if time.perf_counter() - start_time > timeout:
86
+ raise RuntimeError(f"Server {url} failed to start in {timeout}s")
87
+ time.sleep(1)
88
+
89
+ @classmethod
90
+ def tearDownClass(cls):
91
+ for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
92
+ if process:
93
+ try:
94
+ kill_process_tree(process.pid)
95
+ except Exception as e:
96
+ print(f"Error killing process {process.pid}: {e}")
97
+
98
+ # wait for 5 seconds
99
+ time.sleep(5)
100
+
101
+
102
+ def get_rdma_devices_args():
103
+ # 1. Get visible GPU indices
104
+ cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
105
+ if not cuda_visible_devices:
106
+ warnings.warn("CUDA_VISIBLE_DEVICES is not set. Using default RDMA devices.")
107
+ return "mlx5_roce0,mlx5_roce4"
108
+
109
+ try:
110
+ # Convert to list of integers (handling possible spaces and empty strings)
111
+ gpu_indices = [
112
+ int(idx.strip()) for idx in cuda_visible_devices.split(",") if idx.strip()
113
+ ]
114
+ if not gpu_indices or len(gpu_indices) > 4:
115
+ return "mlx5_roce0,mlx5_roce4"
116
+ except ValueError:
117
+ warnings.warn(f"Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible_devices}")
118
+ return "mlx5_roce0,mlx5_roce4"
119
+
120
+ # 2. Calculate base RDMA index group (each group of 4 GPUs uses consecutive devices)
121
+ base_rdma_group = min(gpu_indices) // 4 * 4
122
+
123
+ # 3. Generate RDMA device names
124
+ rdma_devices = []
125
+ for gpu_idx in gpu_indices:
126
+ # Validate GPU index within expected range
127
+ if gpu_idx < base_rdma_group or gpu_idx >= base_rdma_group + 4:
128
+ warnings.warn(
129
+ f"GPU index {gpu_idx} is outside expected group {base_rdma_group}-{base_rdma_group+3}"
130
+ )
131
+ continue
132
+
133
+ # Map GPU index to RDMA device index
134
+ rdma_index = base_rdma_group // 4 * 4 + (gpu_idx % 4)
135
+ rdma_devices.append(f"mlx5_roce{rdma_index}")
136
+
137
+ if not rdma_devices:
138
+ return "mlx5_roce0,mlx5_roce4"
139
+
140
+ return ",".join(rdma_devices)