sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,452 @@
1
+ # https://raw.githubusercontent.com/ROCm/rocmProfileData/refs/heads/master/tools/rpd2tracing.py
2
+ # commit 92d13a08328625463e9ba944cece82fc5eea36e6
3
+ def rpd_to_chrome_trace(
4
+ input_rpd, output_json=None, start="0%", end="100%", format="object"
5
+ ):
6
+ import gzip
7
+ import sqlite3
8
+
9
+ if output_json is None:
10
+ import pathlib
11
+
12
+ output_json = pathlib.PurePath(input_rpd).with_suffix(".trace.json.gz")
13
+
14
+ connection = sqlite3.connect(input_rpd)
15
+
16
+ outfile = gzip.open(output_json, "wt", encoding="utf-8")
17
+
18
+ if format == "object":
19
+ outfile.write('{"traceEvents": ')
20
+
21
+ outfile.write("[ {}\n")
22
+
23
+ for row in connection.execute("select distinct gpuId from rocpd_op"):
24
+ try:
25
+ outfile.write(
26
+ ',{"name": "process_name", "ph": "M", "pid":"%s","args":{"name":"%s"}}\n'
27
+ % (row[0], "GPU" + str(row[0]))
28
+ )
29
+ outfile.write(
30
+ ',{"name": "process_sort_index", "ph": "M", "pid":"%s","args":{"sort_index":"%s"}}\n'
31
+ % (row[0], row[0] + 1000000)
32
+ )
33
+ except ValueError:
34
+ outfile.write("")
35
+
36
+ for row in connection.execute("select distinct pid, tid from rocpd_api"):
37
+ try:
38
+ outfile.write(
39
+ ',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n'
40
+ % (row[0], row[1], "Hip " + str(row[1]))
41
+ )
42
+ outfile.write(
43
+ ',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n'
44
+ % (row[0], row[1], row[1] * 2)
45
+ )
46
+ except ValueError:
47
+ outfile.write("")
48
+
49
+ try:
50
+ # FIXME - these aren't rendering correctly in chrome://tracing
51
+ for row in connection.execute("select distinct pid, tid from rocpd_hsaApi"):
52
+ try:
53
+ outfile.write(
54
+ ',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n'
55
+ % (row[0], row[1], "HSA " + str(row[1]))
56
+ )
57
+ outfile.write(
58
+ ',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n'
59
+ % (row[0], row[1], row[1] * 2 - 1)
60
+ )
61
+ except ValueError:
62
+ outfile.write("")
63
+ except:
64
+ pass
65
+
66
+ rangeStringApi = ""
67
+ rangeStringOp = ""
68
+ rangeStringMonitor = ""
69
+ min_time = connection.execute("select MIN(start) from rocpd_api;").fetchall()[0][0]
70
+ max_time = connection.execute("select MAX(end) from rocpd_api;").fetchall()[0][0]
71
+ if min_time == None:
72
+ raise Exception("Trace file is empty.")
73
+
74
+ print("Timestamps:")
75
+ print(f"\t first: \t{min_time/1000} us")
76
+ print(f"\t last: \t{max_time/1000} us")
77
+ print(f"\t duration: \t{(max_time-min_time) / 1000000000} seconds")
78
+
79
+ start_time = min_time / 1000
80
+ end_time = max_time / 1000
81
+
82
+ if start:
83
+ if "%" in start:
84
+ start_time = (
85
+ (max_time - min_time) * (int(start.replace("%", "")) / 100) + min_time
86
+ ) / 1000
87
+ else:
88
+ start_time = int(start)
89
+ rangeStringApi = "where rocpd_api.start/1000 >= %s" % (start_time)
90
+ rangeStringOp = "where rocpd_op.start/1000 >= %s" % (start_time)
91
+ rangeStringMonitor = "where start/1000 >= %s" % (start_time)
92
+ if end:
93
+ if "%" in end:
94
+ end_time = (
95
+ (max_time - min_time) * (int(end.replace("%", "")) / 100) + min_time
96
+ ) / 1000
97
+ else:
98
+ end_time = int(end)
99
+
100
+ rangeStringApi = (
101
+ rangeStringApi + " and rocpd_api.start/1000 <= %s" % (end_time)
102
+ if start != None
103
+ else "where rocpd_api.start/1000 <= %s" % (end_time)
104
+ )
105
+ rangeStringOp = (
106
+ rangeStringOp + " and rocpd_op.start/1000 <= %s" % (end_time)
107
+ if start != None
108
+ else "where rocpd_op.start/1000 <= %s" % (end_time)
109
+ )
110
+ rangeStringMonitor = (
111
+ rangeStringMonitor + " and start/1000 <= %s" % (end_time)
112
+ if start != None
113
+ else "where start/1000 <= %s" % (end_time)
114
+ )
115
+
116
+ print("\nFilter: %s" % (rangeStringApi))
117
+ print(f"Output duration: {(end_time-start_time)/1000000} seconds")
118
+
119
+ # Output Ops
120
+
121
+ for row in connection.execute(
122
+ "select A.string as optype, B.string as description, gpuId, queueId, rocpd_op.start/1000.0, (rocpd_op.end-rocpd_op.start) / 1000.0 from rocpd_op INNER JOIN rocpd_string A on A.id = rocpd_op.opType_id INNER Join rocpd_string B on B.id = rocpd_op.description_id %s"
123
+ % (rangeStringOp)
124
+ ):
125
+ try:
126
+ name = row[0] if len(row[1]) == 0 else row[1]
127
+ outfile.write(
128
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
129
+ % (row[2], row[3], name, row[4], row[5], row[0])
130
+ )
131
+ except ValueError:
132
+ outfile.write("")
133
+
134
+ # Output Graph executions on GPU
135
+ try:
136
+ for row in connection.execute(
137
+ "select graphExec, gpuId, queueId, min(start)/1000.0, (max(end)-min(start))/1000.0, count(*) from rocpd_graphLaunchapi A join rocpd_api_ops B on B.api_id = A.api_ptr_id join rocpd_op C on C.id = B.op_id %s group by api_ptr_id"
138
+ % (rangeStringMonitor)
139
+ ):
140
+ try:
141
+ outfile.write(
142
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"kernels":"%s"}}\n'
143
+ % (row[1], row[2], f"Graph {row[0]}", row[3], row[4], row[5])
144
+ )
145
+ except ValueError:
146
+ outfile.write("")
147
+ except:
148
+ pass
149
+
150
+ # Output apis
151
+ for row in connection.execute(
152
+ "select A.string as apiName, B.string as args, pid, tid, rocpd_api.start/1000.0, (rocpd_api.end-rocpd_api.start) / 1000.0, (rocpd_api.end != rocpd_api.start) as has_duration from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id INNER Join rocpd_string B on B.id = rocpd_api.args_id %s order by rocpd_api.id"
153
+ % (rangeStringApi)
154
+ ):
155
+ try:
156
+ if row[0] == "UserMarker":
157
+ if row[6] == 0: # instantanuous "mark" messages
158
+ outfile.write(
159
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","ph":"i","s":"p","args":{"desc":"%s"}}\n'
160
+ % (
161
+ row[2],
162
+ row[3],
163
+ row[1].replace('"', ""),
164
+ row[4],
165
+ row[1].replace('"', ""),
166
+ )
167
+ )
168
+ else:
169
+ outfile.write(
170
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
171
+ % (
172
+ row[2],
173
+ row[3],
174
+ row[1].replace('"', ""),
175
+ row[4],
176
+ row[5],
177
+ row[1].replace('"', ""),
178
+ )
179
+ )
180
+ else:
181
+ outfile.write(
182
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
183
+ % (
184
+ row[2],
185
+ row[3],
186
+ row[0],
187
+ row[4],
188
+ row[5],
189
+ row[1].replace('"', "").replace("\t", ""),
190
+ )
191
+ )
192
+ except ValueError:
193
+ outfile.write("")
194
+
195
+ # Output api->op linkage
196
+ for row in connection.execute(
197
+ "select rocpd_api_ops.id, pid, tid, gpuId, queueId, rocpd_api.end/1000.0 - 2, rocpd_op.start/1000.0 from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id %s"
198
+ % (rangeStringApi)
199
+ ):
200
+ try:
201
+ fromtime = row[5] if row[5] < row[6] else row[6]
202
+ outfile.write(
203
+ ',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"s"}\n'
204
+ % (row[1], row[2], fromtime, row[0])
205
+ )
206
+ outfile.write(
207
+ ',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"f", "bp":"e"}\n'
208
+ % (row[3], row[4], row[6], row[0])
209
+ )
210
+ except ValueError:
211
+ outfile.write("")
212
+
213
+ try:
214
+ for row in connection.execute(
215
+ "select A.string as apiName, B.string as args, pid, tid, rocpd_hsaApi.start/1000.0, (rocpd_hsaApi.end-rocpd_hsaApi.start) / 1000.0 from rocpd_hsaApi INNER JOIN rocpd_string A on A.id = rocpd_hsaApi.apiName_id INNER Join rocpd_string B on B.id = rocpd_hsaApi.args_id %s order by rocpd_hsaApi.id"
216
+ % (rangeStringApi)
217
+ ):
218
+ try:
219
+ outfile.write(
220
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
221
+ % (
222
+ row[2],
223
+ row[3] + 1,
224
+ row[0],
225
+ row[4],
226
+ row[5],
227
+ row[1].replace('"', ""),
228
+ )
229
+ )
230
+ except ValueError:
231
+ outfile.write("")
232
+ except:
233
+ pass
234
+
235
+ #
236
+ # Counters
237
+ #
238
+
239
+ # Counters should extend to the last event in the trace. This means they need to have a value at Tend.
240
+ # Figure out when that is
241
+
242
+ T_end = 0
243
+ for row in connection.execute(
244
+ "SELECT max(end)/1000 from (SELECT end from rocpd_api UNION ALL SELECT end from rocpd_op)"
245
+ ):
246
+ T_end = int(row[0])
247
+ if end:
248
+ T_end = end_time
249
+
250
+ # Loop over GPU for per-gpu counters
251
+ gpuIdsPresent = []
252
+ for row in connection.execute("SELECT DISTINCT gpuId FROM rocpd_op"):
253
+ gpuIdsPresent.append(row[0])
254
+
255
+ for gpuId in gpuIdsPresent:
256
+ # print(f"Creating counters for: {gpuId}")
257
+
258
+ # Create the queue depth counter
259
+ depth = 0
260
+ idle = 1
261
+ for row in connection.execute(
262
+ 'select * from (select rocpd_api.start/1000.0 as ts, "1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s UNION ALL select rocpd_op.end/1000.0, "-1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s) order by ts'
263
+ % (gpuId, rangeStringOp, gpuId, rangeStringOp)
264
+ ):
265
+ try:
266
+ if idle and int(row[1]) > 0:
267
+ idle = 0
268
+ outfile.write(
269
+ ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
270
+ % (gpuId, row[0], idle)
271
+ )
272
+ if depth == 1 and int(row[1]) < 0:
273
+ idle = 1
274
+ outfile.write(
275
+ ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
276
+ % (gpuId, row[0], idle)
277
+ )
278
+ depth = depth + int(row[1])
279
+ outfile.write(
280
+ ',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n'
281
+ % (gpuId, row[0], depth)
282
+ )
283
+ except ValueError:
284
+ outfile.write("")
285
+ if T_end > 0:
286
+ outfile.write(
287
+ ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
288
+ % (gpuId, T_end, idle)
289
+ )
290
+ outfile.write(
291
+ ',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n'
292
+ % (gpuId, T_end, depth)
293
+ )
294
+
295
+ # Create SMI counters
296
+ try:
297
+ for row in connection.execute(
298
+ "select deviceId, monitorType, start/1000.0, value from rocpd_monitor %s"
299
+ % (rangeStringMonitor)
300
+ ):
301
+ outfile.write(
302
+ ',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n'
303
+ % (row[0], row[1], row[2], row[1], row[3])
304
+ )
305
+ # Output the endpoints of the last range
306
+ for row in connection.execute(
307
+ "select distinct deviceId, monitorType, max(end)/1000.0, value from rocpd_monitor %s group by deviceId, monitorType"
308
+ % (rangeStringMonitor)
309
+ ):
310
+ outfile.write(
311
+ ',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n'
312
+ % (row[0], row[1], row[2], row[1], row[3])
313
+ )
314
+ except:
315
+ print("Did not find SMI data")
316
+
317
+ # Create the (global) memory counter
318
+ """
319
+ sizes = {} # address -> size
320
+ totalSize = 0
321
+ exp = re.compile("^ptr\((.*)\)\s+size\((.*)\)$")
322
+ exp2 = re.compile("^ptr\((.*)\)$")
323
+ for row in connection.execute("SELECT rocpd_api.end/1000.0 as ts, B.string, '1' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipFree' UNION ALL SELECT rocpd_api.start/1000.0, B.string, '0' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipMalloc' ORDER BY ts asc"):
324
+ try:
325
+ if row[2] == '0': #malloc
326
+ m = exp.match(row[1])
327
+ if m:
328
+ size = int(m.group(2), 16)
329
+ totalSize = totalSize + size
330
+ sizes[m.group(1)] = size
331
+ outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize))
332
+ else: #free
333
+ m = exp2.match(row[1])
334
+ if m:
335
+ try: # Sometimes free addresses are not valid or listed
336
+ size = sizes[m.group(1)]
337
+ sizes[m.group(1)] = 0
338
+ totalSize = totalSize - size;
339
+ outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize))
340
+ except KeyError:
341
+ pass
342
+ except ValueError:
343
+ outfile.write("")
344
+ if T_end > 0:
345
+ outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(T_end,totalSize))
346
+ """
347
+
348
+ # Create "faux calling stack frame" on gpu ops traceS
349
+ stacks = {} # Call stacks built from UserMarker entres. Key is 'pid,tid'
350
+ currentFrame = {} # "Current GPU frame" (id, name, start, end). Key is 'pid,tid'
351
+
352
+ class GpuFrame:
353
+ def __init__(self):
354
+ self.id = 0
355
+ self.name = ""
356
+ self.start = 0
357
+ self.end = 0
358
+ self.gpus = []
359
+ self.totalOps = 0
360
+
361
+ # FIXME: include 'start' (in ns) so we can ORDER BY it and break ties?
362
+ for row in connection.execute(
363
+ "SELECT '0', start/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '1', end/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '2', rocpd_api.start/1000.0, pid, tid, '' as label, gpuId, queueId, rocpd_op.start/1000.0, rocpd_op.end/1000.0 from rocpd_api_ops INNER JOIN rocpd_api ON rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op ON rocpd_api_ops.op_id = rocpd_op.id %s ORDER BY start/1000.0 asc"
364
+ % (rangeStringApi, rangeStringApi, rangeStringApi)
365
+ ):
366
+ try:
367
+ key = (row[2], row[3]) # Key is 'pid,tid'
368
+ if row[0] == "0": # Frame start
369
+ if key not in stacks:
370
+ stacks[key] = []
371
+ stack = stacks[key].append((row[1], row[4]))
372
+ # print(f"0: new api frame: pid_tid={key} -> stack={stacks}")
373
+
374
+ elif row[0] == "1": # Frame end
375
+ completed = stacks[key].pop()
376
+ # print(f"1: end api frame: pid_tid={key} -> stack={stacks}")
377
+
378
+ elif row[0] == "2": # API + Op
379
+ if key in stacks and len(stacks[key]) > 0:
380
+ frame = stacks[key][-1]
381
+ # print(f"2: Op on {frame} ({len(stacks[key])})")
382
+ gpuFrame = None
383
+ if key not in currentFrame: # First op under the current api frame
384
+ gpuFrame = GpuFrame()
385
+ gpuFrame.id = frame[0]
386
+ gpuFrame.name = frame[1]
387
+ gpuFrame.start = row[7]
388
+ gpuFrame.end = row[8]
389
+ gpuFrame.gpus.append((row[5], row[6]))
390
+ gpuFrame.totalOps = 1
391
+ # print(f"2a: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
392
+ else:
393
+ gpuFrame = currentFrame[key]
394
+ # Another op under the same frame -> union them (but only if they are butt together)
395
+ if (
396
+ gpuFrame.id == frame[0]
397
+ and gpuFrame.name == frame[1]
398
+ and (
399
+ abs(row[7] - gpuFrame.end) < 200
400
+ or abs(gpuFrame.start - row[8]) < 200
401
+ )
402
+ ):
403
+ # if gpuFrame.id == frame[0] and gpuFrame.name == frame[1]: # Another op under the same frame -> union them
404
+ # if False: # Turn off frame joining
405
+ if row[7] < gpuFrame.start:
406
+ gpuFrame.start = row[7]
407
+ if row[8] > gpuFrame.end:
408
+ gpuFrame.end = row[8]
409
+ if (row[5], row[6]) not in gpuFrame.gpus:
410
+ gpuFrame.gpus.append((row[5], row[6]))
411
+ gpuFrame.totalOps = gpuFrame.totalOps + 1
412
+ # print(f"2c: union frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
413
+
414
+ else: # This is a new frame - dump the last and make new
415
+ gpuFrame = currentFrame[key]
416
+ for dest in gpuFrame.gpus:
417
+ # print(f"2: OUTPUT: dest={dest} time={gpuFrame.start} -> {gpuFrame.end} Duration={gpuFrame.end - gpuFrame.start} TotalOps={gpuFrame.totalOps}")
418
+ outfile.write(
419
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
420
+ % (
421
+ dest[0],
422
+ dest[1],
423
+ gpuFrame.name.replace('"', ""),
424
+ gpuFrame.start - 1,
425
+ gpuFrame.end - gpuFrame.start + 1,
426
+ f"UserMarker frame: {gpuFrame.totalOps} ops",
427
+ )
428
+ )
429
+ currentFrame.pop(key)
430
+
431
+ # make the first op under the new frame
432
+ gpuFrame = GpuFrame()
433
+ gpuFrame.id = frame[0]
434
+ gpuFrame.name = frame[1]
435
+ gpuFrame.start = row[7]
436
+ gpuFrame.end = row[8]
437
+ gpuFrame.gpus.append((row[5], row[6]))
438
+ gpuFrame.totalOps = 1
439
+ # print(f"2b: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
440
+
441
+ currentFrame[key] = gpuFrame
442
+
443
+ except ValueError:
444
+ outfile.write("")
445
+
446
+ outfile.write("]\n")
447
+
448
+ if format == "object":
449
+ outfile.write("} \n")
450
+
451
+ outfile.close()
452
+ connection.close()
@@ -0,0 +1,71 @@
1
+ import logging
2
+ from typing import Any, Dict, List
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ import triton
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def execute():
12
+ if dist.get_rank() == 0:
13
+ logger.info(f"[slow_rank_detector] Start benchmarking...")
14
+
15
+ local_metrics = {
16
+ bench_name: _compute_local_metric(bench_name) for bench_name in _BENCH_NAMES
17
+ }
18
+
19
+ all_metrics = [None for _ in range(dist.get_world_size())]
20
+ dist.gather_object(local_metrics, all_metrics if dist.get_rank() == 0 else None)
21
+
22
+ if dist.get_rank() == 0:
23
+ _analyze_metrics(all_metrics)
24
+
25
+
26
+ class _GemmExecutor:
27
+ def __init__(self):
28
+ self.lhs = torch.randn((8192, 8192), dtype=torch.bfloat16, device="cuda")
29
+ self.rhs = torch.randn((8192, 8192), dtype=torch.bfloat16, device="cuda")
30
+
31
+ def __call__(self):
32
+ self.lhs @ self.rhs
33
+
34
+
35
+ class _ElementwiseExecutor:
36
+ def __init__(self):
37
+ self.value = torch.randint(
38
+ 0, 10000, (128 * 1024**2,), dtype=torch.int32, device="cuda"
39
+ )
40
+
41
+ def __call__(self):
42
+ self.value += 1
43
+
44
+
45
+ _EXECUTOR_CLS_OF_BENCH = {
46
+ "gemm": _GemmExecutor,
47
+ "elementwise": _ElementwiseExecutor,
48
+ }
49
+
50
+ _BENCH_NAMES = list(_EXECUTOR_CLS_OF_BENCH.keys())
51
+
52
+
53
+ def _compute_local_metric(bench_name):
54
+ executor = _EXECUTOR_CLS_OF_BENCH[bench_name]()
55
+ ms = triton.testing.do_bench_cudagraph(executor, return_mode="mean", rep=20)
56
+ return ms
57
+
58
+
59
+ def _analyze_metrics(all_metrics: List[Dict[str, Any]]):
60
+ for bench_name in _BENCH_NAMES:
61
+ time_of_rank = torch.tensor([m[bench_name] for m in all_metrics])
62
+ speed_of_rank = 1 / time_of_rank
63
+ rel_speed_of_rank = speed_of_rank / speed_of_rank.max()
64
+ slowest_rel_speed = rel_speed_of_rank.min().item()
65
+ logger.info(
66
+ f"[slow_rank_detector] {bench_name=} {slowest_rel_speed=} {rel_speed_of_rank=} {time_of_rank=}"
67
+ )
68
+ if slowest_rel_speed < 0.9:
69
+ logger.warning(
70
+ "[slow_rank_detector] Some ranks are too slow compared with others"
71
+ )
@@ -1,8 +1,6 @@
1
1
  import logging
