sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -38,10 +38,14 @@ from sglang.srt.configs import (
38
38
  ChatGLMConfig,
39
39
  DbrxConfig,
40
40
  DeepseekVL2Config,
41
+ DotsOCRConfig,
42
+ DotsVLMConfig,
41
43
  ExaoneConfig,
44
+ FalconH1Config,
42
45
  KimiVLConfig,
43
46
  LongcatFlashConfig,
44
47
  MultiModalityConfig,
48
+ Qwen3NextConfig,
45
49
  Step3VLConfig,
46
50
  )
47
51
  from sglang.srt.configs.internvl import InternVLChatConfig
@@ -58,6 +62,10 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
58
62
  InternVLChatConfig.model_type: InternVLChatConfig,
59
63
  Step3VLConfig.model_type: Step3VLConfig,
60
64
  LongcatFlashConfig.model_type: LongcatFlashConfig,
65
+ Qwen3NextConfig.model_type: Qwen3NextConfig,
66
+ FalconH1Config.model_type: FalconH1Config,
67
+ DotsVLMConfig.model_type: DotsVLMConfig,
68
+ DotsOCRConfig.model_type: DotsOCRConfig,
61
69
  }
62
70
 
63
71
  for name, cls in _CONFIG_REGISTRY.items():
@@ -115,6 +123,38 @@ def get_hf_text_config(config: PretrainedConfig):
115
123
  return config
116
124
 
117
125
 
126
+ # Temporary hack for DeepSeek-V3.2 model
127
+ def _load_deepseek_v32_model(
128
+ model_path: str,
129
+ trust_remote_code: bool = False,
130
+ revision: Optional[str] = None,
131
+ **kwargs,
132
+ ):
133
+ # first get the local path
134
+ local_path = download_from_hf(model_path)
135
+ # then load the config file in json
136
+ config_file = os.path.join(local_path, "config.json")
137
+ if not os.path.exists(config_file):
138
+ raise RuntimeError(f"Can't find config file in {local_path}.")
139
+
140
+ with open(config_file, "r") as f:
141
+ config_json = json.load(f)
142
+
143
+ config_json["architectures"] = ["DeepseekV3ForCausalLM"]
144
+ config_json["model_type"] = "deepseek_v3"
145
+
146
+ tmp_path = os.path.join(local_path, "_tmp_config_folder")
147
+ os.makedirs(tmp_path, exist_ok=True)
148
+
149
+ unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
150
+ with open(unique_path, "w") as f:
151
+ json.dump(config_json, f)
152
+
153
+ return AutoConfig.from_pretrained(
154
+ unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
155
+ )
156
+
157
+
118
158
  @lru_cache_frozenset(maxsize=32)
