sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -172,6 +172,20 @@ def is_blackwell():
172
172
  return torch.cuda.get_device_capability()[0] == 10
173
173
 
174
174
 
175
+ @lru_cache(maxsize=1)
176
+ def is_sm100_supported(device=None) -> bool:
177
+ return (torch.cuda.get_device_capability(device)[0] == 10) and (
178
+ torch.version.cuda >= "12.8"
179
+ )
180
+
181
+
182
+ @lru_cache(maxsize=1)
183
+ def is_sm90_supported(device=None) -> bool:
184
+ return (torch.cuda.get_device_capability(device)[0] == 9) and (
185
+ torch.version.cuda >= "12.3"
186
+ )
187
+
188
+
175
189
  _warned_bool_env_var_keys = set()
176
190
 
177
191
 
@@ -216,8 +230,16 @@ except:
216
230
  is_intel_amx_backend_available = False
217
231
 
218
232
 
233
+ try:
234
+ # move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
235
+ # to support torch compile
236
+ is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
237
+ except:
238
+ is_amx_tile_supported = False
239
+
240
+
219
241
  def cpu_has_amx_support():
220
- return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
242
+ return is_amx_tile_supported and is_intel_amx_backend_available
221
243
 
222
244
 
223
245
  def use_intel_amx_backend(layer):
