sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
1
+ from typing import List, Optional
2
+
3
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
4
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
5
+ from sglang.srt.managers.utils import GenerationBatchResult
6
+ from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
7
+ from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj
8
+
9
+
10
+ class SchedulerPPMixin:
11
+
12
+ @DynamicGradMode()
13
+ def event_loop_pp(self):
14
+ """A non-overlap scheduler loop for pipeline parallelism."""
15
+ mbs = [None] * self.pp_size
16
+ last_mbs = [None] * self.pp_size
17
+ self.running_mbs = [
18
+ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
19
+ ]
20
+ pp_outputs: Optional[PPProxyTensors] = None
21
+ while True:
22
+ server_is_idle = True
23
+ for mb_id in range(self.pp_size):
24
+ self.running_batch = self.running_mbs[mb_id]
25
+ self.last_batch = last_mbs[mb_id]
26
+
27
+ recv_reqs = self.recv_requests()
28
+ self.process_input_requests(recv_reqs)
29
+ mbs[mb_id] = self.get_next_batch_to_run()
30
+ self.running_mbs[mb_id] = self.running_batch
31
+
32
+ self.cur_batch = mbs[mb_id]
33
+ if self.cur_batch:
34
+ server_is_idle = False
35
+ result = self.run_batch(self.cur_batch)
36
+
37
+ # (last rank) send the outputs to the next step
38
+ if self.pp_group.is_last_rank:
39
+ if self.cur_batch:
40
+ next_token_ids = result.next_token_ids
41
+ if self.cur_batch.return_logprob:
42
+ pp_outputs = PPProxyTensors(
43
+ {
44
+ "next_token_ids": next_token_ids,
45
+ "extend_input_len_per_req": result.extend_input_len_per_req,
46
+ "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
47
+ }
48
+ | (
49
+ {
50
+ f"logits_output.{k}": v
51
+ for k, v in result.logits_output.__dict__.items()
52
+ }
53
+ if result.logits_output is not None
54
+ else {}
55
+ )
56
+ )
57
+ else:
58
+ pp_outputs = PPProxyTensors(
59
+ {
60
+ "next_token_ids": next_token_ids,
61
+ }
62
+ )
63
+ # send the output from the last round to let the next stage worker run post processing
64
+ self.pp_group.send_tensor_dict(
65
+ pp_outputs.tensors,
66
+ all_gather_group=self.attn_tp_group,
67
+ )
68
+
69
+ # receive outputs and post-process (filter finished reqs) the coming microbatch
70
+ next_mb_id = (mb_id + 1) % self.pp_size
71
+ next_pp_outputs = None
72
+ if mbs[next_mb_id] is not None:
73
+ next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
74
+ self.pp_group.recv_tensor_dict(
75
+ all_gather_group=self.attn_tp_group
76
+ )
77
+ )
78
+ mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
79
+ logits_output_args = {
80
+ k[len("logits_output.") :]: v
81
+ for k, v in next_pp_outputs.tensors.items()
82
+ if k.startswith("logits_output.")
83
+ }
84
+ if len(logits_output_args) > 0:
85
+ logits_output = LogitsProcessorOutput(**logits_output_args)
86
+ else:
87
+ logits_output = None
88
+
89
+ output_result = GenerationBatchResult.from_pp_proxy(
90
+ logits_output=logits_output,
91
+ next_pp_outputs=next_pp_outputs,
92
+ can_run_cuda_graph=result.can_run_cuda_graph,
93
+ )
94
+ self.process_batch_result(mbs[next_mb_id], output_result)
95
+ last_mbs[next_mb_id] = mbs[next_mb_id]
96
+
97
+ # (not last rank)
98
+ if not self.pp_group.is_last_rank:
99
+ # carry the outputs to the next stage
100
+ # send the outputs from the last round to let the next stage worker run post processing
101
+ if pp_outputs:
102
+ self.pp_group.send_tensor_dict(
103
+ pp_outputs.tensors,
104
+ all_gather_group=self.attn_tp_group,
105
+ )
106
+
107
+ # send out reqs to the next stage
108
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
109
+ if self.attn_tp_rank == 0:
110
+ point_to_point_pyobj(
111
+ recv_reqs,
112
+ self.pp_rank * self.tp_size + dp_offset,
113
+ self.world_group.device_group,
114
+ self.pp_rank * self.tp_size + dp_offset,
115
+ (self.pp_rank + 1) * self.tp_size + dp_offset,
116
+ )
117
+
118
+ # send out proxy tensors to the next stage
119
+ if self.cur_batch:
120
+ # FIXME(lsyin): remove this assert
121
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
122
+ self.pp_group.send_tensor_dict(
123
+ result.pp_hidden_states_proxy_tensors.tensors,
124
+ all_gather_group=self.attn_tp_group,
125
+ )
126
+
127
+ pp_outputs = next_pp_outputs
128
+
129
+ # When the server is idle, self-check and re-init some states
130
+ if server_is_idle:
131
+ # When the server is idle, do self-check and re-init some states
132
+ self.self_check_during_idle()
133
+
134
+ @DynamicGradMode()
135
+ def event_loop_pp_disagg_prefill(self):
136
+ """
137
+ An event loop for the prefill server in pipeline parallelism.
138
+
139
+ Rules:
140
+ 1. Each stage runs in the same order and is notified by the previous stage.
141
+ 2. Each send/recv operation is blocking and matched by the neighboring stage.
142
+
143
+ Regular Schedule:
144
+ ====================================================================
145
+ Stage i | Stage i+1
146
+ send ith req | recv ith req
147
+ send ith proxy | recv ith proxy
148
+ send prev (i+1)th carry | recv prev (i+1)th carry
149
+ ====================================================================
150
+
151
+ Prefill Server Schedule:
152
+ ====================================================================
153
+ Stage i | Stage i+1
154
+ send ith req | recv ith req
155
+ send ith bootstrap req | recv ith bootstrap req
156
+ send ith transferred req | recv ith transferred req
157
+ send ith proxy | recv ith proxy
158
+ send prev (i+1)th carry | recv prev (i+1)th carry
159
+ send prev (i+1)th release req | recv prev (i+1)th release req
160
+ ====================================================================
161
+
162
+ There are two additional elements compared to the regular schedule:
163
+
164
+ 1. Bootstrap Requests:
165
+ a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
166
+ b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
167
+ c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
168
+
169
+ 2. Transferred Requests + Release Requests:
170
+ a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
171
+ b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
172
+ c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
173
+ """
174
+ mbs = [None] * self.pp_size
175
+ last_mbs = [None] * self.pp_size
176
+ self.running_mbs = [
177
+ ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
178
+ ]
179
+ pp_outputs: Optional[PPProxyTensors] = None
180
+
181
+ # Either success or failed
182
+ bootstrapped_rids: List[str] = []
183
+ transferred_rids: List[str] = []
184
+ release_rids: Optional[List[str]] = None
185
+
186
+ # transferred microbatch
187
+ tmbs = [None] * self.pp_size
188
+
189
+ ENABLE_RELEASE = True # For debug
190
+
191
+ while True:
192
+ server_is_idle = True
193
+
194
+ for mb_id in range(self.pp_size):
195
+ self.running_batch = self.running_mbs[mb_id]
196
+ self.last_batch = last_mbs[mb_id]
197
+
198
+ recv_reqs = self.recv_requests()
199
+
200
+ self.process_input_requests(recv_reqs)
201
+
202
+ if self.pp_group.is_first_rank:
203
+ # First rank, pop the bootstrap reqs from the bootstrap queue
204
+ bootstrapped_reqs, failed_reqs = (
205
+ self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
206
+ return_failed_reqs=True
207
+ )
208
+ )
209
+ bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
210
+ req.rid for req in failed_reqs
211
+ ]
212
+ self.waiting_queue.extend(bootstrapped_reqs)
213
+ else:
214
+ # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
215
+ bootstrapped_rids = self.recv_pyobj_from_prev_stage()
216
+ bootstrapped_reqs = (
217
+ self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
218
+ rids_to_check=bootstrapped_rids
219
+ )
220
+ )
221
+ self.waiting_queue.extend(bootstrapped_reqs)
222
+
223
+ if self.pp_group.is_first_rank:
224
+ transferred_rids = self.get_transferred_rids()
225
+ # if other ranks,
226
+ else:
227
+ # 1. recv previous stage's transferred reqs info
228
+ prev_transferred_rids = self.recv_pyobj_from_prev_stage()
229
+ # 2. get the current stage's transferred reqs info
230
+ curr_transferred_rids = self.get_transferred_rids()
231
+ # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
232
+ transferred_rids = list(
233
+ set(prev_transferred_rids) & set(curr_transferred_rids)
234
+ )
235
+
236
+ tmbs[mb_id] = transferred_rids
237
+
238
+ self.process_prefill_chunk()
239
+ mbs[mb_id] = self.get_new_batch_prefill()
240
+ self.running_mbs[mb_id] = self.running_batch
241
+
242
+ self.cur_batch = mbs[mb_id]
243
+ if self.cur_batch:
244
+ server_is_idle = False
245
+ result = self.run_batch(self.cur_batch)
246
+
247
+ # send the outputs to the next step
248
+ if self.pp_group.is_last_rank:
249
+ if self.cur_batch:
250
+ next_token_ids = result.next_token_ids
251
+ pp_outputs = PPProxyTensors(
252
+ {
253
+ "next_token_ids": next_token_ids,
254
+ }
255
+ )
256
+ # send the output from the last round to let the next stage worker run post processing
257
+ self.pp_group.send_tensor_dict(
258
+ pp_outputs.tensors,
259
+ all_gather_group=self.attn_tp_group,
260
+ )
261
+
262
+ if ENABLE_RELEASE:
263
+ if self.pp_group.is_last_rank:
264
+ # At the last stage, all stages has reached the consensus to release memory for transferred_rids
265
+ release_rids = transferred_rids
266
+ # send to the first rank
267
+ self.send_pyobj_to_next_stage(release_rids)
268
+
269
+ # receive outputs and post-process (filter finished reqs) the coming microbatch
270
+ next_mb_id = (mb_id + 1) % self.pp_size
271
+ next_pp_outputs = None
272
+ next_release_rids = None
273
+
274
+ if mbs[next_mb_id] is not None:
275
+ next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
276
+ self.pp_group.recv_tensor_dict(
277
+ all_gather_group=self.attn_tp_group
278
+ )
279
+ )
280
+ mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
281
+ output_result = GenerationBatchResult(
282
+ logits_output=None,
283
+ pp_hidden_states_proxy_tensors=None,
284
+ next_token_ids=next_pp_outputs["next_token_ids"],
285
+ extend_input_len_per_req=None,
286
+ extend_logprob_start_len_per_req=None,
287
+ can_run_cuda_graph=result.can_run_cuda_graph,
288
+ )
289
+ self.process_batch_result_disagg_prefill(
290
+ mbs[next_mb_id], output_result
291
+ )
292
+
293
+ last_mbs[next_mb_id] = mbs[next_mb_id]
294
+
295
+ if ENABLE_RELEASE:
296
+ if tmbs[next_mb_id] is not None:
297
+ # recv consensus rids from the previous rank
298
+ next_release_rids = self.recv_pyobj_from_prev_stage()
299
+ self.process_disagg_prefill_inflight_queue(next_release_rids)
300
+
301
+ # carry the outputs to the next stage
302
+ if not self.pp_group.is_last_rank:
303
+ if pp_outputs:
304
+ # send the outputs from the last round to let the next stage worker run post processing
305
+ self.pp_group.send_tensor_dict(
306
+ pp_outputs.tensors,
307
+ all_gather_group=self.attn_tp_group,
308
+ )
309
+ if ENABLE_RELEASE:
310
+ if release_rids is not None:
311
+ self.send_pyobj_to_next_stage(release_rids)
312
+
313
+ if not self.pp_group.is_last_rank:
314
+ # send out reqs to the next stage
315
+ self.send_pyobj_to_next_stage(recv_reqs)
316
+ self.send_pyobj_to_next_stage(bootstrapped_rids)
317
+ self.send_pyobj_to_next_stage(transferred_rids)
318
+
319
+ # send out proxy tensors to the next stage
320
+ if self.cur_batch:
321
+ # FIXME(lsyin): remove this assert
322
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
323
+ self.pp_group.send_tensor_dict(
324
+ result.pp_hidden_states_proxy_tensors.tensors,
325
+ all_gather_group=self.attn_tp_group,
326
+ )
327
+
328
+ pp_outputs = next_pp_outputs
329
+ release_rids = next_release_rids
330
+
331
+ self.running_batch.batch_is_full = False
332
+
333
+ if not ENABLE_RELEASE:
334
+ if len(self.disagg_prefill_inflight_queue) > 0:
335
+ self.process_disagg_prefill_inflight_queue()
336
+
337
+ # When the server is idle, self-check and re-init some states
338
+ if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
339
+ self.check_memory()
340
+ self.check_tree_cache()
341
+ self.new_token_ratio = self.init_new_token_ratio
@@ -9,6 +9,7 @@ import torch
9
9
  from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