119
159
  def get_config(
120
160
  model: str,
@@ -136,9 +176,17 @@ def get_config(
136
176
  client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
137
177
  model = client.get_local_dir()
138
178
 
139
- config = AutoConfig.from_pretrained(
140
- model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
141
- )
179
+ try:
180
+ config = AutoConfig.from_pretrained(
181
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
182
+ )
183
+ except ValueError as e:
184
+ if not "deepseek_v32" in str(e):
185
+ raise e
186
+ config = _load_deepseek_v32_model(
187
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
188
+ )
189
+
142
190
  if (
143
191
  config.architectures is not None
144
192
  and config.architectures[0] == "Phi4MMForCausalLM"
@@ -370,8 +418,8 @@ def get_processor(
370
418
  **kwargs,
371
419
  )
372
420
 
373
- # fix: for Qwen2-VL model, inject default 'size' if not provided.
374
- if config.model_type in {"qwen2_vl"}:
421
+ # fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
422
+ if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
375
423
  if "size" not in kwargs:
376
424
  kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
377
425
 
@@ -17,10 +17,18 @@ import torch
17
17
  from packaging import version
18
18
  from torch.multiprocessing import reductions
19
19
 
20
+ from sglang.srt.utils import is_npu
21
+
22
+ _is_npu = is_npu()
23
+
20
24
 
21
25
  def monkey_patch_torch_reductions():
22
26
  """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
23
27
 
28
+ # Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter.
29
+ if _is_npu:
30
+ return
31
+
24
32
  if hasattr(reductions, "_reduce_tensor_original"):
25
33
  return
26
34
 
@@ -0,0 +1,452 @@
1
+ # https://raw.githubusercontent.com/ROCm/rocmProfileData/refs/heads/master/tools/rpd2tracing.py
2
+ # commit 92d13a08328625463e9ba944cece82fc5eea36e6
3
+ def rpd_to_chrome_trace(
4
+ input_rpd, output_json=None, start="0%", end="100%", format="object"
5
+ ):
6
+ import gzip
7
+ import sqlite3
8
+
9
+ if output_json is None:
10
+ import pathlib
11
+
12
+ output_json = pathlib.PurePath(input_rpd).with_suffix(".trace.json.gz")
13
+
14
+ connection = sqlite3.connect(input_rpd)
15
+
16
+ outfile = gzip.open(output_json, "wt", encoding="utf-8")
17
+
18
+ if format == "object":
19
+ outfile.write('{"traceEvents": ')
20
+
21
+ outfile.write("[ {}\n")
22
+
23
+ for row in connection.execute("select distinct gpuId from rocpd_op"):
24
+ try:
25
+ outfile.write(
26
+ ',{"name": "process_name", "ph": "M", "pid":"%s","args":{"name":"%s"}}\n'
27
+ % (row[0], "GPU" + str(row[0]))
28
+ )
29
+ outfile.write(
30
+ ',{"name": "process_sort_index", "ph": "M", "pid":"%s","args":{"sort_index":"%s"}}\n'
31
+ % (row[0], row[0] + 1000000)
32
+ )
33
+ except ValueError:
34
+ outfile.write("")
35
+
36
+ for row in connection.execute("select distinct pid, tid from rocpd_api"):
37
+ try:
38
+ outfile.write(
39
+ ',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n'
40
+ % (row[0], row[1], "Hip " + str(row[1]))
41
+ )
42
+ outfile.write(
43
+ ',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n'
44
+ % (row[0], row[1], row[1] * 2)
45
+ )
46
+ except ValueError:
47
+ outfile.write("")
48
+
49
+ try:
50
+ # FIXME - these aren't rendering correctly in chrome://tracing
51
+ for row in connection.execute("select distinct pid, tid from rocpd_hsaApi"):
52
+ try:
53
+ outfile.write(
54
+ ',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n'
55
+ % (row[0], row[1], "HSA " + str(row[1]))
56
+ )
57
+ outfile.write(
58
+ ',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n'
59
+ % (row[0], row[1], row[1] * 2 - 1)
60
+ )
61
+ except ValueError:
62
+ outfile.write("")
63
+ except:
64
+ pass
65
+
66
+ rangeStringApi = ""
67
+ rangeStringOp = ""
68
+ rangeStringMonitor = ""
69
+ min_time = connection.execute("select MIN(start) from rocpd_api;").fetchall()[0][0]
70
+ max_time = connection.execute("select MAX(end) from rocpd_api;").fetchall()[0][0]
71
+ if min_time == None:
72
+ raise Exception("Trace file is empty.")
73
+
74
+ print("Timestamps:")
75
+ print(f"\t first: \t{min_time/1000} us")
76
+ print(f"\t last: \t{max_time/1000} us")
77
+ print(f"\t duration: \t{(max_time-min_time) / 1000000000} seconds")
78
+
79
+ start_time = min_time / 1000
80
+ end_time = max_time / 1000
81
+
82
+ if start:
83
+ if "%" in start:
84
+ start_time = (
85
+ (max_time - min_time) * (int(start.replace("%", "")) / 100) + min_time
86
+ ) / 1000
87
+ else:
88
+ start_time = int(start)
89
+ rangeStringApi = "where rocpd_api.start/1000 >= %s" % (start_time)
90
+ rangeStringOp = "where rocpd_op.start/1000 >= %s" % (start_time)
91
+ rangeStringMonitor = "where start/1000 >= %s" % (start_time)
92
+ if end:
93
+ if "%" in end:
94
+ end_time = (
95
+ (max_time - min_time) * (int(end.replace("%", "")) / 100) + min_time
96
+ ) / 1000
97
+ else:
98
+ end_time = int(end)
99
+
100
+ rangeStringApi = (
101
+ rangeStringApi + " and rocpd_api.start/1000 <= %s" % (end_time)
102
+ if start != None
103
+ else "where rocpd_api.start/1000 <= %s" % (end_time)
104
+ )
105
+ rangeStringOp = (
106
+ rangeStringOp + " and rocpd_op.start/1000 <= %s" % (end_time)
107
+ if start != None
108
+ else "where rocpd_op.start/1000 <= %s" % (end_time)
109
+ )
110
+ rangeStringMonitor = (
111
+ rangeStringMonitor + " and start/1000 <= %s" % (end_time)
112
+ if start != None
113
+ else "where start/1000 <= %s" % (end_time)
114
+ )
115
+
116
+ print("\nFilter: %s" % (rangeStringApi))
117
+ print(f"Output duration: {(end_time-start_time)/1000000} seconds")
118
+
119
+ # Output Ops
120
+
121
+ for row in connection.execute(
122
+ "select A.string as optype, B.string as description, gpuId, queueId, rocpd_op.start/1000.0, (rocpd_op.end-rocpd_op.start) / 1000.0 from rocpd_op INNER JOIN rocpd_string A on A.id = rocpd_op.opType_id INNER Join rocpd_string B on B.id = rocpd_op.description_id %s"
123
+ % (rangeStringOp)
124
+ ):
125
+ try:
126
+ name = row[0] if len(row[1]) == 0 else row[1]
127
+ outfile.write(
128
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
129
+ % (row[2], row[3], name, row[4], row[5], row[0])
130
+ )
131
+ except ValueError:
132
+ outfile.write("")
133
+
134
+ # Output Graph executions on GPU
135
+ try:
136
+ for row in connection.execute(
137
+ "select graphExec, gpuId, queueId, min(start)/1000.0, (max(end)-min(start))/1000.0, count(*) from rocpd_graphLaunchapi A join rocpd_api_ops B on B.api_id = A.api_ptr_id join rocpd_op C on C.id = B.op_id %s group by api_ptr_id"
138
+ % (rangeStringMonitor)
139
+ ):
140
+ try:
141
+ outfile.write(
142
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"kernels":"%s"}}\n'
143
+ % (row[1], row[2], f"Graph {row[0]}", row[3], row[4], row[5])
144
+ )
145
+ except ValueError:
146
+ outfile.write("")
147
+ except:
148
+ pass
149
+
150
+ # Output apis
151
+ for row in connection.execute(
152
+ "select A.string as apiName, B.string as args, pid, tid, rocpd_api.start/1000.0, (rocpd_api.end-rocpd_api.start) / 1000.0, (rocpd_api.end != rocpd_api.start) as has_duration from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id INNER Join rocpd_string B on B.id = rocpd_api.args_id %s order by rocpd_api.id"
153
+ % (rangeStringApi)
154
+ ):
155
+ try:
156
+ if row[0] == "UserMarker":
157
+ if row[6] == 0: # instantanuous "mark" messages
158
+ outfile.write(
159
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","ph":"i","s":"p","args":{"desc":"%s"}}\n'
160
+ % (
161
+ row[2],
162
+ row[3],
163
+ row[1].replace('"', ""),
164
+ row[4],
165
+ row[1].replace('"', ""),
166
+ )
167
+ )
168
+ else:
169
+ outfile.write(
170
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
171
+ % (
172
+ row[2],
173
+ row[3],
174
+ row[1].replace('"', ""),
175
+ row[4],
176
+ row[5],
177
+ row[1].replace('"', ""),
178
+ )
179
+ )
180
+ else:
181
+ outfile.write(
182
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
183
+ % (
184
+ row[2],
185
+ row[3],
186
+ row[0],
187
+ row[4],
188
+ row[5],
189
+ row[1].replace('"', "").replace("\t", ""),
190
+ )
191
+ )
192
+ except ValueError:
193
+ outfile.write("")
194
+
195
+ # Output api->op linkage
196
+ for row in connection.execute(
197
+ "select rocpd_api_ops.id, pid, tid, gpuId, queueId, rocpd_api.end/1000.0 - 2, rocpd_op.start/1000.0 from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id %s"
198
+ % (rangeStringApi)
199
+ ):
200
+ try:
201
+ fromtime = row[5] if row[5] < row[6] else row[6]
202
+ outfile.write(
203
+ ',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"s"}\n'
204
+ % (row[1], row[2], fromtime, row[0])
205
+ )
206
+ outfile.write(
207
+ ',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"f", "bp":"e"}\n'
208
+ % (row[3], row[4], row[6], row[0])
209
+ )
210
+ except ValueError:
211
+ outfile.write("")
212
+
213
+ try:
214
+ for row in connection.execute(
215
+ "select A.string as apiName, B.string as args, pid, tid, rocpd_hsaApi.start/1000.0, (rocpd_hsaApi.end-rocpd_hsaApi.start) / 1000.0 from rocpd_hsaApi INNER JOIN rocpd_string A on A.id = rocpd_hsaApi.apiName_id INNER Join rocpd_string B on B.id = rocpd_hsaApi.args_id %s order by rocpd_hsaApi.id"
216
+ % (rangeStringApi)
217
+ ):
218
+ try:
219
+ outfile.write(
220
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
221
+ % (
222
+ row[2],
223
+ row[3] + 1,
224
+ row[0],
225
+ row[4],
226
+ row[5],
227
+ row[1].replace('"', ""),
228
+ )
229
+ )
230
+ except ValueError:
231
+ outfile.write("")
232
+ except:
233
+ pass
234
+
235
+ #
236
+ # Counters
237
+ #
238
+
239
+ # Counters should extend to the last event in the trace. This means they need to have a value at Tend.
240
+ # Figure out when that is
241
+
242
+ T_end = 0
243
+ for row in connection.execute(
244
+ "SELECT max(end)/1000 from (SELECT end from rocpd_api UNION ALL SELECT end from rocpd_op)"
245
+ ):
246
+ T_end = int(row[0])
247
+ if end:
248
+ T_end = end_time
249
+
250
+ # Loop over GPU for per-gpu counters
251
+ gpuIdsPresent = []
252
+ for row in connection.execute("SELECT DISTINCT gpuId FROM rocpd_op"):
253
+ gpuIdsPresent.append(row[0])
254
+
255
+ for gpuId in gpuIdsPresent:
256
+ # print(f"Creating counters for: {gpuId}")
257
+
258
+ # Create the queue depth counter
259
+ depth = 0
260
+ idle = 1
261
+ for row in connection.execute(
262
+ 'select * from (select rocpd_api.start/1000.0 as ts, "1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s UNION ALL select rocpd_op.end/1000.0, "-1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s) order by ts'
263
+ % (gpuId, rangeStringOp, gpuId, rangeStringOp)
264
+ ):
265
+ try:
266
+ if idle and int(row[1]) > 0:
267
+ idle = 0
268
+ outfile.write(
269
+ ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
270
+ % (gpuId, row[0], idle)
271
+ )
272
+ if depth == 1 and int(row[1]) < 0:
273
+ idle = 1
274
+ outfile.write(
275
+ ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
276
+ % (gpuId, row[0], idle)
277
+ )
278
+ depth = depth + int(row[1])
279
+ outfile.write(
280
+ ',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n'
281
+ % (gpuId, row[0], depth)
282
+ )
283
+ except ValueError:
284
+ outfile.write("")
285
+ if T_end > 0:
286
+ outfile.write(
287
+ ',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
288
+ % (gpuId, T_end, idle)
289
+ )
290
+ outfile.write(
291
+ ',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n'
292
+ % (gpuId, T_end, depth)
293
+ )
294
+
295
+ # Create SMI counters
296
+ try:
297
+ for row in connection.execute(
298
+ "select deviceId, monitorType, start/1000.0, value from rocpd_monitor %s"
299
+ % (rangeStringMonitor)
300
+ ):
301
+ outfile.write(
302
+ ',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n'
303
+ % (row[0], row[1], row[2], row[1], row[3])
304
+ )
305
+ # Output the endpoints of the last range
306
+ for row in connection.execute(
307
+ "select distinct deviceId, monitorType, max(end)/1000.0, value from rocpd_monitor %s group by deviceId, monitorType"
308
+ % (rangeStringMonitor)
309
+ ):
310
+ outfile.write(
311
+ ',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n'
312
+ % (row[0], row[1], row[2], row[1], row[3])
313
+ )
314
+ except:
315
+ print("Did not find SMI data")
316
+
317
+ # Create the (global) memory counter
318
+ """
319
+ sizes = {} # address -> size
320
+ totalSize = 0
321
+ exp = re.compile("^ptr\((.*)\)\s+size\((.*)\)$")
322
+ exp2 = re.compile("^ptr\((.*)\)$")
323
+ for row in connection.execute("SELECT rocpd_api.end/1000.0 as ts, B.string, '1' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipFree' UNION ALL SELECT rocpd_api.start/1000.0, B.string, '0' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipMalloc' ORDER BY ts asc"):
324
+ try:
325
+ if row[2] == '0': #malloc
326
+ m = exp.match(row[1])
327
+ if m:
328
+ size = int(m.group(2), 16)
329
+ totalSize = totalSize + size
330
+ sizes[m.group(1)] = size
331
+ outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize))
332
+ else: #free
333
+ m = exp2.match(row[1])
334
+ if m:
335
+ try: # Sometimes free addresses are not valid or listed
336
+ size = sizes[m.group(1)]
337
+ sizes[m.group(1)] = 0
338
+ totalSize = totalSize - size;
339
+ outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize))
340
+ except KeyError:
341
+ pass
342
+ except ValueError:
343
+ outfile.write("")
344
+ if T_end > 0:
345
+ outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(T_end,totalSize))
346
+ """
347
+
348
+ # Create "faux calling stack frame" on gpu ops traceS
349
+ stacks = {} # Call stacks built from UserMarker entres. Key is 'pid,tid'
350
+ currentFrame = {} # "Current GPU frame" (id, name, start, end). Key is 'pid,tid'
351
+
352
+ class GpuFrame:
353
+ def __init__(self):
354
+ self.id = 0
355
+ self.name = ""
356
+ self.start = 0
357
+ self.end = 0
358
+ self.gpus = []
359
+ self.totalOps = 0
360
+
361
+ # FIXME: include 'start' (in ns) so we can ORDER BY it and break ties?
362
+ for row in connection.execute(
363
+ "SELECT '0', start/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '1', end/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '2', rocpd_api.start/1000.0, pid, tid, '' as label, gpuId, queueId, rocpd_op.start/1000.0, rocpd_op.end/1000.0 from rocpd_api_ops INNER JOIN rocpd_api ON rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op ON rocpd_api_ops.op_id = rocpd_op.id %s ORDER BY start/1000.0 asc"
364
+ % (rangeStringApi, rangeStringApi, rangeStringApi)
365
+ ):
366
+ try:
367
+ key = (row[2], row[3]) # Key is 'pid,tid'
368
+ if row[0] == "0": # Frame start
369
+ if key not in stacks:
370
+ stacks[key] = []
371
+ stack = stacks[key].append((row[1], row[4]))
372
+ # print(f"0: new api frame: pid_tid={key} -> stack={stacks}")
373
+
374
+ elif row[0] == "1": # Frame end
375
+ completed = stacks[key].pop()
376
+ # print(f"1: end api frame: pid_tid={key} -> stack={stacks}")
377
+
378
+ elif row[0] == "2": # API + Op
379
+ if key in stacks and len(stacks[key]) > 0:
380
+ frame = stacks[key][-1]
381
+ # print(f"2: Op on {frame} ({len(stacks[key])})")
382
+ gpuFrame = None
383
+ if key not in currentFrame: # First op under the current api frame
384
+ gpuFrame = GpuFrame()
385
+ gpuFrame.id = frame[0]
386
+ gpuFrame.name = frame[1]
387
+ gpuFrame.start = row[7]
388
+ gpuFrame.end = row[8]
389
+ gpuFrame.gpus.append((row[5], row[6]))
390
+ gpuFrame.totalOps = 1
391
+ # print(f"2a: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
392
+ else:
393
+ gpuFrame = currentFrame[key]
394
+ # Another op under the same frame -> union them (but only if they are butt together)
395
+ if (
396
+ gpuFrame.id == frame[0]
397
+ and gpuFrame.name == frame[1]
398
+ and (
399
+ abs(row[7] - gpuFrame.end) < 200
400
+ or abs(gpuFrame.start - row[8]) < 200
401
+ )
402
+ ):
403
+ # if gpuFrame.id == frame[0] and gpuFrame.name == frame[1]: # Another op under the same frame -> union them
404
+ # if False: # Turn off frame joining
405
+ if row[7] < gpuFrame.start:
406
+ gpuFrame.start = row[7]
407
+ if row[8] > gpuFrame.end:
408
+ gpuFrame.end = row[8]
409
+ if (row[5], row[6]) not in gpuFrame.gpus:
410
+ gpuFrame.gpus.append((row[5], row[6]))
411
+ gpuFrame.totalOps = gpuFrame.totalOps + 1
412
+ # print(f"2c: union frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
413
+
414
+ else: # This is a new frame - dump the last and make new
415
+ gpuFrame = currentFrame[key]
416
+ for dest in gpuFrame.gpus:
417
+ # print(f"2: OUTPUT: dest={dest} time={gpuFrame.start} -> {gpuFrame.end} Duration={gpuFrame.end - gpuFrame.start} TotalOps={gpuFrame.totalOps}")
418
+ outfile.write(
419
+ ',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
420
+ % (
421
+ dest[0],
422
+ dest[1],
423
+ gpuFrame.name.replace('"', ""),
424
+ gpuFrame.start - 1,
425
+ gpuFrame.end - gpuFrame.start + 1,
426
+ f"UserMarker frame: {gpuFrame.totalOps} ops",
427
+ )
428
+ )
429
+ currentFrame.pop(key)
430
+
431
+ # make the first op under the new frame
432
+ gpuFrame = GpuFrame()
433
+ gpuFrame.id = frame[0]
434
+ gpuFrame.name = frame[1]
435
+ gpuFrame.start = row[7]
436
+ gpuFrame.end = row[8]
437
+ gpuFrame.gpus.append((row[5], row[6]))
438
+ gpuFrame.totalOps = 1
439
+ # print(f"2b: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
440
+
441
+ currentFrame[key] = gpuFrame
442
+
443
+ except ValueError:
444
+ outfile.write("")
445
+
446
+ outfile.write("]\n")
447
+
448
+ if format == "object":
449
+ outfile.write("} \n")
450
+
451
+ outfile.close()
452
+ connection.close()
@@ -0,0 +1,71 @@
1
+ import logging
2
+ from typing import Any, Dict, List
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ import triton
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def execute():
12
+ if dist.get_rank() == 0:
13
+ logger.info(f"[slow_rank_detector] Start benchmarking...")
14
+
15
+ local_metrics = {
16
+ bench_name: _compute_local_metric(bench_name) for bench_name in _BENCH_NAMES
17
+ }
18
+
19
+ all_metrics = [None for _ in range(dist.get_world_size())]
20
+ dist.gather_object(local_metrics, all_metrics if dist.get_rank() == 0 else None)
21
+
22
+ if dist.get_rank() == 0:
23
+ _analyze_metrics(all_metrics)
24
+
25
+
26
+ class _GemmExecutor:
27
+ def __init__(self):
28
+ self.lhs = torch.randn((8192, 8192), dtype=torch.bfloat16, device="cuda")
29
+ self.rhs = torch.randn((8192, 8192), dtype=torch.bfloat16, device="cuda")
30
+
31
+ def __call__(self):
32
+ self.lhs @ self.rhs
33
+
34
+
35
+ class _ElementwiseExecutor:
36
+ def __init__(self):
37
+ self.value = torch.randint(
38
+ 0, 10000, (128 * 1024**2,), dtype=torch.int32, device="cuda"
39
+ )
40
+
41
+ def __call__(self):
42
+ self.value += 1
43
+
44
+
45
+ _EXECUTOR_CLS_OF_BENCH = {
46
+ "gemm": _GemmExecutor,
47
+ "elementwise": _ElementwiseExecutor,
48
+ }
49
+
50
+ _BENCH_NAMES = list(_EXECUTOR_CLS_OF_BENCH.keys())
51
+
52
+
53
+ def _compute_local_metric(bench_name):
54
+ executor = _EXECUTOR_CLS_OF_BENCH[bench_name]()
55
+ ms = triton.testing.do_bench_cudagraph(executor, return_mode="mean", rep=20)
56
+ return ms
57
+
58
+
59
+ def _analyze_metrics(all_metrics: List[Dict[str, Any]]):
60
+ for bench_name in _BENCH_NAMES:
61
+ time_of_rank = torch.tensor([m[bench_name] for m in all_metrics])
62
+ speed_of_rank = 1 / time_of_rank
63
+ rel_speed_of_rank = speed_of_rank / speed_of_rank.max()
64
+ slowest_rel_speed = rel_speed_of_rank.min().item()
65
+ logger.info(
66
+ f"[slow_rank_detector] {bench_name=} {slowest_rel_speed=} {rel_speed_of_rank=} {time_of_rank=}"
67
+ )
68
+ if slowest_rel_speed < 0.9:
69
+ logger.warning(
70
+ "[slow_rank_detector] Some ranks are too slow compared with others"
71
+ )
sglang/srt/warmup.py CHANGED
@@ -1,20 +1,24 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
- from typing import List
4
+ from typing import TYPE_CHECKING, List
3
5
 
4
6
  import numpy as np
5
7
  import tqdm
6
8
 
7
9
  from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
8
10
  from sglang.srt.managers.io_struct import GenerateReqInput
9
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
10
14
 
11
15
  logger = logging.getLogger(__file__)
12
16
 
13
17
  _warmup_registry = {}
14
18
 
15
19
 
16
- def warmup(name: str) -> callable:
17
- def decorator(fn: callable):
20
+ def warmup(name: str):
21
+ def decorator(fn):
18
22
  _warmup_registry[name] = fn
19
23
  return fn
20
24