@@ -412,7 +434,9 @@ def get_available_gpu_memory(
412
434
 
413
435
  elif device == "cpu":
414
436
  # TODO: rename the variables in the current function to be not GPU specific
415
- free_gpu_memory = psutil.virtual_memory().available
437
+ total_free_memory = psutil.virtual_memory().available
438
+ n_numa_node: int = len(get_cpu_ids_by_node())
439
+ free_gpu_memory = round(total_free_memory / n_numa_node, 3)
416
440
  elif device == "npu":
417
441
  num_gpus = torch.npu.device_count()
418
442
  assert gpu_id < num_gpus
@@ -1665,9 +1689,29 @@ def direct_register_custom_op(
1665
1689
  IMPORTANT: the lifetime of the operator is tied to the lifetime of the
1666
1690
  library object. If you want to bind the operator to a different library,
1667
1691
  make sure the library object is alive when the operator is used.
1692
+
1693
+ Note: This function will silently skip registration if the operator
1694
+ with the same name is already registered to avoid RuntimeError in
1695
+ multi-engine scenarios (e.g., VERL framework).
1668
1696
  """
1669
1697
  import torch.library
1670
1698
 
1699
+ my_lib = target_lib or sglang_lib
1700
+
1701
+ # Check if operator is already registered to avoid duplicate registration
1702
+ # This is important for scenarios where multiple SGLang engines run in the same process
1703
+ try:
1704
+ # Try to access the operator to see if it's already registered
1705
+ lib_name = my_lib.m.name if hasattr(my_lib.m, "name") else "sglang"
1706
+ if hasattr(torch.ops, lib_name) and hasattr(
1707
+ getattr(torch.ops, lib_name), op_name
1708
+ ):
1709
+ # Operator already exists, skip registration
1710
+ return
1711
+ except (AttributeError, RuntimeError):
1712
+ # Operator doesn't exist, proceed with registration
1713
+ pass
1714
+
1671
1715
  if hasattr(torch.library, "infer_schema"):
1672
1716
  schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
1673
1717
  else:
@@ -1676,11 +1720,22 @@ def direct_register_custom_op(
1676
1720
 
1677
1721
  schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
1678
1722
 
1679
- my_lib = target_lib or sglang_lib
1680
- my_lib.define(op_name + schema_str)
1681
- my_lib.impl(op_name, op_func, "CUDA")
1682
- if fake_impl is not None:
1683
- my_lib._register_fake(op_name, fake_impl)
1723
+ try:
1724
+ my_lib.define(op_name + schema_str)
1725
+ my_lib.impl(op_name, op_func, "CUDA")
1726
+ if fake_impl is not None:
1727
+ my_lib._register_fake(op_name, fake_impl)
1728
+ except RuntimeError as error:
1729
+ if "Tried to register an operator" in str(e) and "multiple times" in str(e):
1730
+ # Silently ignore duplicate registration errors
1731
+ # This can happen in multi-engine scenarios
1732
+ pass
1733
+ else:
1734
+ # Re-raise other RuntimeErrors
1735
+ raise error
1736
+ except AttributeError as error:
1737
+ # Always re-raise AttributeError as it indicates missing dependencies
1738
+ raise error
1684
1739
 
1685
1740
 
1686
1741
  def set_gpu_proc_affinity(
@@ -1919,6 +1974,15 @@ def get_ip() -> str:
1919
1974
  except Exception:
1920
1975
  pass
1921
1976
 
1977
+ # try using hostname
1978
+ hostname = socket.gethostname()
1979
+ try:
1980
+ ip_addr = socket.gethostbyname(hostname)
1981
+ warnings.warn("using local ip address: {}".format(ip_addr))
1982
+ return ip_addr
1983
+ except Exception:
1984
+ pass
1985
+
1922
1986
  warnings.warn(
1923
1987
  "Failed to get the IP address, using 0.0.0.0 by default."
1924
1988
  "The value can be set by the environment variable"
@@ -2733,6 +2797,10 @@ def lru_cache_frozenset(maxsize=128):
2733
2797
  return decorator
2734
2798
 
2735
2799
 
2800
+ def get_origin_rid(rid):
2801
+ return rid.split("_", 1)[1] if "_" in rid else rid
2802
+
2803
+
2736
2804
  def apply_module_patch(target_module, target_function, wrappers):
2737
2805
  original_module, original_function = parse_module_path(
2738
2806
  target_module, target_function, False
@@ -2842,6 +2910,18 @@ def mxfp_supported():
2842
2910
  return False
2843
2911
 
2844
2912
 
2913
+ @lru_cache(maxsize=1)
2914
+ def is_gfx95_supported():
2915
+ """
2916
+ Returns whether the current platform supports MX types.
2917
+ """
2918
+ if torch.version.hip:
2919
+ gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
2920
+ return any(gfx in gcn_arch for gfx in ["gfx95"])
2921
+ else:
2922
+ return False
2923
+
2924
+
2845
2925
  # LoRA-related constants and utilities
2846
2926
  SUPPORTED_LORA_TARGET_MODULES = [
2847
2927
  "q_proj",
@@ -2957,3 +3037,12 @@ def check_cuda_result(raw_output):
2957
3037
  raise Exception(f"CUDA error: {err}")
2958
3038
 
2959
3039
  return results
3040
+
3041
+
3042
+ def numa_bind_to_node(node: int):
3043
+ libnuma = ctypes.CDLL("libnuma.so")
3044
+ if libnuma.numa_available() < 0:
3045
+ raise SystemError("numa not available on this system")
3046
+
3047
+ libnuma.numa_run_on_node(ctypes.c_int(node))
3048
+ libnuma.numa_set_localalloc()
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
6
6
  from torch.distributed.tensor import DTensor
7
7
 
8
8
  from sglang.srt.entrypoints.engine import Engine
9
- from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
9
+ from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
10
10
  from sglang.srt.model_executor.model_runner import LocalSerializedTensor
11
11
  from sglang.srt.utils import MultiprocessingSerializer
12
12
 
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
41
41
  "v_head_dim": 512,
42
42
  "num_kv_heads": 1,
43
43
  "layer_id": 0,
44
+ "tp_q_head_num": 128,
45
+ "tp_k_head_num": 128,
46
+ "prefill_head_dim": 192,
47
+ "prefill_v_head_dim": 128,
44
48
  }
45
49
 
46
50
  ROPE_BASE = 10000
@@ -92,7 +96,7 @@ TEST_CASES = {
92
96
  "description": "Medium-scale batch",
93
97
  },
94
98
  ],
95
- "decode_output_match": [
99
+ "output_match": [
96
100
  {
97
101
  "name": "single_fp16",
98
102
  "batch_size": 1,
@@ -208,6 +212,15 @@ class MockModelRunner:
208
212
  self.kv_cache_dtype = config["kv_cache_dtype"]
209
213
  self.page_size = config["page_size"]
210
214
 
215
+ # Server args stub - needed by attention backends
216
+ self.server_args = type(
217
+ "ServerArgs",
218
+ (),
219
+ {
220
+ "enable_dp_attention": False, # Default value for testing
221
+ },
222
+ )
223
+
211
224
  # Model-config stub with MLA attributes
212
225
  self.model_config = type(
213
226
  "ModelConfig",
@@ -313,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
313
326
  config.update(test_case)
314
327
  return config
315
328
 
316
- def _create_model_components(self, config):
329
+ def _create_model_components(self, config, is_prefill=False):
317
330
  """Create model runners, backends, and layer for testing."""
318
331
  # Create model runners
319
332
  model_runner_trtllm = MockModelRunner(config)
@@ -323,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
323
336
  trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
324
337
  reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
325
338
 
339
+ head_dim = (
340
+ config["kv_lora_rank"] + config["qk_rope_head_dim"]
341
+ if not is_prefill
342
+ else config["prefill_head_dim"]
343
+ )
344
+ v_head_dim = (
345
+ config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
346
+ )
347
+
326
348
  # Create RadixAttention layer
327
349
  layer = RadixAttention(
328
350
  num_heads=config["num_attention_heads"],
329
- head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
351
+ head_dim=head_dim,
330
352
  scaling=model_runner_trtllm.model_config.scaling,
331
353
  num_kv_heads=config["num_kv_heads"],
332
354
  layer_id=config["layer_id"],
333
- v_head_dim=config["v_head_dim"],
355
+ v_head_dim=v_head_dim,
334
356
  prefix="attn_mqa",
335
357
  )
336
358
 
@@ -515,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
515
537
  """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
516
538
  print(f"\nRunning decode output matching tests...")
517
539
 
518
- for test_case in TEST_CASES["decode_output_match"]:
540
+ for test_case in TEST_CASES["output_match"]:
519
541
  with self.subTest(test_case=test_case["name"]):
520
542
  print(f" Testing {test_case['name']}: {test_case['description']}")
521
543
 
@@ -833,7 +855,7 @@ class TestTRTLLMMLA(CustomTestCase):
833
855
 
834
856
  # Test workspace properties
835
857
  self.assertEqual(metadata.workspace.device.type, "cuda")
836
- self.assertEqual(metadata.workspace.dtype, torch.int8)
858
+ self.assertEqual(metadata.workspace.dtype, torch.uint8)
837
859
  self.assertGreater(
838
860
  metadata.workspace.numel(), 0, "Workspace should have non-zero size"
839
861
  )
@@ -993,8 +1015,8 @@ class TestTRTLLMMLA(CustomTestCase):
993
1015
  )
994
1016
 
995
1017
  # Verify CUDA graph buffers are allocated
996
- self.assertIsNotNone(backend.cuda_graph_kv_indices)
997
- self.assertIsNotNone(backend.cuda_graph_workspace)
1018
+ self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
1019
+ self.assertIsNotNone(backend.decode_cuda_graph_workspace)
998
1020
 
999
1021
  # Test capture metadata
1000
1022
  seq_lens = torch.full(
@@ -1090,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
1090
1112
  self.assertIsNotNone(metadata_3.block_kv_indices)
1091
1113
  self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
1092
1114
 
1115
+ def test_prefill_output_match_self_attention(self):
1116
+ """Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
1117
+ print(f"\nRunning prefill output tests...")
1118
+
1119
+ for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
1120
+ with self.subTest(test_case=test_case["name"]):
1121
+ print(
1122
+ f"Prefill Testing {test_case['name']}: {test_case['description']}"
1123
+ )
1124
+
1125
+ config = self._merge_config(test_case)
1126
+ batch_size = config["batch_size"]
1127
+ max_seq_len = config["max_seq_len"]
1128
+
1129
+ # Create components
1130
+ (
1131
+ model_runner_trtllm,
1132
+ model_runner_reference,
1133
+ trtllm_backend,
1134
+ reference_backend,
1135
+ layer,
1136
+ ) = self._create_model_components(config, is_prefill=True)
1137
+
1138
+ # Prefill uses full sequences
1139
+ seq_lens = torch.full(
1140
+ (batch_size,), max_seq_len, device=config["device"]
1141
+ )
1142
+
1143
+ def _create_forward_batch_prefill(
1144
+ batch_size,
1145
+ seq_lens,
1146
+ extend_prefix_lens,
1147
+ backend,
1148
+ model_runner,
1149
+ config,
1150
+ ):
1151
+ """Create a forward batch for the given backend."""
1152
+
1153
+ fb = ForwardBatch(
1154
+ batch_size=batch_size,
1155
+ input_ids=torch.randint(
1156
+ 0, 100, (batch_size, 1), device=config["device"]
1157
+ ),
1158
+ out_cache_loc=torch.arange(batch_size, device=config["device"]),
1159
+ seq_lens_sum=int(seq_lens.sum().item()),
1160
+ extend_prefix_lens=extend_prefix_lens,
1161
+ extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
1162
+ extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
1163
+ .cpu()
1164
+ .int()
1165
+ .tolist(),
1166
+ forward_mode=ForwardMode.EXTEND,
1167
+ req_pool_indices=torch.arange(
1168
+ batch_size, device=config["device"]
1169
+ ),
1170
+ seq_lens=seq_lens,
1171
+ seq_lens_cpu=seq_lens.cpu(),
1172
+ attn_attend_prefix_cache=False,
1173
+ mha_return_lse=False,
1174
+ attn_backend=backend,
1175
+ )
1176
+ fb.req_to_token_pool = model_runner.req_to_token_pool
1177
+ fb.token_to_kv_pool = model_runner.token_to_kv_pool
1178
+
1179
+ # Add position information for RoPE
1180
+ fb.positions = torch.arange(batch_size, device=config["device"])
1181
+
1182
+ return fb
1183
+
1184
+ # Create forward batches
1185
+ fb_trtllm = _create_forward_batch_prefill(
1186
+ batch_size,
1187
+ seq_lens.clone(),
1188
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1189
+ trtllm_backend,
1190
+ model_runner_trtllm,
1191
+ config,
1192
+ )
1193
+ fb_reference = _create_forward_batch_prefill(
1194
+ batch_size,
1195
+ seq_lens.clone(),
1196
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1197
+ reference_backend,
1198
+ model_runner_reference,
1199
+ config,
1200
+ )
1201
+
1202
+ # Initialize metadata for both backends
1203
+ trtllm_backend.init_forward_metadata(fb_trtllm)
1204
+ reference_backend.init_forward_metadata(fb_reference)
1205
+
1206
+ # Create Q, K, V tensors for prefill
1207
+ torch.manual_seed(config["seed_qkv"])
1208
+
1209
+ def _create_qkv_tensors_prefill(
1210
+ batch_size, seq_len, config, dtype_override=None
1211
+ ):
1212
+ """Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
1213
+ device = config["device"]
1214
+ dtype = dtype_override or config["dtype"]
1215
+
1216
+ total_tokens = batch_size * seq_len
1217
+
1218
+ tp_q_head_num = config["tp_q_head_num"]
1219
+ tp_k_head_num = config["tp_k_head_num"]
1220
+ head_dim = config["prefill_head_dim"]
1221
+ v_head_dim = config["prefill_v_head_dim"]
1222
+
1223
+ q = torch.randn(
1224
+ (total_tokens, tp_q_head_num * head_dim),
1225
+ dtype=dtype,
1226
+ device=device,
1227
+ )
1228
+ k = torch.randn(
1229
+ (total_tokens, tp_k_head_num * head_dim),
1230
+ dtype=dtype,
1231
+ device=device,
1232
+ )
1233
+ v = torch.randn(
1234
+ (total_tokens, tp_k_head_num * v_head_dim),
1235
+ dtype=dtype,
1236
+ device=device,
1237
+ )
1238
+
1239
+ # Reshape as requested
1240
+ q = q.view(-1, tp_q_head_num, head_dim)
1241
+ k = k.view(-1, tp_k_head_num, head_dim)
1242
+ v = v.view(-1, tp_k_head_num, v_head_dim)
1243
+
1244
+ return q, k, v
1245
+
1246
+ q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
1247
+ # Run prefill on both backends
1248
+ out_trtllm = trtllm_backend.forward_extend(
1249
+ q, k, v, layer, fb_trtllm, False
1250
+ ).view(-1, layer.tp_q_head_num * layer.v_head_dim)
1251
+ out_reference = reference_backend.forward_extend(
1252
+ q, k, v, layer, fb_reference, False
1253
+ )
1254
+
1255
+ tolerance = config.get("tolerance", 1e-2)
1256
+ comparison_passed = compare_outputs(
1257
+ out_trtllm, out_reference, tolerance=tolerance
1258
+ )
1259
+ self.assertTrue(
1260
+ comparison_passed,
1261
+ f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
1262
+ f"Config: {test_case['name']}, "
1263
+ f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
1264
+ )
1265
+
1093
1266
 
1094
1267
  if __name__ == "__main__":
1095
1268
  unittest.main()
@@ -129,6 +129,7 @@ def run_eval(args):
129
129
 
130
130
  return {
131
131
  "accuracy": acc,
132
+ "invalid": invalid,
132
133
  "latency": latency,
133
134
  "output_throughput": output_throughput,
134
135
  }
sglang/test/runners.py CHANGED
@@ -505,6 +505,7 @@ class SRTRunner:
505
505
  mem_fraction_static: float = 0.65,
506
506
  trust_remote_code: bool = False,
507
507
  speculative_draft_model_path: Optional[str] = None,
508
+ speculative_draft_model_revision: Optional[str] = None,
508
509
  speculative_algorithm: Optional[str] = None,
509
510
  speculative_num_steps: Optional[int] = None,
510
511
  speculative_eagle_topk: Optional[int] = None,
@@ -526,6 +527,9 @@ class SRTRunner:
526
527
  spec_kwargs = {}
527
528
  if speculative_draft_model_path:
528
529
  spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
530
+ spec_kwargs["speculative_draft_model_revision"] = (
531
+ speculative_draft_model_revision
532
+ )
529
533
  spec_kwargs["speculative_algorithm"] = speculative_algorithm
530
534
  spec_kwargs["speculative_num_steps"] = speculative_num_steps
531
535
  spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
@@ -9,6 +9,7 @@ from transformers import AutoConfig
9
9
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
10
10
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
11
11
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
12
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
12
13
 
13
14
 
14
15
  # Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
@@ -21,7 +22,7 @@ def calc_diff(x, y):
21
22
 
22
23
  def get_model_config(tp_size: int):
23
24
  config = AutoConfig.from_pretrained(
24
- "deepseek-ai/deepseek-R1", trust_remote_code=True
25
+ "deepseek-ai/Deepseek-R1", trust_remote_code=True
25
26
  )
26
27
  E = config.n_routed_experts
27
28
  topk = config.num_experts_per_tok
@@ -152,14 +153,31 @@ def run_test(tp_size, batch_size, model_config, check=False):
152
153
  problem_sizes2,
153
154
  )
154
155
 
156
+ topk_output = StandardTopKOutput(
157
+ topk_weights=topk_weights,
158
+ topk_ids=topk_ids,
159
+ router_logits=torch.randn(
160
+ (batch_size, topk), device=topk_weights.device, dtype=dtype
161
+ ),
162
+ )
163
+
164
+ moe_runner_config = MoeRunnerConfig(
165
+ num_experts=E,
166
+ top_k=topk,
167
+ hidden_size=H,
168
+ intermediate_size_per_partition=I,
169
+ params_dtype=dtype,
170
+ activation="silu",
171
+ inplace=False,
172
+ )
173
+
155
174
  # Note: Triton expects non-transposed weights
156
- moe_config = MoeRunnerConfig(inplace=False)
157
175
  triton_lambda = lambda: fused_experts(
158
176
  x,
159
177
  w1,
160
178
  w2,
161
- (topk_weights, topk_ids, "dummy"),
162
- moe_config,
179
+ topk_output,
180
+ moe_runner_config,
163
181
  use_fp8_w8a8=True,
164
182
  w1_scale=w1_scale,
165
183
  w2_scale=w2_scale,
@@ -224,8 +242,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
224
242
  x,
225
243
  w1, # Original shape
226
244
  w2, # Original shape
227
- (topk_weights, topk_ids, "dummy"),
228
- moe_config,
245
+ topk_output,
246
+ moe_runner_config,
229
247
  use_fp8_w8a8=True,
230
248
  w1_scale=w1_scale,
231
249
  w2_scale=w2_scale,
@@ -1,6 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import Optional
3
+ from typing import Literal, Optional
4
4
 
5
5
  import pytest
6
6
  import torch
@@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten
25
25
  return packed_tensor.to(torch.int8)
26
26
 
27
27
 
28
- def pack_interleave(num_experts, ref_weight, ref_scale):
28
+ def pack_interleave(num_experts, ref_weight, ref_scale, alignment=4):
29
29
  n, k = ref_weight.shape[1], ref_weight.shape[2]
30
30
 
31
31
  weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
@@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
33
33
  w_q = w_q.contiguous()
34
34
 
35
35
  scale_interleaved = ref_scale.reshape(
36
- ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
36
+ ref_scale.shape[0],
37
+ ref_scale.shape[1],
38
+ (ref_scale.shape[2] // alignment),
39
+ alignment,
37
40
  ) # [E, N, K/4, 4]
38
41
  scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
39
42
  scale_interleaved = scale_interleaved.reshape(
40
- ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
43
+ ref_scale.shape[0],
44
+ ref_scale.shape[2] // alignment,
45
+ ref_scale.shape[1] * alignment,
41
46
  ) # [E, K/4, N*4]
42
47
  w_scale = scale_interleaved.contiguous()
43
48
 
@@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
48
53
  @pytest.mark.parametrize("N", [2048])
49
54
  @pytest.mark.parametrize("K", [7168])
50
55
  @pytest.mark.parametrize("E", [256])
51
- @pytest.mark.parametrize("ep_size", [8])
56
+ @pytest.mark.parametrize("tp_size", [8])
57
+ @pytest.mark.parametrize("use_ep_moe", [True, False])
52
58
  @pytest.mark.parametrize("topk", [8])
53
59
  @pytest.mark.parametrize("group_size", [128])
54
60
  @pytest.mark.parametrize("dtype", [torch.bfloat16])
55
- def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
56
- local_e = E // ep_size
61
+ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dtype):
62
+ if use_ep_moe:
63
+ local_e = E // tp_size
64
+ else: # tp mode
65
+ local_e = E
66
+ N = N // tp_size
57
67
 
58
68
  debug = False
59
69
  if debug:
@@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
87
97
  )
88
98
 
89
99
  w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
90
- w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
100
+ if use_ep_moe:
101
+ w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
102
+ else:
103
+ w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2, 1)
91
104
 