2
- import threading
3
- import time
4
2
  from abc import ABC
5
- from contextlib import contextmanager, nullcontext
3
+ from contextlib import contextmanager
6
4
 
7
5
  try:
8
6
  import torch_memory_saver
@@ -40,7 +38,7 @@ class TorchMemorySaverAdapter(ABC):
40
38
  def configure_subprocess(self):
41
39
  raise NotImplementedError
42
40
 
43
- def region(self, tag: str):
41
+ def region(self, tag: str, enable_cpu_backup: bool = False):
44
42
  raise NotImplementedError
45
43
 
46
44
  def pause(self, tag: str):
@@ -60,8 +58,8 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
60
58
  def configure_subprocess(self):
61
59
  return torch_memory_saver.configure_subprocess()
62
60
 
63
- def region(self, tag: str):
64
- return _memory_saver.region(tag=tag)
61
+ def region(self, tag: str, enable_cpu_backup: bool = False):
62
+ return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
65
63
 
66
64
  def pause(self, tag: str):
67
65
  return _memory_saver.pause(tag=tag)
@@ -80,7 +78,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
80
78
  yield
81
79
 
82
80
  @contextmanager
83
- def region(self, tag: str):
81
+ def region(self, tag: str, enable_cpu_backup: bool = False):
84
82
  yield