10
10
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
11
11
  from sglang.srt.utils import is_npu
12
+ from sglang.srt.utils.profile_merger import ProfileMerger
12
13
 
13
14
  _is_npu = is_npu()
14
15
  if _is_npu:
@@ -25,10 +26,9 @@ logger = logging.getLogger(__name__)
25
26
 
26
27
 
27
28
  class SchedulerProfilerMixin:
28
-
29
29
  def init_profiler(self):
30
30
  self.torch_profiler = None
31
- self.torch_profiler_output_dir: Optional[str] = None
31
+ self.torch_profiler_output_dir: Optional[Path] = None
32
32
  self.profiler_activities: Optional[List[str]] = None
33
33
  self.profile_id: Optional[str] = None
34
34
  self.profiler_start_forward_ct: Optional[int] = None
@@ -41,6 +41,7 @@ class SchedulerProfilerMixin:
41
41
  self.profile_steps: Optional[int] = None
42
42
  self.profile_in_progress: bool = False
43
43
  self.rpd_profiler = None
44
+ self.merge_profiles = False
44
45
 
45
46
  def init_profile(
46
47
  self,
@@ -52,6 +53,7 @@ class SchedulerProfilerMixin:
52
53
  record_shapes: Optional[bool],
53
54
  profile_by_stage: bool,
54
55
  profile_id: str,
56
+ merge_profiles: bool = False,
55
57
  ) -> ProfileReqOutput:
56
58
  if self.profile_in_progress:
57
59
  return ProfileReqOutput(
@@ -60,13 +62,14 @@ class SchedulerProfilerMixin:
60
62
  )
61
63
 
62
64
  self.profile_by_stage = profile_by_stage
65
+ self.merge_profiles = merge_profiles
63
66
 
64
67
  if output_dir is None:
65
68
  output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
66
69
  if activities is None:
67
70
  activities = ["CPU", "GPU"]
68
71
 
69
- self.torch_profiler_output_dir = output_dir
72
+ self.torch_profiler_output_dir = Path(output_dir).expanduser()
70
73
  self.torch_profiler_with_stack = with_stack
71
74
  self.torch_profiler_record_shapes = record_shapes
72
75
  self.profiler_activities = activities
@@ -169,6 +172,38 @@ class SchedulerProfilerMixin:
169
172
 
170
173
  return ProfileReqOutput(success=True, message="Succeeded")
171
174
 
175
+ def _merge_profile_traces(self) -> str:
176
+ if not self.merge_profiles:
177
+ return ""
178
+
179
+ if self.tp_rank != 0:
180
+ return ""
181
+ if getattr(self, "dp_size", 1) > 1 and getattr(self, "dp_rank", 0) != 0:
182
+ return ""
183
+ if getattr(self, "pp_size", 1) > 1 and getattr(self, "pp_rank", 0) != 0:
184
+ return ""
185
+ if getattr(self, "moe_ep_size", 1) > 1 and getattr(self, "moe_ep_rank", 0) != 0:
186
+ return ""
187
+
188
+ try:
189
+ logger.info("Starting profile merge...")
190
+ merger = ProfileMerger(self.torch_profiler_output_dir, self.profile_id)
191
+ merged_path = merger.merge_chrome_traces()
192
+
193
+ summary = merger.get_merge_summary()
194
+ merge_message = (
195
+ f" Merged trace: {merged_path} "
196
+ f"(Events: {summary.get('total_events', '?')}, "
197
+ f"Files: {summary.get('total_files', '?')})"
198
+ )
199
+
200
+ logger.info(f"Profile merge completed: {merged_path}")
201
+ except Exception as e:
202
+ logger.error(f"Failed to merge profiles: {e}", exc_info=True)
203
+ return f" Merge failed: {e!s}"
204
+ else:
205
+ return merge_message
206
+
172
207
  def stop_profile(
173
208
  self, stage: Optional[ForwardMode] = None
174
209
  ) -> ProfileReqOutput | None:
@@ -178,22 +213,28 @@ class SchedulerProfilerMixin:
178
213
  message="Profiling is not in progress. Call /start_profile first.",
179
214
  )
180
215
 
181
- if not Path(self.torch_profiler_output_dir).exists():
182
- Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
216
+ self.torch_profiler_output_dir.mkdir(parents=True, exist_ok=True)
183
217
 
184
218
  stage_suffix = f"-{stage.name}" if stage else ""
185
219
  logger.info("Stop profiling" + stage_suffix + "...")
186
220
  if self.torch_profiler is not None:
187
221
  self.torch_profiler.stop()
188
222
  if not _is_npu:
223
+ # Build filename with only non-zero ranks to maintain backward compatibility
224
+ filename_parts = [self.profile_id, f"TP-{self.tp_rank}"]
225
+
226
+ # Only add other ranks if parallelism is enabled (size > 1)
227
+ if getattr(self, "dp_size", 1) > 1:
228
+ filename_parts.append(f"DP-{getattr(self, 'dp_rank', 0)}")
229
+ if getattr(self, "pp_size", 1) > 1:
230
+ filename_parts.append(f"PP-{getattr(self, 'pp_rank', 0)}")
231
+ if getattr(self, "moe_ep_size", 1) > 1:
232
+ filename_parts.append(f"EP-{getattr(self, 'moe_ep_rank', 0)}")
233
+
234
+ filename = "-".join(filename_parts) + stage_suffix + ".trace.json.gz"
235
+
189
236
  self.torch_profiler.export_chrome_trace(
190
- os.path.join(
191
- self.torch_profiler_output_dir,
192
- self.profile_id
193
- + f"-TP-{self.tp_rank}"
194
- + stage_suffix
195
- + ".trace.json.gz",
196
- )
237
+ os.path.join(self.torch_profiler_output_dir, filename)
197
238
  )