92
105
  device = "cuda"
93
106
  a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
@@ -265,7 +278,9 @@ def ref(
265
278
 
266
279
  gate, fc1 = fc1.chunk(2, dim=-1)
267
280
  fc1 = fc1 * torch.nn.functional.silu(gate)
268
- act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
281
+ act = torch.clamp((fc1 / pre_quant_scale_2.float()), -448.0, 448.0).to(
282
+ torch.float8_e4m3fn
283
+ )
269
284
  act = act.to(dtype)
270
285
 
271
286
  w2 = ref_weight_2[e_idx]
@@ -0,0 +1,66 @@
1
+ import time
2
+
3
+ import requests
4
+
5
+ from sglang.srt.utils import kill_process_tree
6
+ from sglang.test.test_utils import (
7
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
8
+ CustomTestCase,
9
+ popen_with_error_check,
10
+ )
11
+
12
+
13
+ class TestDisaggregationBase(CustomTestCase):
14
+ @classmethod
15
+ def setUpClass(cls):
16
+ cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
17
+ pass
18
+
19
+ @classmethod
20
+ def launch_lb(cls):
21
+ lb_command = [
22
+ "python3",
23
+ "-m",
24
+ "sglang_router.launch_router",
25
+ "--pd-disaggregation",
26
+ "--mini-lb", # FIXME: remove this
27
+ "--prefill",
28
+ cls.prefill_url,
29
+ "--decode",
30
+ cls.decode_url,
31
+ "--host",
32
+ cls.base_host,
33
+ "--port",
34
+ cls.lb_port,
35
+ ]
36
+ print("Starting load balancer:", " ".join(lb_command))
37
+ cls.process_lb = popen_with_error_check(lb_command)
38
+ cls.wait_server_ready(cls.lb_url + "/health")
39
+
40
+ @classmethod
41
+ def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
42
+ start_time = time.perf_counter()
43
+ while True:
44
+ try:
45
+ response = requests.get(url)
46
+ if response.status_code == 200:
47
+ print(f"Server {url} is ready")
48
+ return
49
+ except Exception:
50
+ pass
51
+
52
+ if time.perf_counter() - start_time > timeout:
53
+ raise RuntimeError(f"Server {url} failed to start in {timeout}s")
54
+ time.sleep(1)
55
+
56
+ @classmethod
57
+ def tearDownClass(cls):
58
+ for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
59
+ if process:
60
+ try:
61
+ kill_process_tree(process.pid)
62
+ except Exception as e:
63
+ print(f"Error killing process {process.pid}: {e}")
64
+
65
+ # wait for 5 seconds
66
+ time.sleep(5)