85
83
 
86
84
  def pause(self, tag: str):
sglang/srt/warmup.py CHANGED
@@ -1,20 +1,24 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
- from typing import List
4
+ from typing import TYPE_CHECKING, List
3
5
 
4
6
  import numpy as np
5
7
  import tqdm
6
8
 
7
9
  from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
8
10
  from sglang.srt.managers.io_struct import GenerateReqInput
9
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
10
14
 
11
15
  logger = logging.getLogger(__file__)
12
16
 
13
17
  _warmup_registry = {}
14
18
 
15
19
 
16
- def warmup(name: str) -> callable:
17
- def decorator(fn: callable):
20
+ def warmup(name: str):
21
+ def decorator(fn):
18
22
  _warmup_registry[name] = fn
19
23
  return fn
20
24
 
@@ -33,7 +33,7 @@ async def update_weights(
33
33
  """
34
34
  infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
35
35
  infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
36
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
36
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
37
37
 
38
38
  monkey_patch_torch_reductions()
39
39
 
@@ -66,7 +66,7 @@ class MockModelRunner:
66
66
  enable_memory_saver=False,
67
67
  )
68
68
  # Required by torch native backend
69
- self.server_args = ServerArgs(model_path="fake_model_path")
69
+ self.server_args = ServerArgs(model_path="dummy")
70
70
 
71
71
 
72
72
  @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
@@ -4,7 +4,6 @@ import torch
4
4
 
5
5
  from sglang.srt.configs.model_config import AttentionArch
6
6
  from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
7
- from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
8
7
  from sglang.srt.layers.radix_attention import RadixAttention
9
8
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
10
9
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -2,8 +2,6 @@ import unittest
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
6
- from sglang.srt.layers.radix_attention import RadixAttention
7
5
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
8
6
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
7
  from sglang.test.test_utils import CustomTestCase