198
239
  torch.distributed.barrier(self.tp_cpu_group)
199
240
 
@@ -224,15 +265,18 @@ class SchedulerProfilerMixin:
224
265
  if "CUDA_PROFILER" in self.profiler_activities:
225
266
  torch.cuda.cudart().cudaProfilerStop()
226
267
 
268
+ merge_message = self._merge_profile_traces()
269
+
227
270
  logger.info(
228
- "Profiling done. Traces are saved to: %s",
271
+ "Profiling done. Traces are saved to: %s%s",
229
272
  self.torch_profiler_output_dir,
273
+ merge_message,
230
274
  )
231
275
  self.torch_profiler = None
232
276
  self.profile_in_progress = False
233
277
  self.profiler_start_forward_ct = None
234
278
 
235
- return ProfileReqOutput(success=True, message="Succeeded.")
279
+ return ProfileReqOutput(success=True, message=f"Succeeded.{merge_message}")
236
280
 
237
281
  def _profile_batch_predicate(self, batch):
238
282
  if self.profile_by_stage:
@@ -282,6 +326,7 @@ class SchedulerProfilerMixin:
282
326
  recv_req.record_shapes,
283
327
  recv_req.profile_by_stage,
284
328
  recv_req.profile_id,
329
+ recv_req.merge_profiles,
285
330
  )
286
331
  else:
287
332
  self.init_profile(
@@ -293,6 +338,7 @@ class SchedulerProfilerMixin:
293
338
  recv_req.record_shapes,
294
339
  recv_req.profile_by_stage,
295
340
  recv_req.profile_id,
341
+ recv_req.merge_profiles,
296
342
  )
297
343
  return self.start_profile()
298
344
  else: